In [None]:
# import important libraries
import numpy as np
from math import log2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image
from torchvision import transforms

# **Create Helper Functions and Classes**

In [None]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/18, 1/16, 1/32]

In [None]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, gain = 2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,)
        self.scale = (gain / (in_channels * (kernel_size**2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
    def forward(self, x):
        x = self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
        return x

In [None]:
class WSConvTranspose2d(nn.Module):
    def __init__(self, z_dim, in_channels, kernel_size, stride, padding, gain = 2):
        super().__init__()
        self.convTranspose = nn.ConvTranspose2d(z_dim, in_channels, kernel_size, stride, padding,)
        self.scale = (gain / (z_dim * kernel_size**2)) ** 0.5
        self.bias = self.convTranspose.bias
        self.convTranspose.bias = None
        nn.init.normal_(self.convTranspose.weight)
        nn.init.zeros_(self.bias)
    def forward(self, x):
        x = self.convTranspose(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
        return x

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)

In [None]:
class Conv2dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm = True):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(.2)
        self.pix = PixelNorm()
        self.use_pix = use_pixelnorm
        
    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pix(x) if self.use_pix else x
        x = self.leaky(self.conv2(x))
        x = self.pix(x) if self.use_pix else x
        return x

# **Create Generator**

In [8]:
class Generator(nn.Module):
    def __init__(self,z_dim, in_channels, image_channels = 3):
        super().__init__()
        self.initial = nn.Sequential(
            WSConvTranspose2d(z_dim, in_channels, kernel_size = 4, stride = 1, padding = 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, 3, 1, 1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(in_channels, image_channels, kernel_size = 1, stride = 1, padding = 0)
        self.prog_blocks, self.rgb_layers = nn.ModuleList(), nn.ModuleList([self.initial_rgb])

        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(Conv2dBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c, image_channels, kernel_size = 1, stride = 1, padding = 0))

    def fade(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha)*upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)
        if steps == 0:
            return self.initial_rgb(out)
            
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor = 2, mode='nearest')
            out = self.prog_blocks[step](upscaled)
            
        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)
        
        return self.fade(alpha, final_upscaled, final_out)

# **Create Discriminator**

In [9]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, image_channels = 3):
        super().__init__()    
        self.initial_rgb = WSConv2d(image_channels, in_channels, kernel_size = 1, stride = 1,padding = 0)
        self.avgpool = nn.AvgPool2d(kernel_size = 2, stride = 2)
        self.prog_blocks, self.rgb_layers = nn.ModuleList(), nn.ModuleList()
        self.leaky = nn.LeakyReLU(.2)
        factors_reversed = list(reversed(factors))
        for i in range(len(factors_reversed) - 1,):
            conv_in_c = int(in_channels * factors_reversed[i])
            conv_out_c = int(in_channels * factors_reversed[i+1])
            self.prog_blocks.append(Conv2dBlock(conv_in_c, conv_out_c, use_pixelnorm = False))
            self.rgb_layers.append(WSConv2d(image_channels, conv_in_c, kernel_size = 1, stride = 1,padding = 0))

        self.rgb_layers.append(self.initial_rgb)
        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size = 3, padding = 1),
            nn.LeakyReLU(.2),
            WSConv2d(in_channels, in_channels, kernel_size = 4, stride = 1, padding = 0),
            nn.LeakyReLU(.2),
            nn.Flatten(),
            nn.Linear(in_channels, 1)
            # WSConv2d(in_channels, 1, kernel_size = 1, stride = 1, padding = 0),
        )
        

    def fade(self, alpha, downscaled, out):
        return alpha * out + (1 - alpha)*downscaled


    def mini_batch(self, x):
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_statistics], dim=1)
        

    
    def forward(self, x, alpha, steps):
        current_step = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[current_step](x))

        if steps == 0:
            out = self.mini_batch(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[current_step + 1](self.avgpool(x)))
        out = self.avgpool(self.prog_blocks[current_step](out))
        out = self.fade(alpha, downscaled, out)

        for step in range(current_step +1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avgpool(out)
        out = self.mini_batch(out)
        out = self.final_block(out)
        return out
        

In [10]:
def test_model():
    z_dim = 50
    in_channels = 512
    gen = Generator(z_dim, in_channels)
    disc = Discriminator(in_channels)
    for img_size in [4,8,16,32,64,128,256,512,1024]:
        num_steps = int(log2(img_size/4))
        x = torch.randn(1, z_dim, 1, 1)
        z = gen(x, .5, steps = num_steps)
        assert z.shape == (1, 3, img_size, img_size)
        out = disc(z, .5, steps = num_steps)
        assert out.shape == (1,1)
test_model()

  batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])


# **Load and Augment Data**

In [11]:
# Load dataset
dataset = ImageFolder('/kaggle/input/celebahq/celeba_hq/train')


