In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import random
from tqdm.auto import tqdm
import imageio
import einops
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torchvision.datasets import FashionMNIST
device = torch.device("cuda:0")

# TO DO: fix discriminator and generator

In [2]:
def show_images(images, title=""):
    """Shows the provided images as sub-pictures in a square"""

    # Converting images to CPU numpy arrays
    if type(images) is torch.Tensor:
        images = images.detach().cpu().numpy()

    # Defining number of rows and columns
    fig = plt.figure(figsize=(8, 8))
    rows = int(len(images) ** (1 / 2))
    cols = round(len(images) / rows)

    # Populating figure with sub-plots
    idx = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, idx + 1)

            if idx < len(images):
                cur_img = np.transpose(images[idx], (1, 2, 0))
                cur_img = (cur_img - np.amin(cur_img, axis=(0,1),keepdims=True)) / np.ptp(cur_img, axis=(0,1))                
                plt.imshow(cur_img)
                idx += 1
    fig.suptitle(title, fontsize=30)

    # Showing the figure
    plt.show()

In [3]:
transform = Compose([
    ToTensor(),
    Lambda(lambda x: (x - 0.5) * 2)]
)
batch_size = 64

dataset = FashionMNIST("./datasets", download=True, train=True, transform=transform)
loader = DataLoader(dataset, batch_size, shuffle=True)

In [4]:
def generate_new_images(ddg, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, h=28, w=28):

    frame_idxs = np.linspace(0, ddg.n_steps, frames_per_gif).astype(np.uint)
    frames = []

    with torch.no_grad():
        if device is None:
            device = ddg.device

        # Starting from random noise
        x = torch.randn(n_samples, c, h, w).to(device)

        for idx, t in enumerate(list(range(ddg.n_steps))[::-1]):
            # Estimating noise to be removed
            time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
            eta_theta = ddg.backward(x, time_tensor)

            alpha_t = ddg.alphas[t]
            alpha_t_bar = ddg.alpha_bars[t]

            # Partially denoising the image
            x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)

            if t > 0:
                z = torch.randn(n_samples, c, h, w).to(device)

                # Option 1: sigma_t squared = beta_t
                beta_t = ddg.betas[t]
                sigma_t = beta_t.sqrt()


                # Adding some more noise like in Langevin Dynamics fashion
                x = x + sigma_t * z

            # Adding frames to the GIF
            if idx in frame_idxs or t == 0:
                # Putting digits in range [0, 255]
                normalized = x.clone()
                for i in range(len(normalized)):
                    normalized[i] -= torch.min(normalized[i])
                    normalized[i] *= 255 / torch.max(normalized[i])

                # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame
                frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(n_samples ** 0.5))
                frame = frame.cpu().numpy().astype(np.uint8)

                # Rendering frame
                frames.append(frame)

    # Storing the gif
    with imageio.get_writer(gif_name, mode="I") as writer:
        for idx, frame in enumerate(frames):
            writer.append_data(frame)
            if idx == len(frames) - 1:
                for _ in range(frames_per_gif // 3):
                    writer.append_data(frames[-1])
    return x

In [5]:
class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, 
                 padding=1, activation=None, normalize=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out

In [6]:
def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding

In [7]:
class Generator(nn.Module):
    def __init__(self, n_steps=4, time_emb_dim=4):
        super(Generator, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 15),
            MyBlock((15, 28, 28), 15, 15),
            MyBlock((15, 28, 28), 15, 15)
        )
        self.down1 = nn.MaxPool2d(2, 2)

        self.te2 = self._make_te(time_emb_dim, 15)
        self.b2 = nn.Sequential(
            MyBlock((15, 14, 14), 15, 30),
            MyBlock((30, 14, 14), 30, 30),
            MyBlock((30, 14, 14), 30, 30)
        )
        self.down2 = nn.MaxPool2d(2, 2)

        self.te3 = self._make_te(time_emb_dim, 30)
        self.b3 = nn.Sequential(
            MyBlock((30, 7, 7), 30, 60),
            MyBlock((60, 7, 7), 60, 60),
            MyBlock((60, 7, 7), 60, 60)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(60, 60, 2, 1),
            nn.SiLU(),
            nn.Conv2d(60, 60, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 60)
        self.b_mid = nn.Sequential(
            MyBlock((60, 3, 3), 60, 30),
            MyBlock((30, 3, 3), 30, 30),
            MyBlock((30, 3, 3), 30, 60)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(60, 60, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(60, 60, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 120)
        self.b4 = nn.Sequential(
            MyBlock((120, 7, 7), 120, 60),
            MyBlock((60, 7, 7), 60, 30),
            MyBlock((30, 7, 7), 30, 30)
        )

        self.up2 = nn.ConvTranspose2d(30, 30, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 60)
        self.b5 = nn.Sequential(
            MyBlock((60, 14, 14), 60, 30),
            MyBlock((30, 14, 14), 30, 15),
            MyBlock((15, 14, 14), 15, 15)
        )

        self.up3 = nn.ConvTranspose2d(15, 15, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 30)
        self.b_out = nn.Sequential(
            MyBlock((30, 28, 28), 30, 15),
            MyBlock((15, 28, 28), 15, 15),
            MyBlock((15, 28, 28), 15, 15, normalize=False)
        )

        self.conv_out = nn.Conv2d(15, 1, 3, 1, 1)

    def forward(self, x, t):

        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1)) 

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1)) 

        out5 = torch.cat((out2, self.up2(out4)), dim=1) 
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))

        out = torch.cat((out1, self.up3(out5)), dim=1)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))

        out = self.conv_out(out)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )

