In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init

%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

from torch.nn.parameter import Parameter
from torchvision.datasets import MNIST
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

**Dataset**

In [None]:
mnist_trainset = MNIST("./temp/", train=True, download=True, transform=transforms.ToTensor())
mnist_testset = MNIST("./temp/", train=False, download=True, transform=transforms.ToTensor())

In [None]:
# To speed up training we'll only work on a subset of the data
x_train_set = mnist_trainset.data.view(-1, 784).float()
x_test = mnist_testset.data.view(-1, 784).float()
x_train = torch.cat((x_train_set, x_test), dim=0)

print("Information on dataset")
print("x_train", x_train.shape)

# normalize the inputs
x_train.div_(255)

**Network**

In [None]:
class ScoreNetwork0(torch.nn.Module):
    # takes an input image and time, returns the score function
    def __init__(self):
        super().__init__()
        nch = 2
        chs = [32, 64, 128, 256, 256]
        self._convs = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv2d(2, chs[0], kernel_size=3, padding=1),  # (batch, ch, 28, 28)
                torch.nn.LogSigmoid(),  # (batch, 8, 28, 28)
            ),
            torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 14, 14)
                torch.nn.Conv2d(chs[0], chs[1], kernel_size=3, padding=1),  # (batch, ch, 14, 14)
                torch.nn.LogSigmoid(),  # (batch, 16, 14, 14)
            ),
            torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 7, 7)
                torch.nn.Conv2d(chs[1], chs[2], kernel_size=3, padding=1),  # (batch, ch, 7, 7)
                torch.nn.LogSigmoid(),  # (batch, 32, 7, 7)
            ),
            torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),  # (batch, ch, 4, 4)
                torch.nn.Conv2d(chs[2], chs[3], kernel_size=3, padding=1),  # (batch, ch, 4, 4)
                torch.nn.LogSigmoid(),  # (batch, 64, 4, 4)
            ),
            torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 2, 2)
                torch.nn.Conv2d(chs[3], chs[4], kernel_size=3, padding=1),  # (batch, ch, 2, 2)
                torch.nn.LogSigmoid(),  # (batch, 64, 2, 2)
            ),
        ])
        self._tconvs = torch.nn.ModuleList([
            torch.nn.Sequential(
                # input is the output of convs[4]
                torch.nn.ConvTranspose2d(chs[4], chs[3], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, 64, 4, 4)
                torch.nn.LogSigmoid(),
            ),
            torch.nn.Sequential(
                # input is the output from the above sequential concated with the output from convs[3]
                torch.nn.ConvTranspose2d(chs[3] * 2, chs[2], kernel_size=3, stride=2, padding=1, output_padding=0),  # (batch, 32, 7, 7)
                torch.nn.LogSigmoid(),
            ),
            torch.nn.Sequential(
                # input is the output from the above sequential concated with the output from convs[2]
                torch.nn.ConvTranspose2d(chs[2] * 2, chs[1], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, chs[2], 14, 14)
                torch.nn.LogSigmoid(),
            ),
            torch.nn.Sequential(
                # input is the output from the above sequential concated with the output from convs[1]
                torch.nn.ConvTranspose2d(chs[1] * 2, chs[0], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, chs[1], 28, 28)
                torch.nn.LogSigmoid(),
            ),
            torch.nn.Sequential(
                # input is the output from the above sequential concated with the output from convs[0]
                torch.nn.Conv2d(chs[0] * 2, chs[0], kernel_size=3, padding=1),  # (batch, chs[0], 28, 28)
                torch.nn.LogSigmoid(),
                torch.nn.Conv2d(chs[0], 1, kernel_size=3, padding=1),  # (batch, 1, 28, 28)
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x: (..., ch0 * 28 * 28), t: (..., 1)
        x2 = torch.reshape(x, (*x.shape[:-1], 1, 28, 28))  # (..., ch0, 28, 28)
        tt = t[..., None, None].expand(*t.shape[:-1], 1, 28, 28)  # (..., 1, 28, 28)
        x2t = torch.cat((x2, tt), dim=-3)
        signal = x2t
        signals = []
        for i, conv in enumerate(self._convs):
            signal = conv(signal)
            if i < len(self._convs) - 1:
                signals.append(signal)

        for i, tconv in enumerate(self._tconvs):
            if i == 0:
                signal = tconv(signal)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                signal = tconv(signal)
        signal = torch.reshape(signal, (*signal.shape[:-3], -1))  # (..., 1 * 28 * 28)
        return signal

score_network = ScoreNetwork0().to(device)

**Define loss and optimizer**

In [None]:
#optimizer
opt = torch.optim.AdamW(score_network.parameters(), lr=2e-4)

#loss
loss_MSE = torch.nn.MSELoss()

**Define parameters**

In [None]:
#number of transformations in Forward process
T = 1000
#variance schedule 
betas = torch.linspace(0.0001, 0.02, T)  #myb should be sorted or we can check out the SDE thing so we need no variance schedule
#the cumulative alpha 
alphas = 1 - betas
alpha_calculate_cumulative = lambda idx: torch.prod(alphas[:idx+1])
#epochs
max_epochs = 100
#convergence check
convergence_threshold = 1e-6 

# batch size
batch_size = 128
# number of batches
num_batches = x_train.shape[0] // batch_size
print("num_batches", num_batches)

**Traininig**

In [None]:
def training(losses):
    
    get_slice = lambda i,  size: range(i * size, (i + 1) * size)

    previous_loss = 0
    for i in range(max_epochs):
        current_loss = []
        for j in range(num_batches):
            slce = get_slice(j, batch_size) 
            #sample training params 
            x_0_np = x_train[slce] 
            x_0 = torch.tensor(x_0_np).to(device)  # Convert to tensor and move to device
            x_0 = x_0.view(batch_size, 28*28)

            #t = torch.rand((1, 1), dtype=x_0.dtype, device=x_0.device) * (1 - 1e-4) + 1e-4
            #t = torch.randint(1, T + 1, (1, 1), dtype=x_0.dtype, device=x_0.device) / T
            t = torch.randint(1, T + 1, (batch_size, 1), dtype=torch.int64, device=x_0.device)

            epsilon = torch.randn_like(x_0, device=device)   #N(0,1)
            
            #create noisy observation
            alpha_ = [alpha_calculate_cumulative(time).detach().to(device) for time in t]
            alpha_ = torch.tensor(alpha_, device=device).view(-1, 1)  
            x_noisy = torch.sqrt(alpha_)*x_0 + torch.sqrt(1-alpha_)*epsilon
            
            #train
            epsilon_hat = score_network(x_noisy, t)  #forward pass
            batch_loss = loss_MSE(epsilon, epsilon_hat)    #calculate loss
            opt.zero_grad()                          #reset grad
            batch_loss.backward()                          #backprop
            opt.step()                               #train params
            current_loss.append(batch_loss.item())
        losses.append(np.mean(current_loss))

        #stop training until loss converged
        
        print(f"Epoch {i+1} | Loss {losses[-1]}")
        # if abs(previous_loss - current_loss) < convergence_threshold:
        #     break
        # previous_loss = current_loss


In [None]:
def plot_loss(losses):
    """
    Plots the training loss over time.

    Parameters:
    losses (list or array): A list of loss values over epochs or iterations.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss', color='blue', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss over Time')
    plt.legend()
    plt.grid(True)
    plt.show()

losses = []
training(losses)
plot_loss(losses)
print(losses[-1])

**Sampling**

In [None]:
def sampling():
    with torch.no_grad(): #turn of grad 
        score_network.eval() #turn of dropout and similar

        #generate noise sample
        x_T = x_previous_t = torch.randn(1, 784, device=device)
        x_0 = None
        #removing the noise for each transition
        for t in range(T-1, 0, -1):

            #set noise 
            z =  torch.zeros_like(x_T)      #special case when it's the last transition 1->0
            if t > 1:
                z = torch.randn_like(x_T)   #otherwise, N(0,1) 

            #remove noise for this timestep transition
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_cum_t = alpha_calculate_cumulative(t).item()
            variance_t = beta_t             #variance of p_theta we have to choose based on x_0 - for now this since x_0 ~ N(0,I) 
            
            alpha_cum_t_minus_1 = alpha_calculate_cumulative(t - 1).item()  # Cumulative product up to t-1

            # Variance term for DDPM, incorporating cumulative alphas for stability
            #variance_t = beta_t * (1 - alpha_cum_t_minus_1) / (1 - alpha_cum_t)

            time = torch.tensor([[t]], dtype=torch.int64, device=x_T.device)
            epsilon_hat = score_network(x_previous_t, time)

            # x_previous_t = (1/np.sqrt(alpha_t))*(x_previous_t - ((1-alpha_t)/np.sqrt(1-alpha_cum_t))*epsilon_hat) + (variance_t*z)
            x_previous_t = (1 / np.sqrt(alpha_t)) * (x_previous_t - ((1 - alpha_t) / np.sqrt(1 - alpha_cum_t)) * epsilon_hat) + (np.sqrt(variance_t) * z)
            
            x_0 = x_previous_t #remember last for return

    #return final calculated x_0
    return x_0


In [None]:
# plot a few MNIST examples
idx, dim, classes = 0, 28, 5

# create empty canvas
canvas = np.zeros((dim * classes, classes * dim))

# fill with tensors
for i in range(classes):
    for j in range(classes):
        # Detach the tensor and convert it to a NumPy array
        canvas[i * dim:(i + 1) * dim, j * dim:(j + 1) * dim] = sampling().cpu().detach().reshape((dim, dim)).numpy()

    print(str(i) + ' sample')

# visualize matrix of tensors as gray scale image
plt.figure(figsize=(4, 4))
plt.axis('off')
plt.imshow(canvas, cmap='gray')
plt.title('MNIST handwritten digits')
plt.show()


**Evaluation**

In [None]:
from torchvision.models import inception_v3
from scipy.linalg import sqrtm

def fid_score(sampled_data, real_data):
    model = inception_v3(pretrained=True, transform_input=False).to(device).eval()

    sampled_images = torch.nn.functional.interpolate(sampled_data, size=(28, 28), mode='bilinear')
    sampled_images = F.interpolate(sampled_images, size=(299, 299), mode='bilinear')
    sampled_images = sampled_images.repeat(1, 3, 1, 1) # turn into rgb for model
    
    real_images = torch.nn.functional.interpolate(real_data, size=(28, 28), mode='bilinear')
    real_images = F.interpolate(real_images, size=(299, 299), mode='bilinear')
    real_images = real_images.repeat(1, 3, 1, 1) # turn into rgb for model
    
    with torch.no_grad():
        sampled_features = model(sampled_images).cpu().numpy()
        real_features = model(real_images).cpu().numpy()

    mu_sampled = np.mean(sampled_features, axis=0)
    cov_sampled = np.cov(sampled_features, rowvar=False)

    mu_real = np.mean(real_features, axis=0)
    cov_real = np.cov(real_features, rowvar=False)

    mu_diff = mu_sampled - mu_real
    cov_sqrtm = sqrtm(cov_sampled.dot(cov_real))
    return mu_diff.dot(mu_diff) + np.trace(cov_sampled + cov_real - 2 * cov_sqrtm)

def inception_score(sampled_data, real_data):
    model = inception_v3(pretrained=True, transform_input=False).to(device).eval()

In [None]:
# genearte samples
num_samples = 10000

sampled_data = torch.cat([sampling() for _ in range(num_samples)], dim=0).view(-1, 1, 28, 28).to(device)
real_data = x_train[:num_samples].view(-1, 1, 28, 28).to(device)


In [None]:

fid = fid_score(sampled_data, real_data)
print(f'FID score: {fid}')