#Import libraries

In [2]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from keras.datasets.mnist import load_data
# import unet.py file into current foler
from unet import UNet
import matplotlib.pyplot as plt

#Download Mnist dataset

In [3]:
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.

def sample_images(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    data = torch.from_numpy(trainX[indices]).unsqueeze(1).to(device)
    return torch.nn.functional.interpolate(data, 32)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


# DDM Modle build



In [4]:
class DenoisingDiffusionModel():

    def __init__(self, T : int, model : nn.Module, device : str):

        self.T = T
        self.UNet_Model = model.to(device)
        self.device = device
        self.beta = torch.linspace(1e-4, 0.02, T).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def training(self, batch_size, optimizer):

        x0 = sample_images(batch_size, self.device)

        t = torch.randint(1, self.T + 1, (batch_size,), device=self.device, dtype=torch.long)

        epsilon = torch.randn_like(x0)

        # Take one gradient descent step
        alpha_bar_t = self.alpha_bar[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        predicted_epsilon = self.UNet_Model(torch.sqrt(
            alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * epsilon, t-1)
        loss = nn.functional.mse_loss(epsilon, predicted_epsilon)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    @torch.no_grad()
    def sampling(self, num_samples=1, img_chn=1, img_size=(32, 32)):

        x = torch.randn((num_samples, img_chn, img_size[0], img_size[1]),
                         device=self.device)

        #progress_bar = tqdm if use_tqdm else lambda x : x
        for t in (range(self.T, 0, -1)):
            z = torch.randn_like(x) if t > 1 else torch.zeros_like(x)

            t = torch.ones(num_samples, dtype=torch.long, device=self.device) * t

            beta_t = self.beta[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            alpha_t = self.alpha[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            alpha_bar_t = self.alpha_bar[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

            mean = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(
                1 - alpha_bar_t)) * self.UNet_Model(x, t-1))
            sigma = torch.sqrt(beta_t)
            x = mean + sigma * z

        return x


# Training of Model

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
model = UNet()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
ddm = DenoisingDiffusionModel(1000, model, device)

In [6]:
train_loss = []
for epoch in range(1000):
    loss = ddm.training(batch_size, optimizer)
    print(loss)
    train_loss.append(loss)


0.9744899272918701
0.8153575658798218
0.6853034496307373
0.576460599899292
0.49690714478492737
0.4365069270133972
0.35328859090805054
0.2989097535610199
0.27143803238868713
0.27725455164909363
0.24187184870243073
0.2321491837501526
0.20969411730766296
0.1879080981016159
0.15793690085411072
0.16810056567192078
0.17400312423706055
0.13233375549316406
0.14311863481998444
0.15915602445602417
0.1767292022705078
0.1809975802898407
0.1584121584892273
0.18061307072639465
0.1394634246826172
0.18140162527561188
0.15334096550941467
0.12480732053518295
0.15064270794391632
0.11791599541902542
0.12731242179870605
0.151548832654953
0.11946102976799011
0.12095235288143158
0.128426194190979
0.09740369021892548
0.10144276916980743
0.13254623115062714
0.17123118042945862
0.12255419790744781
0.16394579410552979
0.1566484570503235
0.10948579013347626
0.12692636251449585
0.12286484241485596
0.11506035923957825
0.10043691098690033
0.1225321963429451
0.09635448455810547
0.10486321896314621
0.10622981190681458

# Output

In [7]:
input_images=100
samples = ddm.sampling(num_samples=input_images)
plt.figure(figsize=(17, 17))
for i in range(input_images):
    plt.subplot(10, 10, 1 + i)
    plt.axis('off')
    plt.imshow(samples[i].squeeze(0).clip(0, 1).data.cpu().numpy(), cmap='gray')
plt.savefig(f'epoch_{epoch}.png')
plt.close()

torch.save(model.cpu(), f'DDM_epoch_{epoch}')