In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import random
from tqdm.auto import tqdm
import imageio
import einops
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torchvision.datasets import FashionMNIST
device = torch.device("cuda:0")

In [2]:
import importlib
from DDPM_gen_vis import *
from GAN_modules import * 

In [3]:
transform = Compose([
    ToTensor(),
    Lambda(lambda x: (x - 0.5) * 2)]
)
batch_size = 64

dataset = FashionMNIST("./datasets", download=True, train=True, transform=transform)
loader = DataLoader(dataset, batch_size, shuffle=True)

In [4]:
class DDGAN(nn.Module):
    def __init__(self, generator, discriminator, n_steps=15, 
                 min_beta=2e-1, max_beta=9e-1, device=None, 
                 image_chw=(1, 28, 28), emb_dim=100):
        super(DDGAN, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        first = torch.tensor(1e-8).to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
        self.betas = torch.cat((first[None], self.betas)).to(device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)      
        self.sigmas_cum = (1 - self.alpha_bars) ** 0.5
        self.emb_matr = sinusoidal_embedding(n_steps, emb_dim).to(device)
        self.emb_dim = emb_dim
        
    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy
    
    def backward(self, x, t, n_st=None):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        if n_st is None:
            n_st = self.n_steps
            emb_matr = self.emb_matr
        else:
            emb_matr = sinusoidal_embedding(n_st, self.emb_dim).to(self.device)

        t = F.embedding(t, emb_matr)
        t.requires_grad_(False)
        
        latent_z = torch.randn_like(x).to(device)

        return self.generator(x, t, latent_z)

In [5]:
def extract(inp, t, shape):
    out = torch.gather(inp, 0, t)
 
    reshape = [shape[0]] + [1] * (len(shape) - 1)
    out = out.reshape(*reshape)
    return out

def q_sample(ddgan, x_start, t, *, noise=None):
    """
    Diffuse the data (t == 0 means diffused for t step)
    """
    if noise is None:
        noise = torch.randn_like(x_start)
      
    x_t = extract(ddgan.alpha_bars**0.5, t, x_start.shape) * x_start + \
          extract(ddgan.sigmas_cum, t, x_start.shape) * noise
    
    return x_t

    
def q_sample_pairs(ddgan, x_start, t):
    """
    Generate a pair of disturbed images for training
    :param x_start: x_0
    :param t: time step t
    :return: x_t, x_{t+1}
    """
    noise = torch.randn_like(x_start)
    x_t = q_sample(ddgan, x_start, t)
    
    x_t_plus_one = x_t * extract(ddgan.alphas**0.5, t+1, x_start.shape) + \
                    extract(ddgan.betas**0.5, t+1, x_start.shape) * noise

    return x_t, x_t_plus_one

In [6]:
class Posterior_Coefficients():
    def __init__(self, ddgan, device):
        
        self.betas = ddgan.betas
        
        #we don't need the zeros
        self.betas = self.betas.type(torch.float32)[1:]
        
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, 0)
        self.alphas_cumprod_prev = torch.cat(
                                    (torch.tensor([1.], dtype=torch.float32,device=device), self.alphas_cumprod[:-1]), 0
                                        )               
        self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
        
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod - 1)
        
        self.posterior_mean_coef1 = (self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1 - self.alphas_cumprod))
        self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
        
        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
        
