## Startup

### Import Modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from PIL import Image
from torchvision import transforms
import torchvision.utils as vutils
import numpy as np
import torch.nn.functional as F
from math import log2
import math

### Starting variables

In [None]:
# Root directory for dataset
dataroot = "C:\.Python Projects\AnimeGAN\data"

# Number of workers for dataloader
workers = 14

# Set True to prevent releasing and reassigning workers between epochs
persistent_workers = True

# Number of batches to prefetch per worker ( workers * prefect_factor = Number of Batches preloaded )
prefetch_factor = 4

# Batch size during training
batch_sizes = [64, 64, 64, 64, 40, 16, 8, 4] # 512: [64, 64, 64, 64, 10, 6, 4, 2], 256: [64, 64, 64, 64, 32, 16, 8, 4]

# Image size to train up to (inclusive). Paper uses 1024
image_size = 256

assert image_size >= 4, f"{image_size} is not greater than or equal to 4!"
assert image_size <= 1024, f"{image_size} is not less than or equal to 1024!"
assert math.ceil(log2(image_size)) == math.floor(log2(image_size)), f"{image_size} is not a power of 2!"

# Image size to start training at
start_train_at = 4

assert image_size >= 4, f"{start_train_at} is not greater than or equal to 4!"
assert start_train_at <= image_size, f"{start_train_at} is not less than or equal to {image_size}"
assert math.ceil(log2(start_train_at)) == math.floor(log2(start_train_at)), f"{start_train_at} is not a power of 2!"

# Device to push to
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
pin_memory = True if device.type == "cuda" else False

# Size of latent vector, Paper uses 512 for z_dim and in_channels
z_dim = 256

in_channels = 256

# Learning rate for optimizers. Paper uses 1e-3
lr = 1e-3

lambda_gp = 10

# Number of steps to reach desired image size
num_steps = int(log2(image_size / 4)) + 1

# Progressive Epochs
#   4x4: Train the model to see 800k images. 
#   8x8, 16x16, ..., img_size x img_size: Train model to see 800k images fading in the new layer and 800k images to stabilize.
#   
#   4x4 epochs: 800,000 / dataset_size
#   onwards... 2 times 4x4 epochs

prog_epochs = [32] + [64] * (num_steps - 1)

# Used to create the layers progressively.
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

# Fixed noise to monitor progression of model
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)

# Display images after each epoch
display_images = False

# Use pretrained model
use_pretrained = True

# Set true to show every step with tqdm
update_last = True

start_epoch = 0

step = int(log2(start_train_at / 4))

if use_pretrained:
    path = "models\pretrained_imgsize_128_zdim_256.pth" # "../PATH_TO_CHECKPOINT.pth"
    checkpoint = torch.load(path)
    batch_sizes = checkpoint["batch_sizes"]
    start_train_at = checkpoint["start_training_at"]
    fixed_noise = checkpoint["fixed_noise"]
    z_dim = checkpoint["z_dim"]
    in_channels = checkpoint["in_channels"]
    step = int(log2(start_train_at / 4))
    start_epoch = checkpoint["epoch"]

## Build Network

### Weighted-Scaled Convolutional Layer

Equalized Learning Rate

In [None]:
class WSConv2d(nn.Module):
    def __init__(self, input_channel, out_channel, kernel_size=3, stride=1, padding=1, gain=2):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(input_channel, out_channel, kernel_size, stride, padding)
        self.scale = (gain / (input_channel * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

### Pixel Normalization

Normalization replacement for batch normalization.

Configured as $b_{x, y} = a_{x, y} / \sqrt{\frac{1}{N} \sum_{j=0}^{N-1}(a_{x,y}^j)^2 + \epsilon}$

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt( torch.mean( x ** 2, dim=1, keepdim=True ) + self.epsilon )

### Convolutional Block

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, input_channel, out_channel, pixel_norm=True):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(input_channel, out_channel)
        self.conv2 = WSConv2d(out_channel, out_channel)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        self.use_pn = pixel_norm

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

### Generator Network

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), # 1x1 -> 4x4
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm()
        )

        self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([self.initial_rgb])

        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)
        
        if steps == 0:
            return self.initial_rgb(out)
        
        for i in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[i](upscaled)
        
        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

