In [1]:
import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

import torch.nn.functional as F

from torchvision.utils import save_image, make_grid

from torchvision.transforms.functional import center_crop

import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
NUM_EPOCHS = 10
BATCH_SIZE = 128

In [4]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

training_data = MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

test_data = MNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 89818855.79it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 5674068.75it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25694064.18it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 19559064.44it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [5]:
print(training_data)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()


In [6]:
train_dataloader = DataLoader(training_data, batch_size = BATCH_SIZE)
test_dataloader = DataLoader(test_data, batch_size = BATCH_SIZE)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([128, 1, 28, 28])
Shape of y: torch.Size([128]) torch.int64


In [7]:
class ConvolutionBlock(nn.Module):
    def __init__(self, in_shape: int, out_shape: int):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_shape, out_shape, kernel_size=3, padding = 1),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_shape, out_shape, kernel_size=3, padding = 1),
            nn.ReLU()
        )
        
    def forward(self, x: torch.Tensor):
        out = self.conv1(x)
        out = self.conv2(out)
        return out

class U_Net(nn.Module):
    def __init__(self, in_shape:int, out_shape:int):
        super().__init__()

        two_powers = [2 ** i for i in range(3, 6)] # due to memory, minimize size

        conv_list = [ConvolutionBlock(two_powers[i], two_powers[i+1]) for i in range(len(two_powers) - 1)]
        conv_list.insert(0, ConvolutionBlock(in_shape, two_powers[0]))

        deconv_list = [ConvolutionBlock(two_powers[i], two_powers[i-1]) for i in range(len(two_powers) - 1, 0, -1)]
        deconv_list.append(ConvolutionBlock(two_powers[0], out_shape))

        self.conv = nn.ModuleList(conv_list)
        self.conv_pool = nn.ModuleList([nn.MaxPool2d(2) for i in range(len(two_powers) - 1)])
        
        self.deconv = nn.ModuleList(deconv_list)
        self.deconv_pool = nn.ModuleList([nn.ConvTranspose2d(two_powers[i], two_powers[i-1], kernel_size=2, stride=2) for i in range(len(two_powers) - 1, 0, -1)])

    def forward(self, x):

        out = []
        
        for i in range(len(self.conv_pool)):
            _out = self.conv[i](x)
            out.append(_out)
            x = self.conv_pool[i](_out)
        
        x = self.conv[-1](x)

        out.reverse()

        for i in range(len(self.deconv_pool)):
            x = self.deconv_pool[i](x)
            _out = torch.cat([x, out[i]], dim=1)
            x = self.deconv[i](_out)
        
        x = self.deconv[-1](x)

        return x

In [8]:
class DDPM(nn.Module):
    def __init__(self, eps_model: nn.Module, T: int, device: torch.cuda.device):
        super().__init__()

        self.eps_model = eps_model

        self.beta = torch.linspace(0.0001, 0.02, T).to(device)

        self.alpha = 1. - self.beta
        
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.T = T
        
        self.sigma_squared = self.beta

        self.encoder = self._encoder
        self.decoder = self._decoder

        self.device = device
        
        self.optimizer = torch.optim.Adam(self.parameters(), 1e-3)

    # https://nn.labml.ai/diffusion/ddpm/utils.html
    def _get_teeth(self, consts: torch.Tensor, t: torch.Tensor): # get t th const 
        const = consts.gather(-1, t)
        return const.reshape(-1, 1, 1, 1)

    def reparam(self, mean, var, x_0):
        epsilon = torch.randn_like(x_0)
        return mean + (var ** 0.5) * epsilon

    def _encoder(self, x_0):
        batch_size = x_0.shape[0]

        t = torch.randint(0, self.T, (batch_size,), device=self.device, dtype=torch.long)
        
        mean = self._get_teeth(self.alpha_bar, t) ** 0.5 * x_0

        var = 1 - self._get_teeth(self.alpha_bar, t)

        x_t = self.reparam(mean, var, x_0)

        eps_theta = self.eps_model(x_t)

        return eps_theta

    def _decoder(self, x_T):
        x_t = x_T
        for t in range(self.T):
            if t > 1:
                z = torch.randn_like(x_T)
            else:
                z = torch.zeros(x_T)
            
            alpha = self._get_teeth(self.alpha, t)
            alpha_bar = self._get_teeth(self.alpha_bar, t)

            one_sqrt_ath = alpha ** (-0.5)
            one_a = 1. - alpha
            one_sqrt_abth = (1 - alpha_bar) ** (-0.5)

            eps_theta = self.eps_model(x_t)

            sigma = self._get_teeth(self.sigma_squared, t) ** 0.5

            x_t = one_sqrt_ath * (x_t - one_a * one_sqrt_abth * eps_theta) + sigma * z
        return x_t
    
    def loss(self, eps_theta):
        noise = torch.randn_like(eps_theta)

        return F.mse_loss(noise, eps_theta)
    
    def forward(self, x):
        eps = self.encoder(x)
        return eps

In [9]:
def train(dataloader, model, device):
    size = len(dataloader.dataset)
    for batch, (X, _) in enumerate(dataloader):
        X = X.view(-1, 28*28).to(device)

        x_hat = model(X)
        loss = model.loss(x_hat)

        model.optimizer.zero_grad()
        loss.backward()
        model.optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item() / BATCH_SIZE, batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [10]:
model = DDPM(U_Net(1, 1), 1000, device).to(device)

print(model)

for t in range(NUM_EPOCHS):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, device)
print("Done!")

DDPM(
  (eps_model): U_Net(
    (conv): ModuleList(
      (0): ConvolutionBlock(
        (conv1): Sequential(
          (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
      )
      (1): ConvolutionBlock(
        (conv1): Sequential(
          (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
      )
      (2): ConvolutionBlock(
        (conv1): Sequential(
          (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
        (conv2): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        

KeyboardInterrupt: ignored

# GPU 램 문제로 충분한 용량의 파라미터를 할당하지 못하다보니 학습이 제대로 안됨