def sample_posterior(coefficients, x_0, x_t, t):
    
    def q_posterior(x_0, x_t, t):
        mean = (
            extract(coefficients.posterior_mean_coef1, t, x_t.shape) * x_0
            + extract(coefficients.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        var = extract(coefficients.posterior_variance, t, x_t.shape)
        log_var_clipped = extract(coefficients.posterior_log_variance_clipped, t, x_t.shape)
        return mean, var, log_var_clipped
    
  
    def p_sample(x_0, x_t, t):
        mean, _, log_var = q_posterior(x_0, x_t, t)
        
        noise = torch.randn_like(x_t)
        
        nonzero_mask = (1 - (t == 0).type(torch.float32))

        return mean + nonzero_mask[:,None,None,None] * torch.exp(0.5 * log_var) * noise
    
    sample_x_pos = p_sample(x_0, x_t, t)
    
    return sample_x_pos

In [7]:
def training_loop(ddgan, loader, n_epochs, optimizerG, optimizerD, device,
                  schedulerG=None, schedulerD=None, store_path="ddgan_model.pt"):

    best_loss = float("inf")
    n_steps = ddgan.n_steps
    netG = ddgan.generator
    netD = ddgan.discriminator

    pos_coeff = Posterior_Coefficients(ddgan, device)

    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        print(f'Epoch {epoch+1}/{n_epochs}')
        
        epoch_errG = 0.0
        epoch_errD = 0.0
           
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
            
            x0 = batch[0].to(device)
            n = len(x0)
            
            t = torch.randint(0, n_steps, (n,)).to(device)
            
            netD.zero_grad()
            x_t, x_tp1 = q_sample_pairs(ddgan, x0, t)
            x_t.requires_grad = True
            
            #train D with real            
            D_real = netD(x_t, t, x_tp1.detach()).view(-1)
            errD_real = (F.softplus(-D_real)).mean()
            
            errD_real.backward(retain_graph=True)
                    
            #train D with fake from G
            latent_z = torch.randn(n, netG.zsize).to(device)
            
            x_0_predict = netG(x_tp1, t, latent_z)
            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
            
            output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
            
            errD_fake = (F.softplus(output)).mean()
            errD_fake.backward()
            
            errD = errD_real.detach() + errD_fake.detach()
            optimizerD.step()
            
            
            #train G without D
            for p in netD.parameters():
                p.requires_grad = False
            netG.zero_grad()
            
            
            t = torch.randint(0, n_steps, (n,)).to(device)
            
            x_t, x_tp1 = q_sample_pairs(ddgan, x0, t)
            
            latent_z = torch.randn(n, netG.zsize).to(device)
            
            
            x_0_predict = netG(x_tp1.detach(), t, latent_z)
            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
            
            output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
               
            
            errG = (F.softplus(-output)).mean()
            
            errG.backward()
            optimizerG.step()

            epoch_errG += errG.detach() * n / len(loader.dataset)
            epoch_errD += errD.detach() * n / len(loader.dataset)
        
        if schedulerD is not None:
            schedulerD.step()
        if schedulerG is not None:
            schedulerG.step()

        log_string = f"G loss: {epoch_errG:.4f}, D loss: {epoch_errD:.4f}"

        # Storing the model
        torch.save(ddgan.state_dict(), store_path)
        print(log_string)
        print('-' * 75)
    torch.cuda.empty_cache()

In [8]:
from CustomizableCosineDecayScheduler import CosineDecayWithWarmUpScheduler as CD_scheduler
n_steps, min_beta, max_beta = 15, 2e-1, 9e-1

In [None]:
try:
    del generator
except: pass
try:
    del discriminator
except: pass
try:
    del ddgan
except: pass
torch.cuda.empty_cache()


generator = Generator(time_emb_dim=20, n_steps=n_steps, device=device, zsize=100)
discriminator = Discriminator(time_emb_dim=20, n_steps=n_steps, device=device)


ddgan = DDGAN(generator, discriminator, n_steps=n_steps, 
              min_beta=min_beta, max_beta=max_beta, emb_dim=20,
              device=device).to(device)

optimizerG = optim.Adam(ddgan.generator.parameters(), betas=(0.7, 0.99),
                       lr=3e-3)
optimizerD = optim.Adam(ddgan.discriminator.parameters(), betas=(0.7, 0.99),
                       lr=3e-3)

schedulerG = CD_scheduler(optimizerG, 
                    max_lr=3e-3, min_lr=3e-6, num_step_down=20, 
                    num_step_up=0, gamma=0.5, alpha=0.3)
schedulerD = CD_scheduler(optimizerG, 
                    max_lr=7e-3, min_lr=3e-6, num_step_down=20, 
                    num_step_up=0, gamma=0.3, alpha=0.3)

ddgan.train()
training_loop(ddgan, loader, n_epochs=40, optimizerG=optimizerG, 
              optimizerD=optimizerD, schedulerG=schedulerG, schedulerD=schedulerD,
              device=device, store_path="ddgan_model.pt")

Training progress:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/40


Epoch 1/40:   0%|          | 0/938 [00:00<?, ?it/s]

In [None]:
store_path="ddgan_model.pt"

generator = Generator(time_emb_dim=20, n_steps=n_steps, device=device)
discriminator = Discriminator(time_emb_dim=20, n_steps=n_steps, device=device)

best_model = DDGAN(generator, discriminator, n_steps=n_steps, device = device)
best_model.load_state_dict(torch.load(store_path, map_location=device))
best_model.eval()
print()

In [None]:
def generate_new_images(model=None, n_samples=16, device=None,
                        c=1, h=28, w=28, n_steps=None):
    ddgan = model
    torch.cuda.empty_cache()
    if n_steps is None:
        n_steps = ddgan.n_steps


    with torch.no_grad():
        if device is None:
            device = ddgan.device
            
        pos_coeff = Posterior_Coefficients(ddgan, device)
        x = torch.randn(n_samples, c, h, w).to(device)

        for idx, t in enumerate(tqdm(list(range(n_steps))[::-1], leave=False, desc="Steps", colour="#005500")):

            time_tensor = (torch.ones(n_samples,) * t).to(device).long()
            latent_z = torch.randn_like(x).to(device)
            
            x_0_predict = ddgan.generator(x, time_tensor, latent_z)
            x = sample_posterior(pos_coeff, x_0_predict, x, time_tensor).detach()


    try:
        del ddgan
    except: pass

    torch.cuda.empty_cache()

    return x

In [None]:
best_model.eval()
generated = generate_new_images(
        best_model,
        n_samples=49,
        device=device,
        n_steps=15
    )
show_images(generated)