In [91]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import torch
from torch import nn, optim
import numpy

In [92]:
class ImageDataset(Dataset):
    def __init__(self):
        self.images = numpy.load('/kaggle/input/pixel-sprites/sprites_1788_16x16.npy')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 2 - 1)#像素1~-1，分布不會改變
        ])
        
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        image = self.images[index]
        image = self.transform(image)
        return image  

In [156]:
class DDPM:
    def __init__(self, T):
        self.betas = torch.linspace(0.0001, 0.02, T)
        alphas = 1 - self.betas
        alpha_bars = alphas.cumprod(0)
        beta_bars = 1 - alpha_bars
        self.sqrt_alpha_bars = alpha_bars.sqrt()
        self.sqrt_beta_bars = beta_bars.sqrt()
        self.sqrt_alphas = alphas.sqrt()
        self.sqrt_betas = self.betas.sqrt()
        self.T = T 
    def forward_process(self, x0, z, t):
        sqrt_alpha_bar = self.sqrt_alpha_bars[t].reshape(-1, 1, 1, 1)
        sqrt_beta_bar = self.sqrt_beta_bars[t].reshape(-1, 1, 1, 1)
        xt = sqrt_alpha_bar * x0 +sqrt_beta_bar * z
        return xt
    def backward_process(self, model):
        xt = torch.randn(16, 3, 16, 16)
        for t in range(self.T-1, -1, -1):
            t_tensor = torch.tensor([t] * 16).cuda()
            xt = xt.cuda()
            eps = model(xt, t_tensor).cpu()
            xt = xt.cpu()
            
            
            if t == 0:
                z = 0
            else:
                z = torch.randn_like(xt)
                sigma = self.sqrt_betas[t] * self.sqrt_beta_bars[t-1] / self.sqrt_beta_bars[t]
                z = sigma * z
                
            mean = 1 / self.sqrt_alphas[t] * (xt - self.betas[t] / self.sqrt_beta_bars[t] * eps)
            xt = mean + z
        return xt

In [94]:
T = 100 #假設
batch_size = 10
ddpm = DDPM(T)

x0 = torch.zeros(batch_size, 3, 16, 16)
z = torch.randn_like(x0)
t = torch.randint(0, T, (batch_size,))
xt = ddpm.forward_process(x0, z, t)

In [95]:
class Positional_Encoding:
    def __init__(self, T, dim): #dim:向量長度
        pes = []
        for t in torch.arange(T):
            pe = []
            for i in torch.arange(0, dim, 2):
                pe.append(torch.sin(t * 10000 **(-i/dim)))
                pe.append(torch.cos(t * 10000 **(-i/dim)))
            pes.append(pe)
        self.pes= torch.tensor(pes).cuda()
    def __call__(self, t):
        return self.pes[t]

In [None]:
pe = Positional_Encoding(T, 128)
batch_size = 10

t = torch.randint(0, T, (batch_size,))
print(t.shape)
t = pe(t)
print(t.shape)

In [97]:
class UnetBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        self.act = nn.ReLU()
        self.residual = None
        if in_channel != out_channel:
            self.residual = nn.Conv2d(in_channel, out_channel, 1)
    def forward(self, in_x):
        x = self.net(in_x)
        if self.residual is not None:
            in_x = self.residual(in_x)
        x = in_x + x
        x = self.act(x)
        return x

In [98]:
class Encoder_Decoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.net = nn.Sequential(
            UnetBlock(in_channel, out_channel),
            UnetBlock(out_channel,out_channel)
        )
    def forward(self, x):
        x = self.net(x)
        return x

In [99]:
class pe_Linear(nn.Module):
    def __init__(self, dim, ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, ch),
            nn.ReLU(),
            nn.Linear(ch, ch)
        )
    def forward(self, t):
        t = self.net(t)
        return t

In [100]:
class Encoder_Block(nn.Module):
    def __init__(self, dim, in_channel, out_channel):
        super().__init__()
        self.pe_linear = pe_Linear(dim, in_channel)
        self.encoder = Encoder_Decoder(in_channel, out_channel)
        self.conv = nn.Conv2d(out_channel, out_channel, 2, 2)
    def forward(self, x, t):
        t = self.pe_linear(t)
        n = x.shape[0]
        in_channel = x.shape[1]
        t = t.reshape(n, in_channel, 1, 1)
        x = x + t
        sc_x = self.encoder(x)
        x = self.conv(sc_x)
        return x, sc_x
    