In [8]:
class Discriminator(nn.Module):
    def __init__(self, n_steps=4, time_emb_dim=4):
        super(Discriminator, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 15),
            MyBlock((15, 28, 28), 15, 15),
            MyBlock((15, 28, 28), 15, 15)
        )
        self.down1 = nn.MaxPool2d(2, 2)

        self.te2 = self._make_te(time_emb_dim, 15)
        self.b2 = nn.Sequential(
            MyBlock((15, 14, 14), 15, 30),
            MyBlock((30, 14, 14), 30, 30),
            MyBlock((30, 14, 14), 30, 30)
        )
        self.down2 = nn.MaxPool2d(2, 2)

        self.te3 = self._make_te(time_emb_dim, 30)
        self.b3 = nn.Sequential(
            MyBlock((30, 7, 7), 30, 60),
            MyBlock((60, 7, 7), 60, 60),
            MyBlock((60, 7, 7), 60, 60)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(60, 60, 2, 1),
            nn.SiLU(),
            nn.Conv2d(60, 60, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 60)
        self.b_mid = nn.Sequential(
            MyBlock((60, 3, 3), 60, 30),
            MyBlock((30, 3, 3), 30, 30),
            MyBlock((30, 3, 3), 30, 60)
        )

        self.fc_out = nn.Linear(3 * 3 * 60, 1)

    def forward(self, x, t):

        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1)) 

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))

        out = self.fc_out(out_mid)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )

In [9]:
class DDGAN(nn.Module):
    def __init__(self, generator, discriminator, n_steps=4, 
                 min_beta=2e-1, max_beta=9e-1, device=None, 
                 image_chw=(1, 28, 28)):
        super(DDGAN, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(
            device)  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)      
        self.sigmas_cum = (1 - self.alpha_bars) ** 0.5
        
    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy

In [10]:
def extract(inp, t, shape):
    out = torch.gather(inp, 0, t)
    reshape = [shape[0]] + [1] * (len(shape) - 1)
    out = out.reshape(*reshape)
    return out

def q_sample(ddgan, x_start, t, *, noise=None):
    """
    Diffuse the data (t == 0 means diffused for t step)
    """
    if noise is None:
        noise = torch.randn_like(x_start)
      
    x_t = extract(ddgan.alpha_bars**0.5, t, x_start.shape) * x_start + \
          extract(ddgan.sigmas_cum, t, x_start.shape) * noise
    
    return x_t

    
def q_sample_pairs(ddgan, x_start, t):
    """
    Generate a pair of disturbed images for training
    :param x_start: x_0
    :param t: time step t
    :return: x_t, x_{t+1}
    """
    noise = torch.randn_like(x_start)
    x_t = q_sample(ddgan, x_start, t)
    x_t_plus_one = extract(ddgan.alphas**0.5, t+1, x_start.shape) * x_t + \
                   extract(ddgan.betas**0.5, t+1, x_start.shape) * noise
    
    return x_t, x_t_plus_one

In [11]:
class Posterior_Coefficients():
    def __init__(self, ddgan, device):
        
        self.betas = ddgan.betas
        
        #we don't need the zeros
        self.betas = self.betas.type(torch.float32)[1:]
        
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, 0)
        self.alphas_cumprod_prev = torch.cat(
                                    (torch.tensor([1.], dtype=torch.float32,device=device), self.alphas_cumprod[:-1]), 0
                                        )               
        self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
        
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod - 1)
        
        self.posterior_mean_coef1 = (self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1 - self.alphas_cumprod))
        self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
        
        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
        
