# Training

## Setup

In [8]:
import torch
import torchvision
from torchvision import datasets, transforms
import pandas as pd
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
#from torch.autograd import Variable

In [9]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
#torch.backends.cudnn.enabled = False
val_size = 5000
test_size = 5000
batch_size = 16
num_workers = 4

transform = transforms.Compose([transforms.ToTensor()])

# Downloading MNIST again :) Training (60k) and test(5k) + val(5k) split
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data',
                                            download=True,
                                            train=True,
                                            transform=transform),
                                            batch_size=batch_size,
                                            shuffle=True, num_workers=num_workers)

test_dataset = datasets.MNIST('./mnist_data',
                               download=True,
                               train=False,
                               transform=transform)

val_dataset, test_dataset = random_split(test_dataset, [val_size, test_size])

# Test set to compare with DDPM paper
test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)

# Validation set so we can keep track of approximated FID score while training
validation_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=16,
                                            shuffle=False, num_workers=num_workers)
    

In [68]:
# Sets up alpha_bar for training and test so alpha_bar_t = alpha_bar[t]
T = 1000
beta_start, beta_end = [1e-4, 2e-02]
beta = torch.linspace(beta_start, beta_end, T)
alpha = 1-beta
alpha_bar = alpha.clone()
for e in range(T-1):
    alpha_bar[e+1] *= alpha_bar[e]

alpha = alpha.view((1000, 1, 1, 1)).to(device)
beta = beta.view((1000, 1, 1, 1)).to(device)
alpha_bar = alpha_bar.view((1000, 1, 1, 1)).to(device)

## Model

In [11]:
class UNET(torch.nn.Module):
    def __init__(self):
        super(UNET, self).__init__()
        channels = [32, 64, 128, 256]
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2, channels[0], kernel_size=3, padding=1),  # (batchsize, 32, 28, 28)
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.MaxPool2d(2),  # (batchsize, 32, 14, 14)
                nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1),  # (batchsize, 64, 14, 14)
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.MaxPool2d(2),  # (batchsize, 64, 7, 7)
                nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1),  # (batchsize, 128, 7, 7)
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.MaxPool2d(2, padding=1),  # (batchsize, 128, 4, 4)
                nn.Conv2d(channels[2], channels[3], kernel_size=3, padding=1),  # (batchsize, 256, 4, 4)
                nn.ReLU(),
            )
        ])

        self.tconvs = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(channels[3], channels[2], kernel_size=3, 
                                   stride=2, padding=1, output_padding=0),   # (batchsize, 128, 7, 7)
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(channels[2]*2, channels[1], kernel_size=3,
                                   stride=2, padding=1, output_padding=1),   # (batchsize, 64, 14, 14)
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(channels[1]*2, channels[0], kernel_size=3, 
                                   stride=2, padding=1, output_padding=1),   # (batchsize, 32, 28, 28)
                nn.ReLU()
            ),
            nn.Sequential(
                nn.Conv2d(channels[0]*2,channels[0],kernel_size=3,padding=1),  # (batchsize, 32, 28, 28)
                nn.ReLU(),
                nn.Conv2d(channels[0],1,kernel_size=1) # (batchsize, 1, 28, 28)
            )      
        ])

    def forward(self, x, t):
        x_trans = torch.cat((x, t), dim=-3)
        signal = x_trans
        signals = []

        for i, conv in enumerate(self.convs):
            # print(f"conv {i}")
            signal = conv(signal)
            # print(signal.shape)
            if i < len(conv):
                signals.append(signal)
        
        for i, tconv in enumerate(self.tconvs):
            # print(f"tconv {i}")
            # print(f"signal shape: {signal.shape}")
            if i == 0:
                signal = tconv(signal)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                signal = tconv(signal)
        return signal

## Training loop

In [12]:
#from UNET import UNET
epochs = 10
model = UNET()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
running_loss = 0

In [13]:
for epoch in range(epochs):
    for e, data in enumerate(train_loader):
        x0, _ = data
        x0 = x0.to(device)
        t = torch.randint(1, T+1, (batch_size,)).to(device)
        eps = torch.randn(batch_size, 1, 28, 28).to(device)
        # print(eps.shape)
        # print(x0.shape)
        loss = criterion(eps, model(torch.sqrt(alpha_bar[t-1]) * x0 + 
                                    torch.sqrt(1 - alpha_bar[t-1]) * eps, t.view(batch_size, 1, 1, 1).expand(batch_size, 1, 28, 28)))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

        if e % 100 == 99:
            print(f'{epoch, e+1}, loss: {running_loss:.3f}')
            running_loss = 0.0
        
    if epoch % 5 == 4:
        torch.save(model.state_dict(), f"DDPM_{epoch}.pth")

(0, 100), loss: 876.564
(0, 200), loss: 92.737
(0, 300), loss: 76.344
(0, 400), loss: 63.235
(0, 500), loss: 45.309
(0, 600), loss: 33.599
(0, 700), loss: 40.094
(0, 800), loss: 22.579
(0, 900), loss: 18.521
(0, 1000), loss: 53.051
(0, 1100), loss: 18.092
(0, 1200), loss: 17.168
(0, 1300), loss: 14.462
(0, 1400), loss: 14.974
(0, 1500), loss: 13.980
(0, 1600), loss: 13.768
(0, 1700), loss: 59.343
(0, 1800), loss: 13.305
(0, 1900), loss: 9.422
(0, 2000), loss: 8.914
(0, 2100), loss: 7.550
(0, 2200), loss: 6.786
(0, 2300), loss: 6.842
(0, 2400), loss: 28.081
(0, 2500), loss: 8.800
(0, 2600), loss: 6.271
(0, 2700), loss: 6.877
(0, 2800), loss: 5.492
(0, 2900), loss: 5.122
(0, 3000), loss: 5.252
(0, 3100), loss: 5.006
(0, 3200), loss: 4.867
(0, 3300), loss: 4.591
(0, 3400), loss: 4.713
(0, 3500), loss: 4.623
(0, 3600), loss: 6.994
(0, 3700), loss: 5.285
(1, 100), loss: 7.183
(1, 200), loss: 12.929
(1, 300), loss: 4.558
(1, 400), loss: 3.988
(1, 500), loss: 4.148
(1, 600), loss: 4.135
(1, 7

KeyboardInterrupt: 

# Sampling

In [69]:
import matplotlib.pyplot as plt

In [71]:
xt = [torch.randn(batch_size, 1, 28, 28).to(device)]

for t in torch.arange(T, 0, -1):
    t = t.expand(16).to(device)
    z = torch.randn(batch_size, 1, 28, 28).to(device) if t[0] > 1 else torch.zeros(batch_size, 1, 28, 28).to(device)
    # z = 16, 1, 28, 28
    print(xt[0].shape)
    print(alpha[t - 1].shape)
    xt_minus_one = 1 / torch.sqrt(alpha[t - 1]) * (xt[0] - (1 - alpha[t - 1])/(torch.sqrt(1 - alpha_bar[t - 1])) * 
                                                   model(xt[0], t.view(batch_size, 1, 1, 1).expand(batch_size, 1, 28, 28))) + torch.sqrt(beta[t-1]) * z
    xt.insert(0, xt_minus_one)


    

torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 1, 1])
torch.Size([16, 1, 28, 28])


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 6.00 GiB total capacity; 4.10 GiB already allocated; 3.12 MiB free; 4.15 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF