# Dataset & Baseline

In [None]:
############### PICK ONE ################
#dataroot = "/scratch/<BU user name>/img_align_celeba/" # if on SCC
dataroot = "../score_sde_dev/img_align_celeba/" # if on colab

In [None]:
from __future__ import print_function
import os, math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm as tqdm_class
import torchvision
from PIL import Image
from copy import deepcopy
from model import UNet

# The CelebA dataset contains 40 binary attribute labels for each image
attributes = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 
 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 
 'Big_Lips', 'Big_Nose', 'Black_Hair', 
 'Blond_Hair', 'Blurry', 'Brown_Hair', 
 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 
 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 
 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 
 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 
 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 
 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 
 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 
 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 
 'Young']

def set_random_seed(seed=999):
    # Set random seed for reproducibility
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

In [None]:
class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, transform = None):
        '''Initialize the dataset.'''
        self.transform = transform
        self.root = dataroot
        self.attr_txt = dataroot + 'list_attr_celeba.txt'
        self._parse()
    
    def _parse(self):
        '''
        Parse the celeba text file.
        Pupulate the following private variables:
         - self.ys: A list of 1D tensors with 40 binary attribute labels.
         - self.im_paths: A list of strings (image paths).
        '''
        self.im_paths = [] # list of jpeg filenames 
        self.ys = []       # list of attribute labels
        
        def _to_binary(lst):
            return torch.tensor([0 if lab == '-1' else 1 for lab in lst])
            
        with open(self.attr_txt) as f:
            for line in f:
                assert len(line.strip().split()) == 41
                fl = line.strip().split()
                if fl[0][-4:] == '.jpg': # if not header
                    self.im_paths.append(self.root + fl[0]) # jpeg filename
                    self.ys.append(_to_binary(fl[1:]))      # 1D tensor of 40 binary attributes
        
    def __len__(self):
        '''Return length of the dataset.'''
        return len(self.ys)

    def __getitem__(self, index):
        '''
        Return the (image, attributes) tuple.
        This function gets called when you index the dataset.
        '''
        def img_load(index):
            imraw = Image.open(self.im_paths[index])
            im = self.transform(imraw)
            return im

        target = self.ys[index]
        return img_load(index), target

In [None]:
class Diffusion:
    ''' 
    Implements the Diffusion process,
    including both training and sampling.
    '''
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, device = 'cuda'):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = torch.linspace(beta_start,beta_end,num_timesteps).to(device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
    def get_noisy_image(self, x_0, t):
        '''
        This function is only used for training.
        '''     
        
        eps = torch.randn_like(x_0).to(self.device)
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        x_t = torch.sqrt(alpha_bar_t)[:,None,None,None] * x_0 + torch.sqrt(1 - alpha_bar_t)[:,None,None,None] * eps
        return (x_t, eps)
    def sample(self, model, n, y=None):
        '''
        This function is used  to generate images.
        '''
        model.eval()
        
        with torch.no_grad():

            T = self.num_timesteps
            H = self.img_size
            W = self.img_size
            x_T = torch.randn(n, 3, H, W, device=self.device)
            for k in range(T):
                t = T-k-1
                if t == (T-1):
                  x_t = x_T
                
                epsilon = model(x_t, (torch.ones(n)*t).long().to(self.device), y)
                
                mu = (1/torch.sqrt(self.alpha[t]))*(x_t - self.beta[t]*epsilon/torch.sqrt(1-self.alpha_bar[t]))
                if t>0:
                  x_t = torch.randn_like(mu, device=self.device)*torch.sqrt(self.beta[t]) + mu
            x=x_t
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x
    
def show_images(images, **kwargs):
    plt.figure(figsize=(10, 10), dpi=80)
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    plt.imshow(im)
    plt.show()

In [None]:
class EMA:
    '''
    This class implements the Expontial Moving Average (EMA) for model weights.
    Only used for evaluation.
    Using the EMA averaged model increases the quality of generated images.
    '''
    def __init__(self, beta=0.995):
        '''
        beta is a hyperparameter.
        New model weights = beta * (old model weights) + 
                            (1 - beta) * (new model weights)
        '''
        super().__init__()
        self.beta = beta

    def step_ema(self, ma_model, current_model):
        '''
        ma_model: the averaged model we will use for evaluation
        current_model: The model being explicitly trained
        This function updates the weights of ma_model. Return None.
        '''
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        '''Private function used to update individual parameters.'''
        return old * self.beta + (1 - self.beta) * new

In [None]:
# We will resize to 64 x 64 for this assignment
image_size = 64

# Hyperparameters
batch_size = 64
learning_rate = 0.0002
weight_decay = 0.00001 # (L2 penalty)

# Transform used for training
train_transform = transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), 
                                                    (0.5, 0.5, 0.5)),
                           ])