### Discriminator Network

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, pixel_norm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0))

        # 4x4 img res
        self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1),
        )

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x: torch.Tensor):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        
        cur_step = len(self.prog_blocks) - steps

        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

### Test Network

In [None]:
gen = Generator(z_dim=z_dim, in_channels=in_channels)
disc = Discriminator(in_channels=in_channels)

for img_size in [4, 8, 16, 32, 64, 128, 256, 512]:
    num_steps = int(log2(img_size / 4))
    x = torch.randn((1, z_dim, 1, 1))
    z = gen(x, 0.5, steps=num_steps)
    assert z.shape == (1, 3, img_size, img_size)
    out = disc(z, alpha=0.5, steps=num_steps)
    print(f"Success! at img size: {img_size}")

## Helper Functions

In [None]:
def get_loader(image_size):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize to a standard size (can be adjusted)
        transforms.ToTensor(),  # Convert the image to a PyTorch tensor
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalize to [-1, 1]
    ])

    batch_size = batch_sizes[int(log2(image_size / 4))]

    dataset = dset.ImageFolder(root=dataroot, transform=transform)

    dataloader = DataLoader(dataset, 
                            batch_size=batch_size, 
                            shuffle=True, 
                            num_workers=workers, 
                            drop_last=True, 
                            pin_memory=pin_memory, 
                            persistent_workers=persistent_workers, 
                            prefetch_factor=prefetch_factor)

    return dataloader, dataset

In [None]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

### Train Function

In [None]:
from tqdm import tqdm

def train(disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen, epoch, num_epochs):
    loop = tqdm(loader, total=len(loader), miniters=50, desc=f"Epoch [{epoch + 1}/{num_epochs}]", leave=update_last)
    for real, _ in loop:
        #if batch_idx == 10:
        #    break

        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Generate noise and fake data
        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        fake = gen(noise, alpha, step)

        # Train Discriminator
        disc_real = disc(real, alpha, step)
        disc_fake = disc(fake.detach(), alpha, step)
        gp = gradient_penalty(disc, real, fake, alpha, step, device)
        loss_disc = (
            -(torch.mean(disc_real) - torch.mean(disc_fake))
            + lambda_gp * gp
            + (0.001 * torch.mean(disc_real ** 2))
        )

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator
        gen_fake = disc(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Update alpha for progressive growing
        alpha += cur_batch_size / ((prog_epochs[step] * 0.5) * len(dataset))

        alpha = min(alpha, 1)
    
    return alpha


## Training and Testing

### Training

#### Define Networks and Optimizers

In [None]:
import matplotlib.pyplot as plt

gen = Generator(z_dim=z_dim, in_channels=in_channels).to(device)
disc = Discriminator(in_channels=in_channels).to(device)

optimizer_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.99))
optimizer_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.0, 0.99))

if use_pretrained:
    gen.load_state_dict(checkpoint['gen_state'])
    disc.load_state_dict(checkpoint['disc_state'])
    optimizer_gen.load_state_dict(checkpoint['gen_optim'])
    optimizer_disc.load_state_dict(checkpoint['disc_optim'])

fakes = []

if use_pretrained:
    fakes = checkpoint["fakes"]

#### Begin Training

In [None]:
gen.train()
disc.train()

change_alpha = use_pretrained

