# Lab 1 - Sampling

DeepLearning.AI - How Diffusion Models Work

[REF] https://learn.deeplearning.ai/courses/diffusion-models/

In [1]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *

import warnings
warnings.filterwarnings('ignore')

# Setting Things Up

In [2]:
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)

# Sampling

In [3]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/L1_model_trained.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

Loaded in Model


In [4]:
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(nn_model, 32)
animation_ddpm = plot_sample(intermediate_ddpm, 32, 4, save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

gif animating frame 31 of 32

<Figure size 640x480 with 0 Axes>

#### Demonstrate incorrectly sample without adding the 'extra noise'

In [5]:
# incorrectly sample without adding in noise
@torch.no_grad()
def sample_ddpm_incorrect(nn_model, n_sample):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # don't add back in noise
        z = 0

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i%20==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [6]:
# visualize samples
plt.clf()
samples, intermediate = sample_ddpm_incorrect(nn_model, 32)
animation = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)
HTML(animation.to_jshtml())

gif animating frame 31 of 32

<Figure size 640x480 with 0 Axes>