In [1]:
import cv2
from math import log2
from tqdm import tqdm
import random
import numpy as np
from PIL import Image
import os
from scipy.stats import truncnorm

import torch
import torch.optim as optim
import torchvision.datasets as datasets
from torchvision.datasets import CelebA
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.benchmarks = True

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


# **Modelling**

In [5]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.model = nn.Sequential(
                                    nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2),
                                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2),
                                    )

    def forward(self, x):
        return self.model(x)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.initial = nn.Sequential(
                                        nn.ConvTranspose2d(256, 256, 4, 1, 0), nn.LeakyReLU(0.2),
                                        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2),
                                    )

        self.initial_rgb = nn.Conv2d(256, 3, kernel_size=1, stride=1, padding=0)


        self.prog_blocks = nn.ModuleList([
                                          ConvBlock(256, 256),
                                          ConvBlock(256, 256),
                                          ConvBlock(256, 256),
                                          ConvBlock(256, 128),
                                          ConvBlock(128, 64),
                                          ConvBlock(64, 32),
                                          ConvBlock(32, 16),
                                          ConvBlock(16, 8),
                                          ])
        self.rgb_layers = nn.ModuleList([
                                          self.initial_rgb,
                                          nn.Conv2d(256, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(256, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(256, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(128, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(16, 3, kernel_size=1, stride=1, padding=0),
                                          nn.Conv2d(8, 3, kernel_size=1, stride=1, padding=0)
                                          ])

    def forward(self, x, steps, alpha = 0.0001):
        out = self.initial(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return torch.tanh(alpha * final_out + (1 - alpha) * final_upscaled)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()


        self.prog_blocks = nn.ModuleList([
                                          ConvBlock(8, 16),
                                          ConvBlock(16, 32),
                                          ConvBlock(32, 64),
                                          ConvBlock(64, 128),
                                          ConvBlock(128, 256),
                                          ConvBlock(256, 256),
                                          ConvBlock(256, 256),
                                          ConvBlock(256, 256)
                                          ])


        self.rgb_layers = nn.ModuleList([
                                         nn.Conv2d(3, 8, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 16, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 32, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 128, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 256, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 256, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 256, kernel_size=1, stride=1, padding=0),
                                         nn.Conv2d(3, 256, kernel_size=1, stride=1, padding=0)
        ])

        self.leaky = nn.LeakyReLU(0.2)


        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(
                                            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.LeakyReLU(0.2),
                                            nn.Conv2d(256, 256, kernel_size=4, padding=0, stride=1),nn.LeakyReLU(0.2),
                                            nn.Conv2d(256, 1, kernel_size=1, padding=0, stride=1)
                                        )

    def forward(self, x, steps):

        cur_step = len(self.prog_blocks) - steps

        out = self.leaky(self.rgb_layers[cur_step](x))

        out = self.avg_pool(self.prog_blocks[cur_step](out))

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        return self.final_block(out).view(out.shape[0], -1)

# **Training**

In [None]:
# initialize models
gen = Generator().to(DEVICE)
critic = Discriminator().to(DEVICE)

# setting the models to training mode
gen.train()
critic.train()

# initialize optimizers
opt_gen = optim.Adam(gen.parameters(), lr=1e-3, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=1e-3, betas=(0.0, 0.99))

scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()

#### Training a model to generate 128*128


In [None]:
image_size = 128
batch_size = 16
dataset = get_loader(image_size)
loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

In [None]:
for epoch in range(30):
    print(f"Epoch [{epoch+1}]")


    loop = tqdm(loader, leave=True)
    for batch_idx, real in enumerate(loop):

        real = real.to(DEVICE)


        # Train Discriminator
        noise = torch.randn(batch_size, 256, 1, 1).to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, 5, 1e-5)
            critic_real = critic(real, 5)
            critic_fake = critic(fake.detach(), 5)
            loss_critic = 0.001*torch.mean(critic_real ** 2) - torch.mean(critic_real) + torch.mean(critic_fake)

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()






        # Train Generator
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, 5)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        loop.set_postfix(loss_critic=loss_critic.item())

#### Training a model to generate 256 x 256

In [None]:
image_size = 256
batch_size = 16
dataset = get_loader(image_size)
loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

In [None]:
for epoch in range(30):
    print(f"Epoch [{epoch+1}]")


    loop = tqdm(loader, leave=True)
    for batch_idx, real in enumerate(loop):

        real = real.to(DEVICE)


        # Train Discriminator
        noise = torch.randn(batch_size, 256, 1, 1).to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, 6, 1e-4)
            critic_real = critic(real, 6)
            critic_fake = critic(fake.detach(), 6)
            loss_critic = 0.001*torch.mean(critic_real ** 2) - torch.mean(critic_real) + torch.mean(critic_fake)

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()






        # Train Generator
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, 6)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        loop.set_postfix(loss_critic=loss_critic.item())

#### Training a model to generate 512 x 512

In [None]:
image_size = 512
batch_size = 8
dataset = get_loader(image_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

In [None]:
for epoch in range(30):
    print(f"Epoch [{epoch+1}]")


    loop = tqdm(loader, leave=True)
    for batch_idx, real in enumerate(loop):

        real = real.to(DEVICE)


        # Train Discriminator
        noise = torch.randn(batch_size, 256, 1, 1).to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, 7, 1e-3)
            critic_real = critic(real, 7)
            critic_fake = critic(fake.detach(), 7)
            loss_critic = 0.001*torch.mean(critic_real ** 2) - torch.mean(critic_real) + torch.mean(critic_fake)

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()






        # Train Generator
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, 7)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        loop.set_postfix(loss_critic=loss_critic.item())

#### Training a model to generate 1024*1024

In [None]:
image_size = 1024
batch_size = 4
dataset = get_loader(image_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

In [None]:
for epoch in range(30):
    print(f"Epoch [{epoch+1}]")


    loop = tqdm(loader, leave=True)
    for batch_idx, real in enumerate(loop):

        real = real.to(DEVICE)


        # Train Discriminator
        noise = torch.randn(batch_size, 256, 1, 1).to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, 8, 1e-2)
            critic_real = critic(real, 8)
            critic_fake = critic(fake.detach(), 8)
            loss_critic = 0.001*torch.mean(critic_real ** 2) - torch.mean(critic_real) + torch.mean(critic_fake)

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()






        # Train Generator
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, 8)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        loop.set_postfix(loss_critic=loss_critic.item())

# **Inference**
Generating all images i.e.
- 8 * 8,
- 16 * 16,
- 32*32,
- 64 * 64,
- 128 * 128,
- 256 * 256,
- 512 * 512,
- 1024 * 1024

In [None]:
gen.eval()
for i in range(8):
    with torch.no_grad():

        # input noise to the model
        noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, 256, 1, 1)), device=DEVICE, dtype=torch.float32)

        # generating image using the generator
        img = gen(noise, steps)

        # denormalizing and saving the image
        save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")