# Lab 4 - Fast Sampling

DeepLearning.AI - How Diffusion Models Work

[REF] https://learn.deeplearning.ai/courses/diffusion-models/

In [1]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *

import warnings
warnings.filterwarnings('ignore')

# Setting Things Up

In [2]:
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)

# Fast Sampling (DDIM)

In [3]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/L4_model_31.pth", map_location=device))
nn_model.eval() 
print("Loaded in Model without context")

Loaded in Model without context


In [4]:
# visualize samples
plt.clf()
samples, intermediate = sample_ddim(nn_model, 32, n=25)
animation_ddim = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddim.to_jshtml())

gif animating frame 24 of 25

<Figure size 640x480 with 0 Axes>

In [5]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/L4_context_model_31.pth", map_location=device))
nn_model.eval() 
print("Loaded in Context Model")

Loaded in Context Model


In [6]:
# visualize samples
plt.clf()
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
samples, intermediate = sample_ddim_context(nn_model, 32, ctx)
animation_ddpm_context = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm_context.to_jshtml())

gif animating frame 19 of 20

<Figure size 640x480 with 0 Axes>

#### Compare DDPM, DDIM speed

In [7]:
%timeit -r 3 -n 10 sample_ddim(nn_model, 32, )
%timeit -r 3 -n 10 sample_ddpm(nn_model, 32, )

117 ms ± 3.95 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
3.48 s ± 328 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
