In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import sys
import os
model_folder = '/content/drive/MyDrive/GM Project/VQGAN Project/model'
helper_methods_folder = '/content/drive/MyDrive/GM Project/VQGAN Project/utils'

sys.path.append(os.path.abspath(model_folder))
sys.path.append(os.path.abspath(helper_methods_folder))

In [3]:
import os
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import utils as vutils
from discriminator import Discriminator
from perceptualloss import LPIPS
from vqgan import VQGAN
from utils import load_data, weights_init

In [4]:
def configure_optimizers():
    lr = 2.25e-05
    opt_vq = torch.optim.Adam(
        list(vqgan.encoder.parameters()) +
        list(vqgan.decoder.parameters()) +
        list(vqgan.codebook.parameters()) +
        list(vqgan.quant_conv.parameters()) +
        list(vqgan.post_quant_conv.parameters()),
        lr=lr, eps=1e-08, betas=(0.5, 0.9)
    )
    opt_disc = torch.optim.Adam(discriminator.parameters(),
                                lr=lr, eps=1e-08, betas=(0.5, 0.9))

    return opt_vq, opt_disc

In [None]:
vqgan = VQGAN().to(device="cuda:0")
discriminator = Discriminator().to(device="cuda:0")
discriminator.apply(weights_init)
perceptual_loss = LPIPS().eval().to(device="cuda:0")
opt_vq, opt_disc = configure_optimizers()

In [None]:
train_dataset = load_data(dataset_path=r"/content/drive/MyDrive/GM Project/coco2017")
steps_per_epoch = len(train_dataset)
for epoch in range(10):
    with tqdm(range(len(train_dataset))) as pbar:
        for i, imgs in zip(pbar, train_dataset):
            imgs = imgs.to(device="cuda:0")
            decoded_images, _, q_loss = vqgan(imgs)

            disc_real = discriminator(imgs)
            disc_fake = discriminator(decoded_images)

            disc_factor = vqgan.adopt_weight(1., epoch * steps_per_epoch + i, threshold=10000)

            _perceptual_loss = perceptual_loss(imgs, decoded_images)
            rec_loss = torch.abs(imgs - decoded_images)
            perceptual_rec_loss = 1. * _perceptual_loss + 1. * rec_loss
            perceptual_rec_loss = perceptual_rec_loss.mean()
            g_loss = -torch.mean(disc_fake)

            λ = vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
            vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss

            d_loss_real = torch.mean(F.relu(1. - disc_real))
            d_loss_fake = torch.mean(F.relu(1. + disc_fake))
            gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake)

            opt_vq.zero_grad()
            vq_loss.backward(retain_graph=True)

            opt_disc.zero_grad()
            gan_loss.backward()

            opt_vq.step()
            opt_disc.step()

            if i % 100 == 0:
                with torch.no_grad():
                    real_fake_images = torch.cat((imgs.add(1).mul(0.5)[:4], decoded_images.add(1).mul(0.5)[:4]))
                    vutils.save_image(real_fake_images, os.path.join("/content/drive/MyDrive/GM Project/results/", f"{epoch}_{i}.jpg"), nrow=4)

            pbar.set_postfix(
                VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(), 5),
                GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(), 3)
            )
            pbar.update(0)

            torch.save(vqgan.state_dict(), os.path.join("/content/drive/MyDrive/GM Project/checkpoints/", f"vqgan_epoch_{epoch}.pt"))
                
torch.save(vqgan.state_dict(), os.path.join("/content/drive/MyDrive/GM Project/checkpoints/", f"vqgan_final_.pt"))

100%|██████████| 1637/1637 [3:22:34<00:00,  7.42s/it, GAN_Loss=0, VQ_Loss=1.26]
100%|██████████| 1637/1637 [1:26:09<00:00,  3.16s/it, GAN_Loss=0, VQ_Loss=0.473]
 52%|█████▏    | 856/1637 [45:03<39:53,  3.07s/it, GAN_Loss=0, VQ_Loss=0.369]