In [None]:
!pip install labml

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting labml
  Downloading labml-0.4.162-py3-none-any.whl (129 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.3/129.3 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gitpython (from labml)
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
Collecting gitdb<5,>=4.0.1 (from gitpython->labml)
  Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython->labml)
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)
Installing collected packages: smmap, gitdb, gitpython, labml
Successfully installed gitdb-4.0.10 gitpython-3.1.31 labml-0.4.162 smmap-5.0.0


In [None]:
import itertools
import random
import zipfile
from typing import Tuple

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import InterpolationMode
from torchvision.utils import make_grid
import cv2
import os
import torchvision
from torchvision.utils import save_image

from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml.utils.download import download_file
from labml.utils.pytorch import get_modules

class GeneratorResNet(nn.Module):
    """
    The generator is a residual network.
    """

    def __init__(self, input_channels: int, n_residual_blocks: int):
        super().__init__()
        out_features = 64
        layers = [
            nn.Conv2d(input_channels, out_features, kernel_size=7, padding=3, padding_mode='reflect'),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        for _ in range(2):
            out_features *= 2
            layers += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        for _ in range(n_residual_blocks):
            layers += [ResidualBlock(out_features)]

        for _ in range(2):
            out_features //= 2
            layers += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        layers += [nn.Conv2d(out_features, input_channels, 7, padding=3, padding_mode='reflect'), nn.Tanh()]

        self.layers = nn.Sequential(*layers)

        self.apply(weights_init_normal)

    def forward(self, x):
        return self.layers(x)


class ResidualBlock(nn.Module):
    """
    This is the residual block, with two convolution layers.
    """

    def __init__(self, in_features: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor):
        return x + self.block(x)


class Discriminator(nn.Module):
    """
    This is the discriminator.
    """

    def __init__(self, input_shape: Tuple[int, int, int]):
        super().__init__()
        channels, height, width = input_shape
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        self.layers = nn.Sequential(
            DiscriminatorBlock(channels, 64, normalize=False),
            DiscriminatorBlock(64, 128),
            DiscriminatorBlock(128, 256),
            DiscriminatorBlock(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

        self.apply(weights_init_normal)

    def forward(self, img):
        return self.layers(img)


class DiscriminatorBlock(nn.Module):
    """
    This is the discriminator block module.
    It does a convolution, an optional normalization, and a leaky ReLU.

    It shrinks the height and width of the input feature map by half.
    """

    def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
        super().__init__()
        layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor):
        return self.layers(x)


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)


def load_image(path: str):
    image = Image.open(path)
    if image.mode != 'RGB':
        image = Image.new("RGB", image.size).paste(image)

    return image


class ImageDataset(Dataset):

    def __init__(self, dir_a,dir_b, transform):
        self.files_a = os.listdir(dir_a)
        self.files_b = os.listdir(dir_b)
        self.root_dir_a=dir_a
        self.root_dir_b=dir_b
        self.transform=transform
    def __getitem__(self, index):
        index_a=index % len(self.files_a)
        index_b=index % len(self.files_b)
        a=self.files_a[index_a]
        b=self.files_b[index_b]
        image_a=Image.open(os.path.join(self.root_dir_a,a))
        image_b=Image.open(os.path.join(self.root_dir_b,b))
        image_a=image_a.convert('RGB')
        image_b=image_b.convert('RGB')
        image_a=self.transform(image_a)
        image_b=self.transform(image_b)
        return {"x": image_a,
                "y": image_b}

    def __len__(self):
        # Number of images in the dataset
        return max(len(self.files_a), len(self.files_b))


class ReplayBuffer:

    def __init__(self, max_size: int = 50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data: torch.Tensor):
        """Add/retrieve an image"""
        data = data.detach()
        res = []
        for element in data:
            if len(self.data) < self.max_size:
                self.data.append(element)
                res.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    res.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    res.append(element)
        return torch.stack(res)


In [None]:
class Configs(BaseConfigs):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    epochs: int = 200
    batch_size: int = 1
    data_loader_workers = 8
    learning_rate = 0.0002
    adam_betas = (0.5, 0.999)
    decay_start = 100
    gan_loss = torch.nn.MSELoss()
    cycle_loss = torch.nn.L1Loss()
    identity_loss = torch.nn.L1Loss()
    img_height = 128
    img_width = 128
    img_channels = 3
    n_residual_blocks = 9
    cyclic_loss_coefficient = 10.0
    identity_loss_coefficient = 5.
    sample_interval = 20

    generator_xy: GeneratorResNet
    generator_yx: GeneratorResNet
    discriminator_x: Discriminator
    discriminator_y: Discriminator

    # Optimizers
    generator_optimizer: torch.optim.Adam
    discriminator_optimizer: torch.optim.Adam

    # Learning rate schedules
    generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR
    discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR

    # Data loaders
    dataloader: DataLoader
    valid_dataloader: DataLoader

    def sample_images(self, n: int):
        batch = next(iter(self.dataloader))
        self.generator_xy.eval()
        self.generator_yx.eval()
        with torch.no_grad():
            data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
            noise_x=torch.randn_like(data_x)
            noise_y=torch.rand_like(data_y)
            gen_y = self.generator_xy(data_x)
            gen_x = self.generator_yx(data_y)
            gen_y_noise=self.generator_xy(noise_x)
            gen_x_noise=self.generator_yx(noise_y)

            data_x = make_grid(data_x, nrow=4, normalize=True)
            data_y = make_grid(data_y, nrow=4, normalize=True)
            gen_x = make_grid(gen_x, nrow=4, normalize=True)
            gen_y = make_grid(gen_y, nrow=4, normalize=True)
            gen_x_noise = make_grid(gen_x_noise, nrow=4, normalize=True)
            gen_y_noise = make_grid(gen_y_noise, nrow=4, normalize=True)

            image_grid = torch.cat((data_x, gen_y, data_y, gen_x,gen_x_noise,gen_y_noise), 1)

        plot_image(image_grid)
        save_image(image_grid,f'/content/drive/Shareddrives/ECE285/CycleGAN/results/{n}.png')

    def initialize(self):
        input_shape = (self.img_channels, self.img_height, self.img_width)

        self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
        self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device)
        self.discriminator_x = Discriminator(input_shape).to(self.device)
        self.discriminator_y = Discriminator(input_shape).to(self.device)

        pth_name=os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints')
        if len(os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints'))!=0:
            checkpoint = torch.load('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints/'+pth_name[0], map_location='cpu')
            self.generator_xy.load_state_dict(checkpoint['generator_xy'])
            self.generator_yx.load_state_dict(checkpoint['generator_yx'])
            self.discriminator_x.load_state_dict(checkpoint['discriminator_x'])
            self.discriminator_y.load_state_dict(checkpoint['discriminator_y'])

        self.generator_optimizer = torch.optim.Adam(
            itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()),
            lr=self.learning_rate, betas=self.adam_betas)
        self.discriminator_optimizer = torch.optim.Adam(
            itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()),
            lr=self.learning_rate, betas=self.adam_betas)

        decay_epochs = self.epochs - self.decay_start
        self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)
        self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs)

        transforms_ = torchvision.transforms.Compose([
            transforms.Resize(int(self.img_height * 1.12), InterpolationMode.BICUBIC),
            transforms.RandomCrop((self.img_height, self.img_width)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.dataloader = DataLoader(
            ImageDataset('/content/drive/Shareddrives/ECE285/vtuber_image/vtuber_images','/content/drive/MyDrive/ECE285SPRING/animegirl-faces/size256', transforms_),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.data_loader_workers,
        )

    def run(self):
        gen_x_buffer = ReplayBuffer()
        gen_y_buffer = ReplayBuffer()

        for epoch in monit.loop(self.epochs):
            for i, batch in monit.enum('Train', self.dataloader):
                data_x, data_y = batch['x'].to(self.device), batch['y'].to(self.device)
                true_labels = torch.ones(data_x.size(0), *self.discriminator_x.output_shape,
                                         device=self.device, requires_grad=False)
                false_labels = torch.zeros(data_x.size(0), *self.discriminator_x.output_shape,
                                           device=self.device, requires_grad=False)
                gen_x, gen_y,loss_generator,loss_cycle,loss_gan,loss_identity = self.optimize_generators(data_x, data_y, true_labels)

                #  Train discriminators
                loss_discriminator=self.optimize_discriminator(data_x, data_y,
                                            gen_x_buffer.push_and_pop(gen_x), gen_y_buffer.push_and_pop(gen_y),
                                            true_labels, false_labels)
                tracker.save()
                tracker.add_global_step(max(len(data_x), len(data_y)))

                batches_done = epoch * len(self.dataloader) + i
                if batches_done % self.sample_interval == 0:
                    # Sample images
                    if len(os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints'))!=0:
                        batches=int(os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints/')[0].split(".")[0].split("_")[1])
                        batches+=self.sample_interval
                    else:
                      batches=batches_done
                    self.sample_images(batches)

                    pth_name=os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints')
                    if len(os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints'))!=0:
                        checkpoint = torch.load('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints/'+pth_name[0], map_location='cpu')
                        loss_generator_result=checkpoint["loss_generator_result"]
                        loss_generator_result.append(loss_generator)
                        loss_cycle_result=checkpoint["loss_cycle_result"]
                        loss_cycle_result.append(loss_cycle)
                        loss_gan_result=checkpoint["loss_gan_result"]
                        loss_gan_result.append(loss_gan)
                        loss_identity_result=checkpoint["loss_identity_result"]
                        loss_identity_result.append(loss_identity)
                        loss_discriminator_result=checkpoint["loss_discriminator_result"]
                        loss_discriminator_result.append(loss_discriminator)
                        os.remove('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints/'+pth_name[0])
                    else:
                        loss_generator_result=[]
                        loss_cycle_result=[]
                        loss_gan_result=[]
                        loss_identity_result=[]
                        loss_discriminator_result=[]

                        loss_generator_result.append(loss_generator)
                        loss_cycle_result.append(loss_cycle)
                        loss_gan_result.append(loss_gan)
                        loss_identity_result.append(loss_identity)
                        loss_discriminator_result.append(loss_discriminator)

                    torch.save({"generator_xy":self.generator_xy.state_dict(),
                                "generator_yx":self.generator_yx.state_dict(),
                                "discriminator_x":self.discriminator_x.state_dict(),
                                "discriminator_y":self.discriminator_y.state_dict(),
                                "loss_generator_result":loss_generator_result,
                                "loss_cycle_result":loss_cycle_result,
                                "loss_gan_result":loss_gan_result,
                                "loss_identity_result":loss_identity_result,
                                "loss_discriminator_result":loss_discriminator_result},f'/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints/trained_{batches}.pth')

            # Update learning rates
            self.generator_lr_scheduler.step()
            self.discriminator_lr_scheduler.step()
            tracker.new_line()

    def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor):
        self.generator_xy.train()
        self.generator_yx.train()
        loss_identity = (self.identity_loss(self.generator_yx(data_x), data_x) +
                         self.identity_loss(self.generator_xy(data_y), data_y))
        gen_y = self.generator_xy(data_x)
        gen_x = self.generator_yx(data_y)

        loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) +
                    self.gan_loss(self.discriminator_x(gen_x), true_labels))

        loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) +
                      self.cycle_loss(self.generator_xy(gen_x), data_y))

        # Total loss
        loss_generator = (loss_gan +
                          self.cyclic_loss_coefficient * loss_cycle +
                          self.identity_loss_coefficient * loss_identity)

        self.generator_optimizer.zero_grad()
        loss_generator.backward()
        self.generator_optimizer.step()

        # Log losses
        tracker.add({'loss.generator': loss_generator,
                     'loss.generator.cycle': loss_cycle,
                     'loss.generator.gan': loss_gan,
                     'loss.generator.identity': loss_identity})

        return gen_x, gen_y,loss_generator.detach().cpu().item(),loss_cycle.detach().cpu().item(),loss_gan.detach().cpu().item(),loss_identity.detach().cpu().item()

    def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor,
                               gen_x: torch.Tensor, gen_y: torch.Tensor,
                               true_labels: torch.Tensor, false_labels: torch.Tensor):

        loss_discriminator = (self.gan_loss(self.discriminator_x(data_x), true_labels) +
                              self.gan_loss(self.discriminator_x(gen_x), false_labels) +
                              self.gan_loss(self.discriminator_y(data_y), true_labels) +
                              self.gan_loss(self.discriminator_y(gen_y), false_labels))

        self.discriminator_optimizer.zero_grad()
        loss_discriminator.backward()
        self.discriminator_optimizer.step()

        # Log losses
        tracker.add({'loss.discriminator': loss_discriminator})
        return loss_discriminator.detach().cpu().item()

