# Denoising Diffusion Probabilistic Model

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:
!pip install scipy==1.11.1

Collecting scipy==1.11.1
  Downloading scipy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (59 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy<1.28.0,>=1.21.6 (from scipy==1.11.1)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scipy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected pack

In [22]:
import torch
import torch.nn as nn

## Forward Process

In [23]:
class ForwardProcess(nn.Module):
  def __init__(self,
              num_time_steps = 1000,
              beta_start = 1e-4,
              beta_end = 0.02,
              ):
    super().__init__()
    self.betas = torch.linspace(beta_start, beta_end, num_time_steps)
    self.alphas = 1 - self.betas
    self.alpha_bars = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
    self.sqrt_one_minus_alpha_bars = torch.sqrt(1 - self.alpha_bars)

  def add_noise(self, # add noise to a batch of original images at timestep t
                original, # input image tensor
                noise, # random noise tensor sampled from N(0, I)
                t, # timestep of each images in batch, may differ for each image
                ):
    sqrt_alpha_bar_t = self.sqrt_alpha_bars.to(original.device)[t]
    sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_bars.to(original.device)[t]

    # broadcast to multiply with original image
    sqrt_alpha_bar_t = sqrt_alpha_bar_t[:, None, None, None]
    sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[:, None, None, None]

    return sqrt_alpha_bar_t * original + sqrt_one_minus_alpha_bar_t * noise

In [24]:
# test
original = torch.randn(4, 1, 28, 28)
noise = torch.randn(4, 1, 28, 28)
t_steps = torch.randint(0, 1000, (4,)) # random 4 int in range [0,1000)

# test forward process
fp = ForwardProcess()
out = fp.add_noise(original, noise, t_steps)
print(out.shape)

torch.Size([4, 1, 28, 28])


## Reverse Process

In [25]:
class ReverseProcess(nn.Module):
  def __init__(self,
               num_time_steps = 1000,
               beta_start = 1e-4,
               beta_end = 0.02,
               ):
    super().__init__()
    self.betas = torch.linspace(beta_start, beta_end, num_time_steps)
    self.alphas = 1 - self.betas
    self.alpha_bars = torch.cumprod(self.alphas, dim=0)

  def sample_prev_timestep(self, # sample x_(t-1) given x_t and noise predicted by model
                           xt, # image tensor at timestep t
                           noise_pred,# noise predicted by model, same shape as xt
                           t, # current timestep
                           ):
    # original image prediction at current timestep t
    x0 = xt - (torch.sqrt(1 - self.alpha_bars.to(xt.device)[t]) * noise_pred)
    x0 = x0 / torch.sqrt(self.alpha_bars.to(xt.device)[t])
    x0 = torch.clamp(x0, -1., 1.)

    # mean of x_(t-1)
    mean = (xt - ((1 - self.alphas.to(xt.device)[t]) * noise_pred)
    / (torch.sqrt(1 - self.alpha_bars.to(xt.device)[t])))

    if t==0:
      return mean, x0
    else:
      var = (1 - self.alpha_bars.to(xt.device)[t-1]) / (1 - self.alpha_bars.to(xt.device)[t])
      var *= self.betas.to(xt.device)[t]
      sigma = var ** 0.5
      z = torch.randn(xt.shape).to(xt.device)

      return mean + sigma * z, x0


In [26]:
# test
xt = torch.randn(1, 1, 28, 28) # random xt
noise_pred = torch.randn(1, 1, 28, 28) # random noise_pred
t = torch.randint(0, 1000, (1,)) # random 4 int in range [0,1000)
# test reverse process
rp = ReverseProcess()
out, x0 = rp.sample_prev_timestep(xt, noise_pred, t)
print(out.shape)
print(x0.shape)


torch.Size([1, 1, 28, 28])
torch.Size([1, 1, 28, 28])


## Model Architecture

### Time Embedding

In [27]:
def get_time_embedding(
    time_steps: torch.Tensor, # scalar time-step (batch,)
    t_emb_dim: int, # embedding dimension
) -> torch.Tensor: # (batch, t_emb_dim)
  assert t_emb_dim % 2 == 0, "time embedding must be divisible by 2."
  half_dim = t_emb_dim // 2
  factor = 2 * torch.arange(start=0,
                            end = half_dim,
                            dtype=torch.float32,
                            device=time_steps.device
                            ) / (t_emb_dim)
  factor = 10000 ** factor
  t_emb = time_steps[:, None] # B -> (B, 1)
  t_emb = t_emb / factor # (B, 1) -> (B, half_dim)
  t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=1) # (B, half_dim) -> (B, t_emb_dim)
  return t_emb