# Make the dataset
dataset = CelebADataset(transform=train_transform)
# print(type(dataset))

# index of the binary attribute for gender
gender_index = attributes.index('Male')

# Run on TPU
device = 'cuda'

In [None]:

# Instantiate denoising autoencoder
model = UNet().to(device)

# ema_model is the averaged model that we'll use for sampling
ema_model = deepcopy(model)

# ema is the helper for updaing EMA weights
ema = EMA()

# Dataloader
trainloader = torch.utils.data.DataLoader(dataset, drop_last=True,  batch_size=batch_size, shuffle=True, num_workers=8)

# Mixed precision floating point arithmetic can speed up training on some GPUs
scaler = torch.cuda.amp.GradScaler()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Diffusion wrapper
diffusion = Diffusion(img_size=image_size, device=device)

for epoch in range(10): # jtcheck 10
    pbar = tqdm_class(trainloader)
    count =0
    for images, y in pbar:
        y = y[:,gender_index].view(-1).cuda()
        
        with torch.cuda.amp.autocast(enabled=True):
            images = images.to(device)
            

            t = torch.randint(low=1, high=diffusion.num_timesteps, size=(batch_size,)).cuda()
            x_t, noise = diffusion.get_noisy_image(images, t)
            predicted_noise = model(x_t, t, y)

            mse_loss = nn.MSELoss()
            loss = mse_loss(predicted_noise, noise)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        pbar.set_postfix(MSE=loss.item(), LR=optimizer.param_groups[0]['lr'])

        # update EMA model. First epoch of training is too noisy, 
        # so we only do this after the first epoch
        if epoch > 0:
            ema.step_ema(ema_model, model)
        
    if epoch == 0:
        ema_model = deepcopy(model)

    set_random_seed() # set random seed to generate the same style face. This is handy for comparing across epochs.
    # n is number of images you want to generate
    sampled_images = diffusion.sample(ema_model, n=8, y=torch.tensor([0,0,0,0,1,1,1,1]).cuda())
    
   
    show_images(sampled_images)
    
    
torch.save((ema_model.state_dict(), model.state_dict()), 'ddpm.pt')

# Score-based SDE

In [8]:

import functools

device = 'cuda'

def marginal_prob_std(t, sigma):
  """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
  Returns The standard deviation.
  """    
  t = torch.tensor(t, device=device)
  # print(t.shape)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Compute the diffusion coefficient of our SDE.
     returns the vector of diffusion coefficients.
  """
  return torch.tensor(sigma**t, device=device)
  
sigma =  25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [9]:

#@title Define the loss function (double click to expand or collapse)

def loss_fn(model, x, marginal_prob_std, eps=1e-5):
  """The loss function for training score-based generative models.
  """
  random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
  z = torch.randn_like(x)
  
  std = marginal_prob_std(random_t)

  perturbed_x = x + z * std[:, None, None, None]
  score = model(perturbed_x, random_t)
  loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
  return loss

In [11]:
import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm.notebook

device = 'cuda'
# score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = torch.nn.DataParallel(UNet(marginal_prob_std = marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs =   50#
## size of a mini-batch
batch_size =  32 #
## learning rate
lr = 1e-4  #
 

trainloader = torch.utils.data.DataLoader(dataset, drop_last=True,  batch_size=batch_size, shuffle=True, num_workers=8)


optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.notebook.trange(1) # jtcheck 50
for epoch in tqdm_epoch:
  avg_loss = 0.
  num_items = 0
  for x, y in trainloader:
    x = x.to(device)   
    # print(x.shape)
    loss = loss_fn(score_model, x, marginal_prob_std_fn)
    optimizer.zero_grad()
    loss.backward()    
    optimizer.step()
    avg_loss += loss.item() * x.shape[0]
    num_items += x.shape[0]
  # Print the averaged training loss so far.
  tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
  # Update the checkpoint after each epoch of training.
  torch.save(score_model.state_dict(), 'ckpt.pth')

  0%|          | 0/1 [00:00<?, ?it/s]

  t = torch.tensor(t, device=device)