def sample_posterior(coefficients, x_0,x_t, t):
    
    def q_posterior(x_0, x_t, t):
        mean = (
            extract(coefficients.posterior_mean_coef1, t, x_t.shape) * x_0
            + extract(coefficients.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        var = extract(coefficients.posterior_variance, t, x_t.shape)
        log_var_clipped = extract(coefficients.posterior_log_variance_clipped, t, x_t.shape)
        return mean, var, log_var_clipped
    
  
    def p_sample(x_0, x_t, t):
        mean, _, log_var = q_posterior(x_0, x_t, t)
        
        noise = torch.randn_like(x_t)
        
        nonzero_mask = (1 - (t == 0).type(torch.float32))

        return mean + nonzero_mask[:,None,None,None] * torch.exp(0.5 * log_var) * noise
            
    sample_x_pos = p_sample(x_0, x_t, t)
    
    return sample_x_pos

In [21]:
def training_loop(ddgan, loader, n_epochs, optimizerG, optimizerD, device,
                  schedulerD=None, schedulerG=None, store_path="ddgan_model.pt"):

    best_loss = float("inf")
    n_steps = ddgan.n_steps
    netG = ddgan.generator
    netD = ddgan.discriminator
    pos_coeff = Posterior_Coefficients(ddgan, device)

    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        print(f'Epoch {epoch+1}/{n_epochs}')
        
        epoch_errG = 0.0
        epoch_errD = 0.0
           
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):

            x0 = batch[0].to(device)
            n = len(x0)
            
            t = torch.randint(0, n_steps, (n,)).to(device)
            
            netD.zero_grad()
            x_t, x_tp1 = q_sample_pairs(ddgan, x0, t)
            x_t.requires_grad = True
            
            
            #train D with real
            D_real = netD(x_t, t, x_tp1.detach()).view(-1)
            errD_real = (F.softplus(-D_real)).mean()
            
            errD_real.backward(retain_graph=True)
            
            
            #train D with fake from G
            latent_z = torch.randn_like(x0).to(device)
            
            x_0_predict = netG(x_tp1, t, latent_z)
            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
            
            output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
            
            errD_fake = (F.softplus(output)).mean()
            errD_fake.backward()
            
            errD = errD_real + errD_fake
            optimizerD.step()
            
            
            #train G without D
            for p in netD.parameters():
                p.requires_grad = False
            netG.zero_grad()
            
            
            t = torch.randint(0, n_steps, (n,)).to(device)
            
            x_t, x_tp1 = q_sample_pairs(ddgan, real_data, t)
            
            latent_z = torch.randn_like(x0).to(device)
            
            
            x_0_predict = netG(x_tp1.detach(), t, latent_z)
            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
            
            output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
               
            
            errG = (F.softplus(-output)).mean()
            
            errG.backward()
            optimizerG.step()
            

            epoch_errG += errG * n / len(loader.dataset)
            epoch_errD += errD * n / len(loader.dataset)
        
        if schedulerD is not None:
            schedulerD.step()
        if schedulerG is not None:
            schedulerG.step()

        log_string = f"G loss: {epoch_errG:.4f}, D loss: {epoch_errD:.4f}"

        # Storing the model
        if best_loss > epoch_loss:
            torch.save(ddgan.state_dict(), store_path)
            
        print(log_string)
        print('-' * 75)

In [22]:
from CustomizableCosineDecayScheduler import CosineDecayWithWarmUpScheduler as CD_scheduler

In [23]:
try:
    del ddgan
except: pass

n_steps, min_beta, max_beta = 4, 2e-1, 9e-1

generator = Generator()
discriminator = Discriminator()


ddgan = DDGAN(generator, discriminator, n_steps=n_steps, 
              min_beta=min_beta, max_beta=max_beta, device=device)

optimizerG = optim.Adam(ddgan.generator.parameters(), betas=(0.7, 0.99),
                       lr=3e-4)
optimizerD = optim.Adam(ddgan.discriminator.parameters(), betas=(0.7, 0.99),
                       lr=3e-4)

schedulerG = CD_scheduler(optimizerG, 
                    max_lr=3e-4, min_lr=1e-6, num_step_down=20, 
                    num_step_up=0, gamma=0.5, alpha=0.3)
schedulerD = CD_scheduler(optimizerG, 
                    max_lr=3e-4, min_lr=1e-6, num_step_down=20, 
                    num_step_up=0, gamma=0.5, alpha=0.3)

ddgan.train()
training_loop(ddgan, loader, n_epochs=20, optimizerG=optimizerG, 
              optimizerD=optimizerD, device=device)

Training progress:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1/20


Epoch 1/20:   0%|          | 0/938 [00:00<?, ?it/s]

TypeError: forward() takes 3 positional arguments but 4 were given