### U-net

#### Utility Modules

In [28]:
class NormActConv(nn.Module):
    def __init__(self,
                 in_channels:int,
                 out_channels:int,
                 num_groups:int = 8,
                 kernel_size: int = 3,
                 norm:bool = True,
                 act:bool = True
                ):
        super(NormActConv, self).__init__()

        # GroupNorm
        self.g_norm = nn.GroupNorm(
            num_groups,
            in_channels
        ) if norm is True else nn.Identity()

        # Activation SiLU
        self.act = nn.SiLU() if act is True else nn.Identity()

        # Convolution
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            padding=(kernel_size - 1)//2
        )

    def forward(self, x):
        x = self.g_norm(x)
        x = self.act(x)
        x = self.conv(x)
        return x

class TimeEmbedding(nn.Module):
    """
    Maps the Time Embedding to the Required output Dimension.
    """
    def __init__(self,
                n_out:int, # Output Dimension
                t_emb_dim:int = 128 # Time Embedding Dimension
                ):
        super(TimeEmbedding, self).__init__()

        # Time Embedding Block
        self.te_block = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, n_out)
        )

    def forward(self, x):
        return self.te_block(x)

class SelfAttentionBlock(nn.Module):
    """
    Perform GroupNorm and Multiheaded Self Attention operation.
    """
    def __init__(self,
                 num_channels:int,
                 num_groups:int = 8,
                 num_heads:int = 4,
                 norm:bool = True
                ):
        super(SelfAttentionBlock, self).__init__()

        # GroupNorm
        self.g_norm = nn.GroupNorm(
            num_groups,
            num_channels
        ) if norm is True else nn.Identity()

        # Self-Attention
        self.attn = nn.MultiheadAttention(
            num_channels,
            num_heads,
            batch_first=True
        )

    def forward(self, x):
        batch_size, channels, h, w = x.shape
        x = x.reshape(batch_size, channels, h*w)
        x = self.g_norm(x)
        x = x.transpose(1, 2)
        x, _ = self.attn(x, x, x)
        x = x.transpose(1, 2).reshape(batch_size, channels, h, w)
        return x

class Downsample(nn.Module):
    """
    Perform Downsampling by the factor of k across Height and Width.
    """
    def __init__(self,
                 in_channels:int,
                 out_channels:int,
                 k:int = 2, # Downsampling factor
                 use_conv:bool = True, # If Downsampling using conv-block
                 use_mpool:bool = True # If Downsampling using max-pool
                ):
        super(Downsample, self).__init__()

        self.use_conv = use_conv
        self.use_mpool = use_mpool

        # Downsampling using Convolution
        self.cv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1),
            nn.Conv2d(
                in_channels,
                out_channels//2 if use_mpool else out_channels,
                kernel_size=4,
                stride=k,
                padding=1
            )
        ) if use_conv else nn.Identity()

        # Downsampling using Maxpool
        self.mpool = nn.Sequential(
            nn.MaxPool2d(k, k),
            nn.Conv2d(
                in_channels,
                out_channels//2 if use_conv else out_channels,
                kernel_size=1,
                stride=1,
                padding=0
            )
        ) if use_mpool else nn.Identity()

    def forward(self, x):

        if not self.use_conv:
            return self.mpool(x)

        if not self.use_mpool:
            return self.cv(x)

        return torch.cat([self.cv(x), self.mpool(x)], dim=1)