def train():
    conf = Configs()
    conf.initialize()
    conf.run()


def plot_image(img: torch.Tensor):
    from matplotlib import pyplot as plt
    img = img.cpu()
    img_min, img_max = img.min(), img.max()
    img = (img - img_min) / (img_max - img_min + 1e-5)
    img = img.permute(1, 2, 0)
    plt.imshow(img)
    plt.axis('off')
    plt.show()


def evaluate():
    conf = Configs()
    conf.initialize()

    transforms_ = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    '''
    #dataset = ImageDataset('/content/drive/Shareddrives/ECE285/vtuber_image/vtuber_images','/content/drive/MyDrive/ECE285SPRING/animegirl-faces/size256', transforms)

    # Get an image from dataset
    x_image = dataset[10]['x']
    y_image = dataset[10]['y']
    # Display the image
    plot_image(x_image)
    plot_image(y_image)
    '''
    conf.generator_xy.eval()
    conf.generator_yx.eval()
    if not os.path.exists('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/real_vtuber'):
        os.makedirs('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/real_vtuber')
    if not os.path.exists('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/real_faces'):
        os.makedirs('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/real_faces')
    if not os.path.exists('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/fake_faces'):
        os.makedirs('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/fake_faces')
    if not os.path.exists('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/fake_vtuber'):
        os.makedirs('/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/fake_vtuber')
    with torch.no_grad():
        for idx,batch in enumerate(conf.dataloader):
            x_image,y_image=batch['x'],batch['y']
            data_x = x_image.to(conf.device)
            data_y = y_image.to(conf.device)
            generated_y = conf.generator_xy(data_x)
            generated_x = conf.generator_yx(data_y)
            save_image(x_image,f'/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/real_vtuber/{idx}.png')
            save_image(y_image,f'/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/real_faces/{idx}.png')
            save_image(generated_y,f'/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/fake_faces/{idx}.png')
            save_image(generated_x,f'/content/drive/Shareddrives/ECE285/CycleGAN/images_for_evaluation/fake_vtuber/{idx}.png')

    '''
    # Display the generated image.
    plot_image(generated_y[0].cpu())
    plot_image(generated_x[0].cpu())
    '''

In [None]:
train()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import matplotlib.pyplot as plt

pth_name=os.listdir('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints')
checkpoint = torch.load('/content/drive/Shareddrives/ECE285/CycleGAN/checkpoints/'+pth_name[0], map_location='cpu')
loss_generator_result=checkpoint["loss_generator_result"]
loss_cycle_result=checkpoint["loss_cycle_result"]
loss_gan_result=checkpoint["loss_gan_result"]
loss_identity_result=checkpoint["loss_identity_result"]
loss_discriminator_result=checkpoint["loss_discriminator_result"]

num_epochs = len(loss_generator_result)
epochs = range(1, num_epochs + 1)

plt.plot(epochs, loss_generator_result, label='Generator Loss')
plt.plot(epochs, loss_cycle_result, label='Cycle Loss')
plt.plot(epochs, loss_gan_result, label='GAN Loss')
plt.plot(epochs, loss_identity_result, label='Identity Loss')
plt.plot(epochs, loss_discriminator_result, label='Discriminator Loss')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Curves')
plt.legend()
plt.show()

In [None]:
evaluate()