In [101]:
class BottleNeck(nn.Module):
    def __init__(self, dim, in_channel, out_channel):
        super().__init__()
        self.linear = nn.Linear(dim, in_channel)
        self.net = nn.Sequential(
            UnetBlock(in_channel, out_channel),
            UnetBlock(out_channel, out_channel)
        )
    def forward(self, x, t):
        t = self.linear(t)
        n = x.shape[0]
        in_channel = x.shape[1]
        t = t.reshape(n, in_channel, 1, 1)
        x = x + t
        x = self.net(x)
        return x

In [102]:
class DecoderBlock(nn.Module):
    def __init__(self, dim, in_channel, out_channel):
        super().__init__()
        self.linear = nn.Linear(dim, out_channel*2)
        self.up = nn.ConvTranspose2d(in_channel, out_channel, 2, 2)
        self.decoder = Encoder_Decoder(out_channel*2, out_channel)
    def forward(self, sc_x, x, t):
        t = self.linear(t)
        x = self.up(x)
        x = torch.cat([sc_x, x],dim = 1)
        n = x.shape[0]
        out_channel2 = x.shape[1]
        t = t.reshape(n, out_channel2, 1, 1)
        x = x + t
        x = self.decoder(x)
        return x

In [103]:
class UNet(nn.Module):
    def __init__(self, T):
        super().__init__()
        self.en1 = Encoder_Block(128, 3, 16)
        self.en2 = Encoder_Block(128, 16, 32)
        self.en3 = Encoder_Block(128, 32, 64)
        
        self.bottleneck = BottleNeck(128, 64, 128)
        
        self.de1 = DecoderBlock(128, 128, 64)
        self.de2 = DecoderBlock(128, 64, 32)
        self.de3 = DecoderBlock(128, 32, 16)
        
        self.conv = nn.Conv2d(16, 3, 3, 1, 1)
        self.pe = Positional_Encoding(T, 128) 
    def forward(self, x, t):
        t = self.pe(t)
        x, sc_x1 = self.en1(x, t)
        x, sc_x2 = self.en2(x, t)
        x, sc_x3 = self.en3(x, t)
        
        x = self.bottleneck(x, t)
        
        x = self.de1(sc_x3, x, t)
        x = self.de2(sc_x2, x, t)
        x = self.de3(sc_x1, x, t)
        
        x = self.conv(x)
        return x

In [104]:
T = 500
batch_size = 512
epochs = 30

In [157]:
ddpm = DDPM(T)

In [110]:
model = UNet(T).cuda()
loss_fn = nn.MSELoss().cuda()
opt = optim.Adam(model.parameters(),0.001)

In [111]:
dataset = ImageDataset()
loader = DataLoader(dataset, batch_size, True)

In [112]:
from tqdm import tqdm

In [None]:
for epoch in tqdm(range(epochs)):
    for x0 in loader:
        n = x0.shape[0]
        t = torch.randint(0, T, [n])
        z = torch.randn_like(x0)
        xt = ddpm.forward_process(x0, z, t).cuda()
        z = z.cuda()
        t = t.cuda()
        eps = model(xt, t)
        
        loss = loss_fn(eps, z)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
    print('loss:', loss.item())

In [158]:
model.eval()
with torch.no_grad():
    images = ddpm.backward_process(model)

In [159]:
images = images.clamp(-1, 1)

In [160]:
transform = transforms.Compose([
    transforms.Lambda(lambda x : (x + 1) / 2),
    transforms.ToPILImage()
])

In [161]:
background = Image.new('RGB', (16*4, 16*4), 0xffffff)

In [162]:
for i, image in enumerate(images):
    image = transform(image)
    x = i % 4
    y = i // 4
    background.paste(image, (16*x, 16*y))

In [None]:
background.resize((320, 320), 0)

In [None]:
model.eval()
with torch.no_grad():
    images = ddpm.backward_process(model)

images = images.clamp(-1, 1)

transform = transforms.Compose([
    transforms.Lambda(lambda x : (x + 1) / 2),
    transforms.ToPILImage()
])

padding = 1
background = Image.new('RGB', (16*4+3, 16*4+3))

for i, image in enumerate(images):
    image = transform(image)
    x = i % 4
    y = i // 4
    background.paste(image, ((16+1)*x, (16+1)*y))

background.resize((640, 640), 0)