class Upsample(nn.Module):
    """
    Perform Upsampling by the factor of k across Height and Width
    """
    def __init__(self,
                 in_channels:int,
                 out_channels:int,
                 k:int = 2, # Upsampling factor
                 use_conv:bool = True, # Upsampling using conv-block
                 use_upsample:bool = True # Upsampling using nn.upsample
                ):
        super(Upsample, self).__init__()

        self.use_conv = use_conv
        self.use_upsample = use_upsample

        # Upsampling using conv
        self.cv = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels//2 if use_upsample else out_channels,
                kernel_size=4,
                stride=k,
                padding=1
            ),
            nn.Conv2d(
                out_channels//2 if use_upsample else out_channels,
                out_channels//2 if use_upsample else out_channels,
                kernel_size = 1,
                stride=1,
                padding=0
            )
        ) if use_conv else nn.Identity()

        # Upsamling using nn.Upsample
        self.up = nn.Sequential(
            nn.Upsample(
                scale_factor=k,
                mode = 'bilinear',
                align_corners=False
            ),
            nn.Conv2d(
                in_channels,
                out_channels//2 if use_conv else out_channels,
                kernel_size=1,
                stride=1,
                padding=0
            )
        ) if use_upsample else nn.Identity()

    def forward(self, x):

        if not self.use_conv:
            return self.up(x)

        if not self.use_upsample:
            return self.cv(x)

        return torch.cat([self.cv(x), self.up(x)], dim=1)

In [29]:
#Test
layer = Upsample(16, 32, 2, True, True)
x = torch.randn(4, 16, 32, 32)
layer(x).shape

torch.Size([4, 32, 64, 64])

#### Down-Conv Block

In [30]:
class DownC(nn.Module):
    """
    Perform Down-convolution on the input using following approach.
    1. Conv + TimeEmbedding
    2. Conv
    3. Skip-connection from input x.
    4. Self-Attention
    5. Skip-Connection from 3.
    6. Downsampling
    """
    def __init__(self,
                 in_channels:int,
                 out_channels:int,
                 t_emb_dim:int = 128, # Time Embedding Dimension
                 num_layers:int=2,
                 down_sample:bool = True # True for Downsampling
                ):
        super(DownC, self).__init__()

        self.num_layers = num_layers

        self.conv1 = nn.ModuleList([
            NormActConv(in_channels if i==0 else out_channels,
                        out_channels
                       ) for i in range(num_layers)
        ])

        self.conv2 = nn.ModuleList([
            NormActConv(out_channels,
                        out_channels
                       ) for _ in range(num_layers)
        ])

        self.te_block = nn.ModuleList([
            TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers)
        ])

        self.attn_block = nn.ModuleList([
            SelfAttentionBlock(out_channels) for _ in range(num_layers)
        ])

        self.down_block =Downsample(out_channels, out_channels) if down_sample else nn.Identity()

        self.res_block = nn.ModuleList([
            nn.Conv2d(
                in_channels if i==0 else out_channels,
                out_channels,
                kernel_size=1
            ) for i in range(num_layers)
        ])

    def forward(self, x, t_emb):

        out = x

        for i in range(self.num_layers):
            resnet_input = out

            # Resnet Block
            out = self.conv1[i](out)
            out = out + self.te_block[i](t_emb)[:, :, None, None]
            out = self.conv2[i](out)
            out = out + self.res_block[i](resnet_input)

            # Self Attention
            out_attn = self.attn_block[i](out)
            out = out + out_attn

        # Downsampling
        out = self.down_block(out)

        return out

#### Mid-Conv Block

