SAMPLING

In [None]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation , PillowWriter
import numpy as np
from IPython.display import HTML
from diffusion_utilities import *


SETTING THINGS UP

In [None]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat= 256, n_cfeat=10, height=28):
        super(ContextUnet, self).__init__()

        #number of input channels , number of intermediate features and number of classes 
        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height# assume h == w, must be divisible by 4, so also by 28,24,20,16.....

        #initialize the initial convolutional layer 
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res = True)

        #initialize the down sampling path of the u-net with two levels
        self.down1 = UnetDown(n_feat, n_feat) #down1 #[10,256,8,8]
        self.down2 = UnetDown(n_feat, 2*n_feat)  #down2 #[10,256,4,4]
        

        #original :self_to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        self.to_vec = nn.Seuqntial(nn.AvgPool2d(7), nn.GELU())

        #Embed the timestep and context label with a one-layer fully connected neural network
        self.timeembed1 = Embed(1, 2*n_feat)
        self.timeembed2 = Embed(1, 1*n_feat)
        self.contextembed1 = Embed(n_cfeat, 2*n_feat)
        self.contextembed2 = Embed(n_cfeat, 1*n_feat)

        #initialize the final convulational layers to the map to the same number of channels as the input image
        self.out = nn.Sequential(
            nn.Conv2d(2*n_feat, n_feat, 3,1,1), # reduce the number of feature maps, #in_channels, #out channels, stride =1 , padding= 0
            nn.GroupNorm(8, n_feat), #Normalize
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels,3,1,1),#map to same number of channels as input
        )