In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from tqdm import tqdm

# Hyper Parameters

In [None]:
Z_DIM = 128
G_HIDDEN = 64
D_HIDDEN = 64
NUM_CHANNELS = 3
EPOCHS = 100_000
NUM_DISCRIMINATORS = 10

BATCH_SIZE = 64
generator_lr = 1e-4
encoder_lr = 1e-4
image_discriminator_lr = 2e-4
noise_discriminator_lr = 2e-4
betas = (.5, .999)

device = torch.device('cuda', index=0)

# Data Preperation

In [None]:
class AnimeFacesDS(Dataset):
    def __init__(self, base_path: str) -> None:
        super().__init__()
        self.files = sorted(Path(base_path).rglob('*.png'))
        self.transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, index: int) -> Tensor:
        image = Image.open(self.files[index])
        return self.transform(image)


dataset = AnimeFacesDS('data')
image = dataset[0]
plt.imshow(image.permute(1, 2, 0))

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=4,
    shuffle=False,
    pin_memory=True,
)

# Models

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # N, Z_DIM, 1, 1
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 16, 4, 1, 0, bias=False),
            # nn.BatchNorm2d(G_HIDDEN * 16),
            nn.LayerNorm([G_HIDDEN * 16, 4, 4]),
            nn.LeakyReLU(.1, inplace=True),
            # N, (G_HIDDEN*8) x 4, 4
            nn.ConvTranspose2d( G_HIDDEN * 16, G_HIDDEN * 8, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(G_HIDDEN * 8),
            nn.LayerNorm([G_HIDDEN * 8, 8, 8]),
            nn.LeakyReLU(.1, inplace=True),
            # N, (G_HIDDEN*4) x 8, 8
            nn.ConvTranspose2d( G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(G_HIDDEN * 4),
            nn.LayerNorm([G_HIDDEN * 4, 16, 16]),
            nn.LeakyReLU(.1, inplace=True),
            # N, (G_HIDDEN*2) x 16, 16
            nn.ConvTranspose2d( G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(G_HIDDEN * 2),
            nn.LayerNorm([G_HIDDEN * 2, 32, 32]),
            nn.LeakyReLU(.1, inplace=True),
            # N, (G_HIDDEN*2) x 32, 32
            nn.ConvTranspose2d( G_HIDDEN * 2, G_HIDDEN * 1, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(G_HIDDEN * 1),
            nn.LayerNorm([G_HIDDEN * 1, 64, 64]),
            nn.LeakyReLU(.1, inplace=True),
            # N, G_HIDDEN x 64, 64
            nn.Conv2d(G_HIDDEN, NUM_CHANNELS, 1, bias=False),
            nn.Sigmoid(),
            # N, NUM_CHANNELS, 64, 64
        )

    def forward(self, input):
        return self.main(input)


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # N, NUM_CHANNELS, 64, 64
            nn.Conv2d(NUM_CHANNELS, D_HIDDEN * 1, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN),
            nn.LayerNorm([D_HIDDEN, 64, 64]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, NUM_CHANNELS, 32, 32
            nn.Conv2d(D_HIDDEN * 1, D_HIDDEN * 2, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LayerNorm([D_HIDDEN * 2, 32, 32]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 16, 16
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LayerNorm([D_HIDDEN * 4, 16, 16]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 8, 8
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LayerNorm([D_HIDDEN * 8, 8, 8]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 4, 4
            nn.Conv2d(D_HIDDEN * 8, D_HIDDEN * 16, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 16),
            nn.LayerNorm([D_HIDDEN * 16, 4, 4]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 2, 2
            nn.Conv2d(D_HIDDEN * 16, D_HIDDEN * 32, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 32),
            nn.LayerNorm([D_HIDDEN * 32, 2, 2]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, NUM_CHANNELS, 1, 1
            nn.Flatten(start_dim=1, end_dim=3),
            nn.Linear(D_HIDDEN * 32, Z_DIM, bias=False),
            # N, Z_DIM
        )

    def forward(self, input):
        return self.main(input)


class ImageDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # N, NUM_CHANNELS, 64, 64
            nn.Conv2d(NUM_CHANNELS, D_HIDDEN * 1, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN),
            nn.LayerNorm([D_HIDDEN, 64, 64]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, NUM_CHANNELS, 32, 32
            nn.Conv2d(D_HIDDEN * 1, D_HIDDEN * 2, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LayerNorm([D_HIDDEN * 2, 32, 32]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 16, 16
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LayerNorm([D_HIDDEN * 4, 16, 16]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 8, 8
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LayerNorm([D_HIDDEN * 8, 8, 8]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 4, 4
            nn.Conv2d(D_HIDDEN * 8, D_HIDDEN * 16, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 16),
            nn.LayerNorm([D_HIDDEN * 16, 4, 4]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, D_HIDDEN, 2, 2
            nn.Conv2d(D_HIDDEN * 16, D_HIDDEN * 32, kernel_size=3, padding=1, bias=False),
            # nn.BatchNorm2d(D_HIDDEN * 32),
            nn.LayerNorm([D_HIDDEN * 32, 2, 2]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2, 2),
            # N, NUM_CHANNELS, 1, 1
            nn.Flatten(start_dim=1, end_dim=3),
            nn.Linear(D_HIDDEN * 32, 1, bias=False),
            nn.Sigmoid(),
            # N, 1
        )

    def forward(self, input):
        return self.main(input).squeeze(1)


class NoiseDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(Z_DIM, 2 * Z_DIM),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2 * Z_DIM, 1),
            nn.Sigmoid(),
            nn.Flatten(start_dim=0, end_dim=1),
        )

    def forward(self, input):
        return self.main(input)


def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
generator = Generator().to(device)
encoder = Encoder().to(device)
image_discriminator = ImageDiscriminator().to(device)
noise_discriminator = NoiseDiscriminator().to(device)


# Other Stuff

In [None]:
generator_optimizer = optim.Adam(generator.parameters(), lr=generator_lr, betas=betas)
encoder_optimizer = optim.Adam(encoder.parameters(), lr=encoder_lr, betas=betas)
image_discriminator_optimizer = optim.Adam(image_discriminator.parameters(), lr=image_discriminator_lr, betas=betas)
noise_discriminator_optimizer = optim.Adam(noise_discriminator.parameters(), lr=noise_discriminator_lr, betas=betas)

bce_criterion = nn.BCELoss()
l1_criterion = nn.L1Loss()
l2_criterion = nn.MSELoss()

exp_number = len(os.listdir('/home/aj/tmp/tblog/'))
logger = SummaryWriter(f'/home/aj/tmp/tblog/exp{exp_number}')

fixed_z = torch.randn(64, Z_DIM, 1, 1, device=device)
for fixed_real in dataloader:
    break
logger.add_images('real_images', fixed_real, 0)
fixed_real = fixed_real.to(device)

# Training Loop

In [None]:
step = 0
for epoch in range(EPOCHS):
    print(f'{epoch = }')
    for real_image in tqdm(dataloader):
        step += 1
        logger.add_scalar('epoch', epoch, step)
        batch_size = real_image.shape[0]
        z = torch.randn(batch_size, Z_DIM, 1, 1, device=device)
        real_image = real_image.to(device)
        ########
        # blue #
        ########
        encoder_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        x_hat = generator(encoder(real_image).view(batch_size, Z_DIM, 1, 1))
        loss = l1_criterion(real_image, x_hat)
        logger.add_scalar('generator(encoder(real_image))_loss', loss.item(), step)
        loss.backward()
        encoder_optimizer.step()
        generator_optimizer.step()
        ##########
        # yellow #
        ##########
        ##########################
        # 1. train discriminator #
        ##########################
        noise_discriminator_optimizer.zero_grad()
        # real label
        real_label = torch.ones(batch_size, device=device)
        real_pred = noise_discriminator(z.view(batch_size, Z_DIM))
        real_loss = bce_criterion(real_pred, real_label)
        real_loss.backward()
        logger.add_scalar(f'dz_real_loss', real_loss.item(), step)
        # fake label
        fake_label = torch.zeros(batch_size, device=device)
        with torch.no_grad():
            fake_noise = encoder(real_image)
        fake_pred = noise_discriminator(fake_noise)
        fake_loss = bce_criterion(fake_pred, fake_label)
        fake_loss.backward()
        logger.add_scalar('dz_fake_loss', fake_loss.item(), step)
        # update
        noise_discriminator_optimizer.step()
        discriminator_loss = real_loss + fake_loss
        logger.add_scalar('dz_loss', discriminator_loss.item(), step)
        ######################
        # 2. train generator #
        ######################
        encoder_optimizer.zero_grad()
        real_label = torch.ones(batch_size, device=device)
        fake_noise = encoder(real_image)
        pred = noise_discriminator(fake_noise)
        encoder_loss = bce_criterion(pred, real_label)
        logger.add_scalar('gz_loss', encoder_loss.item(), step)
        encoder_loss.backward()
        encoder_optimizer.step()
        #########
        # green #
        #########
        encoder_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        z_hat = encoder(generator(z))
        loss = l2_criterion(z, z_hat.view(batch_size, Z_DIM, 1, 1))
        logger.add_scalar('encoder(generator(z))_loss', loss.item(), step)
        loss.backward()
        encoder_optimizer.step()
        generator_optimizer.step()
        #######
        # red #
        #######
        ##########################
        # 1. train discriminator #
        ##########################
        image_discriminator_optimizer.zero_grad()
        # real label
        real_label = torch.ones(batch_size, device=device)
        real_pred = image_discriminator(real_image)
        real_loss = bce_criterion(real_pred, real_label)
        real_loss.backward()
        logger.add_scalar(f'd_real_loss', real_loss.item(), step)
        # fake label
        fake_label = torch.zeros(batch_size, device=device)
        with torch.no_grad():
            fake_image = generator(z)
        fake_pred = image_discriminator(fake_image)
        fake_loss = bce_criterion(fake_pred, fake_label)
        fake_loss.backward()
        logger.add_scalar('d_fake_loss', fake_loss.item(), step)
        # update
        image_discriminator_optimizer.step()
        discriminator_loss = real_loss + fake_loss
        logger.add_scalar('d_loss', discriminator_loss.item(), step)
        ######################
        # 2. train generator #
        ######################
        generator_optimizer.zero_grad()
        real_label = torch.ones(batch_size, device=device)
        fake_image = generator(z)
        pred = image_discriminator(fake_image)
        generator_loss = bce_criterion(pred, real_label)
        logger.add_scalar('g_loss', generator_loss.item(), step)
        generator_loss.backward()
        generator_optimizer.step()

    # log images
    with torch.no_grad():
        # fixed generation
        images = generator(fixed_z)
        logger.add_images('fake_images', images, epoch)
        # reconstruction
        images = generator(encoder(fixed_real).view(BATCH_SIZE, Z_DIM, 1, 1))
        logger.add_images('reconstruction', images, epoch)