In [31]:
class MidC(nn.Module):
    """
    Refine the features obtained from the DownC block.
    It refines the features using following operations:

    1. Resnet Block with Time Embedding
    2. A Series of Self-Attention + Resnet Block with Time-Embedding
    """
    def __init__(self,
                 in_channels:int,
                 out_channels:int,
                 t_emb_dim:int = 128,
                 num_layers:int = 2
                ):
        super(MidC, self).__init__()

        self.num_layers = num_layers

        self.conv1 = nn.ModuleList([
            NormActConv(in_channels if i==0 else out_channels,
                        out_channels
                       ) for i in range(num_layers + 1)
        ])

        self.conv2 = nn.ModuleList([
            NormActConv(out_channels,
                        out_channels
                       ) for _ in range(num_layers + 1)
        ])

        self.te_block = nn.ModuleList([
            TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers + 1)
        ])

        self.attn_block = nn.ModuleList([
            SelfAttentionBlock(out_channels) for _ in range(num_layers)
        ])

        self.res_block = nn.ModuleList([
            nn.Conv2d(
                in_channels if i==0 else out_channels,
                out_channels,
                kernel_size=1
            ) for i in range(num_layers + 1)
        ])

    def forward(self, x, t_emb):
        out = x

        # First-Resnet Block
        resnet_input = out
        out = self.conv1[0](out)
        out = out + self.te_block[0](t_emb)[:, :, None, None]
        out = self.conv2[0](out)
        out = out + self.res_block[0](resnet_input)

        # Sequence of Self-Attention + Resnet Blocks
        for i in range(self.num_layers):

            # Self Attention
            out_attn = self.attn_block[i](out)
            out = out + out_attn

            # Resnet Block
            resnet_input = out
            out = self.conv1[i+1](out)
            out = out + self.te_block[i+1](t_emb)[:, :, None, None]
            out = self.conv2[i+1](out)
            out = out + self.res_block[i+1](resnet_input)

        return out

#### Up-Conv Block