In [44]:
ROOT_DIR = '/kaggle/input/celebahq/celeba_hq'
# Initialize Hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 1e-3
IMAGE_SIZE = 512
CHANNELS_IMG = 3
Z_DIM = 256
IN_CHANNELS = 256
LAMDA_GP = 10
NUM_STEPS = int(log2(IMAGE_SIZE/4)) + 1
NUM_EPOCH = 200
NUM_WORKERS = 4
BATCH_SIZES = [16, 16, 16, 16, 16, 16, 16, 8, 4]
PROGRESSIVE_EPOCHS = [10] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(1, Z_DIM, 1, 1).to(DEVICE)


In [45]:
# Loader Function
def get_loader(image_size, BATCH_SIZES):
    transform = transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
    ])
    batch_size = BATCH_SIZES[int(log2(image_size/4))]
    dataset = ImageFolder(root=ROOT_DIR, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    return loader, dataset



# **Create Loss Functions**

In [46]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [47]:
# Initialize Models
gen = Generator(Z_DIM, IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
critic = Discriminator(IN_CHANNELS, CHANNELS_IMG).to(DEVICE)


In [48]:
# Initialize Optimizers
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))



In [49]:
os.makedirs('images',exist_ok = True)

# **Model Training**

In [50]:
from tqdm.notebook import tqdm
# create train function
def train(epochs, critic, gen, opt_critic, opt_gen, BATCH_SIZES,device):
    critic.train()
    gen.train()
    step = 0
    for num_epochs in epochs[step:]:
        alpha = 1e-5
        image_size_ = 4*2**step
        loader, dataset = get_loader(image_size_, BATCH_SIZES)
        for epoch in range(num_epochs):
            epoch_loss_critic = 0  # Accumulate discriminator loss for the epoch
            epoch_loss_gen = 0   # Accumulate generator loss for the epoch
            with tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as t:
                for batch_idx, (real, _) in enumerate(t):
                    real = real.to(device)
                    cur_batch_size = real.shape[0]
                    
                    
                    # Train critic
                    noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
                    fake = gen(noise, alpha, step)
                    critic_real = critic(real, alpha, step)
                    critic_fake = critic(fake.detach(), alpha, step )
                    gp = gradient_penalty(critic, real, fake,alpha, step, device=device)
                    loss_critic = (
                    -(torch.mean(critic_real) - torch.mean(critic_fake)) 
                    + gp * LAMDA_GP
                    + (0.001 * torch.mean(critic_real ** 2))
                    )
                    critic.zero_grad()
                    loss_critic.backward()
                    opt_critic.step()
                    
                    
                    # Train Generator
                    gen_fake = critic(fake, alpha, step)
                    loss_gen = -torch.mean(gen_fake)
                    gen.zero_grad()
                    loss_gen.backward()
                    opt_gen.step()

                    alpha += cur_batch_size/(len(dataset)) * epochs[step]*0.5
                    alpha = min(alpha,1)
                    epoch_loss_critic += loss_critic.item()
                    epoch_loss_gen += loss_gen.item()
                    t.set_postfix(d_loss=loss_critic.item(), g_loss=loss_gen.item())
                 # Average loss over all batches in the epoch
            avg_loss_critic = epoch_loss_critic / len(loader)
            avg_loss_gen = epoch_loss_gen / len(loader)
               # Log progress at the end of each epoch
            print(
                f"Epoch [{epoch+1}/{num_epochs}] | "
                f"Avg D Loss: {avg_loss_critic:.4f} | Avg G Loss: {avg_loss_gen:.4f}"
            )
            with torch.no_grad():
                generated_image = gen(FIXED_NOISE, alpha, step) *0.5 +0.5
                save_image(generated_image, f"/kaggle/working/images/generated_image{step}.png")
            print(image_size_)

        step = step + 1



In [None]:
train(PROGRESSIVE_EPOCHS, critic, gen, opt_critic, opt_gen, BATCH_SIZES, DEVICE)

Epoch 1/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [1/10] | Avg D Loss: -0.5085 | Avg G Loss: 1.0963
4


Epoch 2/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [2/10] | Avg D Loss: -0.1612 | Avg G Loss: 0.5253
4


Epoch 3/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [3/10] | Avg D Loss: -0.1637 | Avg G Loss: 0.3149
4


Epoch 4/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [4/10] | Avg D Loss: -0.1572 | Avg G Loss: 0.2744
4


Epoch 5/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [5/10] | Avg D Loss: -0.1401 | Avg G Loss: 0.2575
4


Epoch 6/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [6/10] | Avg D Loss: -0.1142 | Avg G Loss: 0.2320
4


Epoch 7/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [7/10] | Avg D Loss: -0.0835 | Avg G Loss: 0.2072
4


Epoch 8/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [8/10] | Avg D Loss: -0.0421 | Avg G Loss: 0.1544
4


Epoch 9/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [9/10] | Avg D Loss: -0.0299 | Avg G Loss: 0.1536
4


Epoch 10/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [10/10] | Avg D Loss: -0.0313 | Avg G Loss: 0.1480
4


Epoch 1/10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [1/10] | Avg D Loss: -0.4177 | Avg G Loss: 0.9753
8


Epoch 2/10:   0%|          | 0/1875 [00:00<?, ?batch/s]