In [1]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
# from diffusion_utilities import */

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class EmbedFC(nn.Module):
    # Embedding for the timestep and context, keeping it from the original code for compatibility
    def __init__(self, in_features, out_features):
        super(EmbedFC, self).__init__()
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x):
        return F.relu(self.fc(x))

class VAE(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, latent_dim=50):
        super(VAE, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, n_feat, 4, 2, 1), # [batch, n_feat, 14, 14]
            nn.ReLU(),
            nn.Conv2d(n_feat, n_feat * 2, 4, 2, 1), # [batch, n_feat*2, 7, 7]
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(n_feat * 2 * 7 * 7, latent_dim * 2) # Output size is doubled for mean and log-variance
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, n_feat * 2 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (n_feat * 2, 7, 7)),
            nn.ConvTranspose2d(n_feat * 2, n_feat, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(n_feat, in_channels, 4, 2, 1),
            nn.Sigmoid() # Ensure output is in [0,1]
        )

        # Embedding layers for context and time, similar to original
        self.contextembed = EmbedFC(n_cfeat, latent_dim)
        self.timeembed = EmbedFC(1, latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x, t, c=None):
        # Encode
        encoded = self.encoder(x)
        mu, logvar = torch.chunk(encoded, 2, dim=1) # Split the encoder output into mu and logvar
        z = self.reparameterize(mu, logvar)

        # Embed time and context
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat, device=x.device)
        temb = self.timeembed(t)
        cemb = self.contextembed(c)
        
        # Add embeddings to latent space
        z = z + temb + cemb

        # Decode
        return self.decoder(z), mu, logvar

In [3]:
# hyperparameters

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'

# training hyperparameters
batch_size = 100
n_epoch = 32
lrate=1e-3

In [4]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

In [7]:
# construct model
nn_model = VAE(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat).to(device)

In [8]:
# re setup optimizer
optim = torch.optim.AdamW(nn_model.parameters(), lr=lrate)

In [None]:
# training with context code
# set into train mode
nn_model.train()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    
    pbar = tqdm(dataloader, mininterval=2 )
    for x, c in pbar:   # x: images  c: context
        optim.zero_grad()
        x = x.to(device)
        c = c.to(x)
        
        # randomly mask out c
        context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(device)
        c = c * context_mask.unsqueeze(-1)
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 
        x_pert = perturb_input(x, t, noise)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps, c=c)
        
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        
        optim.step()

    # save model periodically
    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"context_model_{ep}.pth")
        print('saved model at ' + save_dir + f"context_model_{ep}.pth")

In [9]:
# hyperparameters
# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02
# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'
# training hyperparameters
batch_size = 100
n_epoch = 32
lrate=1e-3

# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1

In [None]:
 # load dataset and construct optimizer
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=12)
optim = torch.optim.AdamW(nn_model.parameters(), lr=lrate)

In [None]:
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]).sqrt() * noise

In [None]:
# training without context code
# set into train mode
nn_model.train()
for ep in range(n_epoch):
    print(f'epoch {ep}')
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    pbar = tqdm(dataloader, mininterval=2)
    for x, _ in pbar: # x: images
        optim.zero_grad()
        x = x.to(device)
        # perturb data
        noise = torch.randn_like(x)
        4
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device)
        x_pert = perturb_input(x, t, noise)
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        optim.step()
    # save model periodically
    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")

In [None]:
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

In [None]:
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):

    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)
    # array to keep track of generated steps for plotting
    intermediate = []
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')
        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
        # sample some random noise to inject back in. For i = 1, don't add backin noise
        z = torch.randn_like(samples) if i > 1 else 0
        eps = nn_model(samples, t) # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate ==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())
        intermediate = np.stack(intermediate)
    return samples, intermediate

In [None]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")
# visualize samples
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())