In [32]:
class UpC(nn.Module):
    """
    Perform Up-convolution on the input using following approach.
    1. Upsampling
    2. Conv + TimeEmbedding
    3. Conv
    4. Skip-connection from 1.
    5. Self-Attention
    6. Skip-Connection from 3.
    """
    def __init__(self,
                 in_channels:int,
                 out_channels:int,
                 t_emb_dim:int = 128, # Time Embedding Dimension
                 num_layers:int = 2,
                 up_sample:bool = True # True for Upsampling
                ):
        super(UpC, self).__init__()

        self.num_layers = num_layers

        self.conv1 = nn.ModuleList([
            NormActConv(in_channels if i==0 else out_channels,
                        out_channels
                       ) for i in range(num_layers)
        ])

        self.conv2 = nn.ModuleList([
            NormActConv(out_channels,
                        out_channels
                       ) for _ in range(num_layers)
        ])

        self.te_block = nn.ModuleList([
            TimeEmbedding(out_channels, t_emb_dim) for _ in range(num_layers)
        ])

        self.attn_block = nn.ModuleList([
            SelfAttentionBlock(out_channels) for _ in range(num_layers)
        ])

        self.up_block =Upsample(in_channels, in_channels//2) if up_sample else nn.Identity()

        self.res_block = nn.ModuleList([
            nn.Conv2d(
                in_channels if i==0 else out_channels,
                out_channels,
                kernel_size=1
            ) for i in range(num_layers)
        ])

    def forward(self, x, down_out, t_emb):

        # Upsampling
        x = self.up_block(x)
        x = torch.cat([x, down_out], dim=1)

        out = x
        for i in range(self.num_layers):
            resnet_input = out

            # Resnet Block
            out = self.conv1[i](out)
            out = out + self.te_block[i](t_emb)[:, :, None, None]
            out = self.conv2[i](out)
            out = out + self.res_block[i](resnet_input)

            # Self Attention
            out_attn = self.attn_block[i](out)
            out = out + out_attn

        return out

#### U-net

In [33]:
class Unet(nn.Module):
    """
    U-net architecture which is used to predict noise
    in the paper "Denoising Diffusion Probabilistic Model".

    U-net consists of Series of DownC blocks followed by MidC
    followed by UpC.
    """

    def __init__(self,
                 im_channels: int = 1, # RGB
                 down_ch: list = [32, 64, 128, 256],
                 mid_ch: list = [256, 256, 128],
                 up_ch: list[int] = [256, 128, 64, 16],
                 down_sample: list[bool] = [True, True, False],
                 t_emb_dim: int = 128,
                 num_downc_layers:int = 2,
                 num_midc_layers:int = 2,
                 num_upc_layers:int = 2
                ):
        super(Unet, self).__init__()

        self.im_channels = im_channels
        self.down_ch = down_ch
        self.mid_ch = mid_ch
        self.up_ch = up_ch
        self.t_emb_dim = t_emb_dim
        self.down_sample = down_sample
        self.num_downc_layers = num_downc_layers
        self.num_midc_layers = num_midc_layers
        self.num_upc_layers = num_upc_layers

        self.up_sample = list(reversed(self.down_sample)) # [False, True, True]

        # Initial Convolution
        self.cv1 = nn.Conv2d(self.im_channels, self.down_ch[0], kernel_size=3, padding=1)

        # Initial Time Embedding Projection
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        # DownC Blocks
        self.downs = nn.ModuleList([
            DownC(
                self.down_ch[i],
                self.down_ch[i+1],
                self.t_emb_dim,
                self.num_downc_layers,
                self.down_sample[i]
            ) for i in range(len(self.down_ch) - 1)
        ])

        # MidC Block
        self.mids = nn.ModuleList([
            MidC(
                self.mid_ch[i],
                self.mid_ch[i+1],
                self.t_emb_dim,
                self.num_midc_layers
            ) for i in range(len(self.mid_ch) - 1)
        ])

        # UpC Block
        self.ups = nn.ModuleList([
            UpC(
                self.up_ch[i],
                self.up_ch[i+1],
                self.t_emb_dim,
                self.num_upc_layers,
                self.up_sample[i]
            ) for i in range(len(self.up_ch) - 1)
        ])

        # Final Convolution
        self.cv2 = nn.Sequential(
            nn.GroupNorm(8, self.up_ch[-1]),
            nn.Conv2d(self.up_ch[-1], self.im_channels, kernel_size=3, padding=1)
        )

    def forward(self, x, t):

        out = self.cv1(x)

        # Time Projection
        t_emb = get_time_embedding(t, self.t_emb_dim)
        t_emb = self.t_proj(t_emb)

        # DownC outputs
        down_outs = []

        for down in self.downs:
            down_outs.append(out)
            out = down(out, t_emb)

        # MidC outputs
        for mid in self.mids:
            out = mid(out, t_emb)

        # UpC Blocks
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)

        # Final Conv
        out = self.cv2(out)

        return out
# Test
model = Unet()

In [34]:
# Test
model = Unet()
x = torch.randn(4, 1, 32, 32)
t = torch.randint(0, 10, (4,))
model(x, t).shape

torch.Size([4, 1, 32, 32])

## Training

### Dataset

In [35]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torchvision import transforms

class CustomMnistDataset(Dataset):
    """
    Custom MNIST Dataset không cần đọc CSV, sử dụng torchvision.datasets.MNIST
    """
    def __init__(self, root="./data", train=True, num_datapoints=None):
        super(CustomMnistDataset, self).__init__()

        # Load MNIST từ torchvision
        self.dataset = MNIST(
            root=root,
            train=train,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),           # [0,1]
                transforms.Lambda(lambda x: x * 2 - 1)  # [-1,1]
            ])
        )

        if num_datapoints is not None:
            self.dataset.data = self.dataset.data[:num_datapoints]
            self.dataset.targets = self.dataset.targets[:num_datapoints]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Chỉ trả về ảnh, không cần label (vì dùng trong Diffusion model)
        img, _ = self.dataset[idx]
        return img