for num_epochs in prog_epochs[step:]:
    if start_train_at > image_size:
        break
    alpha = 1e-5
    if change_alpha:
        alpha = checkpoint['alpha']
        change_alpha = False

    loader, dataset = get_loader(4 * 2 ** step)
    
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(start_epoch, num_epochs):

        alpha = train(
            disc,
            gen,
            loader,
            dataset,
            step,
            alpha,
            optimizer_disc,
            optimizer_gen,
            epoch,
            num_epochs,
        )

        with torch.no_grad():
            img = gen(fixed_noise, alpha, step) * 0.5 + 0.5

        fakes.append(img)

        if display_images:
            img = img.cpu().detach()
            img = (img - img.min()) / (img.max() - img.min())

            # Display each image
            fig, axes = plt.subplots(8, 8, figsize=(8, 8))  # Create a 8x8 grid
            axes = axes.flatten()

            for i, ax in enumerate(axes):
                image = img[i].permute(1, 2, 0).numpy()  # Rearrange dimensions to (H, W, C)
                ax.imshow(image)
                ax.axis('off')
            plt.tight_layout()
            plt.show()

        if epoch + 1 != num_epochs:
            torch.save({
                'batch_sizes': batch_sizes,
                'start_training_at': 4 * 2 ** step,
                'alpha': alpha,
                'fixed_noise': fixed_noise,
                'z_dim': z_dim,
                'in_channels': in_channels,
                'epoch': epoch + 1,
                'fakes': fakes,
                'gen_state': gen.state_dict(),
                'disc_state': disc.state_dict(),
                'gen_optim': optimizer_gen.state_dict(),
                'disc_optim': optimizer_disc.state_dict(),
            }, f"models/training_imgsize_{4 * 2 ** (step)}_zdim_{z_dim}_progression.pth")

    torch.save({
        'batch_sizes': batch_sizes,
        'start_training_at': (4 * 2 ** step) * 2,
        'alpha': 1e-5,
        'fixed_noise': fixed_noise,
        'z_dim': z_dim,
        'in_channels': in_channels,
        'epoch': 0,
        'fakes': fakes,
        'gen_state': gen.state_dict(),
        'disc_state': disc.state_dict(),
        'gen_optim': optimizer_gen.state_dict(),
        'disc_optim': optimizer_disc.state_dict(),
    }, f"models/pretrained_imgsize_{4 * 2 ** step}_zdim_{z_dim}.pth")

    step += 1

gen.eval()
disc.eval()

print("Eval mode activated")

### Testing

#### Create grid of sample images

In [None]:
img_list = []

for x in [x.detach().cpu() for x in fakes]: # To shorten number of images: [x[:8] for x in fakes]
    x = torch.nn.functional.interpolate(x, size=(128, 128), mode="nearest")
    img_list.append(vutils.make_grid(x, padding=2, normalize=True))

#### Create GIF displaying generation over training

In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

ani.save("anime.gif", writer='pillow', fps=10)

HTML(ani.to_jshtml())

#### Show Real vs. Fake Images

In [None]:
# Grab a batch of real images from the dataloader
dataloader, dataset = get_loader(image_size=128)
real_batch = torch.stack([dataset[i][0] for i in range(64)]).to(device)
#real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch, padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [None]:
gen.eval()
disc.eval()
print("Eval mode set.")

In [None]:
from torchvision.transforms.functional import to_pil_image

noise = torch.randn(1, z_dim, 1, 1).to(device)

output: torch.Tensor = gen(noise, 1, step - 1) * 0.5 + 0.5
output = output.squeeze(0)
img = to_pil_image(output)

plt.figure(figsize=(2,2))
plt.axis("off")
plt.imshow(img)
plt.show()

In [None]:
"""def interpolate_points(p1, p2, n_steps=100):
    ratios = np.linspace(0, 1, num=n_steps)[1:]
    vectors = [p1]
    for ratio in ratios:
        v = (1.0 - ratio) * p1 + ratio * p2
        vectors.append(v)
    return torch.stack(vectors, dim=0)"""

In [None]:
"""point1 = torch.randn(z_dim, 1, 1)
point2 = torch.randn(z_dim, 1, 1)

points = interpolate_points(point1, point2, n_steps=100)

images = [to_pil_image((gen(p.to(device), 1, step - 1) * 0.5 + 0.5).squeeze(0)) for p in points]
images[0].save("progress.gif", save_all=True, append_images=images[1:], duration=10, loop=2)"""

In [None]:
"""
start_train_at = 4 * 2 ** step
epoch = 0 # Specify the epoch to resume training at

torch.save({
    'batch_sizes': batch_sizes,
    'start_training_at': start_train_at,
    'fixed_noise': fixed_noise,
    'z_dim': z_dim,
    'in_channels': in_channels,
    'factors': factors,
    'epoch': epoch,
    'fakes': fakes,
    'gen_state': gen.state_dict(),
    'disc_state': disc.state_dict(),
    'gen_optim': optimizer_gen.state_dict(),
    'disc_optim': optimizer_disc.state_dict(),
}, f"model_imgsize_{4 * 2 ** step}_continue_epoch_{epoch}.pth")
"""