### Training-Loop

In [36]:
class CONFIG:
    model_path = 'ddpm_unet.pth'
    generated_csv_path = '/content/drive/MyDrive/Phase 2/DDPM/ddpm_unet.pth'
    num_epochs = 20
    lr = 1e-4
    num_timesteps = 1000
    batch_size = 128
    img_size = 28
    in_channels = 1
    num_img_to_generate = 256


In [37]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import numpy as np

def train(cfg):
    # Dataset and Dataloader
    mnist_ds = CustomMnistDataset(root="./data", train=True)  # Chỉnh lại từ tải CSV sang tải từ torchvision
    mnist_dl = DataLoader(mnist_ds, cfg.batch_size, shuffle=True)

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device: {device}\n')

    # Initiate Model
    model = Unet().to(device)

    # Initialize Optimizer and Loss Function
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    criterion = torch.nn.MSELoss()

    # Diffusion Forward Process to add noise
    dfp = ForwardProcess()

    # Best Loss
    best_eval_loss = float('inf')

    # Train
    for epoch in range(cfg.num_epochs):

        # For Loss Tracking
        losses = []

        # Set model to train mode
        model.train()

        # Loop over dataloader
        for imgs in tqdm(mnist_dl):

            imgs = imgs.to(device)

            # Generate noise and timestamps
            noise = torch.randn_like(imgs).to(device)
            t = torch.randint(0, cfg.num_timesteps, (imgs.shape[0],)).to(device)

            # Add noise to the images using Forward Process
            noisy_imgs = dfp.add_noise(imgs, noise, t)

            # Avoid Gradient Accumulation
            optimizer.zero_grad()

            # Predict noise using U-net Model
            noise_pred = model(noisy_imgs, t)

            # Calculate Loss
            loss = criterion(noise_pred, noise)
            losses.append(loss.item())

            # Backprop + Update model params
            loss.backward()
            optimizer.step()

        # Mean Loss
        mean_epoch_loss = np.mean(losses)

        # Display
        print('Epoch:{} | Loss : {:.4f}'.format(
            epoch + 1,
            mean_epoch_loss,
        ))

        # Save based on train-loss
        if mean_epoch_loss < best_eval_loss:
            best_eval_loss = mean_epoch_loss
            torch.save(model.state_dict(), cfg.model_path)  # Save state_dict thay vì model hoàn chỉnh

    print(f'Done training...')


In [None]:
# Config
cfg = CONFIG()

# TRAIN
train(cfg)

Device: cpu



  0%|          | 0/469 [00:00<?, ?it/s]

In [None]:
def generate(cfg):
    """
    Given Pretrained DDPM U-net model, Generate Real-life
    Images from noise by going backward step by step. i.e.,
    Mapping of Random Noise to Real-life images.
    """

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #print(f'Device: {device}\n')

    # Initialize Diffusion Reverse Process
    drp = ReverseProcess()

    # Set model to eval mode
    model = torch.load(cfg.model_path).to(device)
    model.eval()

    # Generate Noise sample from N(0, 1)
    xt = torch.randn(1, cfg.in_channels, cfg.img_size, cfg.img_size).to(device)

    # Denoise step by step by going backward.
    with torch.no_grad():
        for t in reversed(range(cfg.num_timesteps)):
            noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device))
            xt, x0 = drp.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device))

    # Convert the image to proper scale
    xt = torch.clamp(xt, -1., 1.).detach().cpu()
    xt = (xt + 1) / 2

    return xt

In [None]:
# Load model and config
cfg = CONFIG()

# Generate
generated_imgs = []
for i in tqdm(range(cfg.num_img_to_generate)):
    xt = generate(cfg)
    xt = 255 * xt[0][0].numpy()
    generated_imgs.append(xt.astype(np.uint8).flatten())

# Save Generated Data CSV
generated_df = pd.DataFrame(generated_imgs, columns=[f'pixel{i}' for i in range(784)])
generated_df.to_csv(cfg.generated_csv_path, index=False)

# Visualize
from matplotlib import pyplot as plt
fig, axes = plt.subplots(8, 8, figsize=(5, 5))

# Plot each image in the corresponding subplot
for i, ax in enumerate(axes.flat):
    ax.imshow(np.reshape(generated_imgs[i], (28, 28)), cmap='gray')  # You might need to adjust the colormap based on your images
    ax.axis('off')  # Turn off axis labels

plt.tight_layout()  # Adjust spacing between subplots
plt.show()

In [None]:
def get_activation(dataloader,
                   model,
                   preprocess, # Preprocessing Transform for InceptionV3
                   device = 'cpu'
                  ):
    """
    Given Dataloader and Model, Generate N X 2048
    Dimensional activation map for N data points
    in dataloader.
    """

    # Set model to evaluation Mode
    model.to(device)
    model.eval()

    # Save activations
    pred_arr = np.zeros((len(dataloader.dataset), 2048))

    # Batch Size
    batch_size = dataloader.batch_size

    # Loop over Dataloader
    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader)):

            # Transform the Batch according to Inceptionv3 specification
            batch = torch.stack([preprocess(img) for img in batch]).to(device)

            # Predict
            pred = model(batch).cpu().numpy()

            # Store
            pred_arr[i*batch_size : i*batch_size + batch.size(0), :] = pred

    return pred_arr

def calculate_activation_statistics(dataloader,
                                    model,
                                    preprocess,
                                    device='cpu'
                                   ):
    """
    Get mean vector and covariance matrix of the activation maps.
    """

    # Get activation maps
    act = get_activation(dataloader,
                         model,
                         preprocess, # Preprocessing Transform for InceptionV3
                         device
                       )
    # Mean
    mu = np.mean(act, axis=0)

    # Covariance Metric
    sigma = np.cov(act, rowvar=False)

    return mu, sigma

from scipy import linalg

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):

    """
    Given Mean and Sigma of Real and Generated Data,
    it calculates FID between them using:

     d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    """
    # Make sure they have appropriate dims
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)

    # Handle various cases
    if not np.isfinite(covmean).all():
        msg = (
            "fid calculation produces singular product; "
            "adding %s to diagonal of cov estimates"
        ) % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

In [None]:
# Transform to Convert Output of CustomMnistDataset class to Inception format.
import torchvision.transforms as transforms

transform_inception = transforms.Compose([
    transforms.Lambda(lambda x: (x + 1.0)/2.0), # [-1, 1] => [0, 1]
    transforms.ToPILImage(), # Tensor to PIL Image
    transforms.Resize((299, 299)),
    transforms.Grayscale(num_output_channels=3),  # Convert to RGB format
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization

])

# Load InceptionV3 Model
import torchvision.models as models
from torchvision.models.inception import Inception_V3_Weights
model = models.inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
model.fc = nn.Identity()

# Mean and Sigma For Generated Data
mnist_ds = CustomMnistDataset(cfg.generated_csv_path, cfg.num_img_to_generate)
mnist_dl = DataLoader(mnist_ds, cfg.batch_size//4, shuffle=False)
mu1, sigma1 = calculate_activation_statistics(mnist_dl, model, preprocess = transform_inception, device='cuda')

# Mean and Sigma for Test Data
mnist_ds = CustomMnistDataset(cfg.test_csv_path, cfg.num_img_to_generate)
mnist_dl = DataLoader(mnist_ds, cfg.batch_size//4, shuffle=False)
mu2, sigma2 = calculate_activation_statistics(mnist_dl, model, preprocess = transform_inception, device='cuda')

# Calculate FID
fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
print(f'FID-Score: {fid}')