In [1]:
# from google.colab import drive
# drive.mount("/content/gdrive")

Mounted at /content/gdrive


## **Dual Contradistinctive Generative Autoencoder - CVPR 2021** 
---
* **Authors:** *Gaurav Parmar, Dacheng Li, Kwonjoon Lee, Zhuowen Tu*
* **Link:** https://arxiv.org/abs/2011.10063
* **Official Implementation:** https://github.com/mlpc-ucsd/DC-VAE

#### **Project Group Members**
* Aybora Köksal aybora@metu.edu.tr
* Halil Çağrı Bilgi cagri.bilgi@metu.edu.tr


# **Paper Summary**


# Imports

In [1]:
from lib import *

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import os.path as osp
import os 

import sacred
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred import SETTINGS
SETTINGS.CONFIG.READ_ONLY_CONFIG=False

# Hyperparameters & Configurations

In [2]:
configs = {
    'model_params' : {
        'decoder': {
            'latent_dim' : 128,
            'channel_dim' : 256
        },

        'encoder' : {
            'ch_in' : 3,
            'hid_ch': 128,
            'z_dim' : 128
        },

        'discriminator' : {
            'ch_in' : 3, 
            'hid_ch': 128,
            'cont_dim' : 16
        }
    },

    'hparams' : {
        'epochs' : 800,
        'train_batch_size' : 64, 
        'test_batch_size' : 64,
        'lr' : 0.0002,
        'disp_freq' : 20,
        'gen_train_freq' : 5,
        'checkpoint': 500,
        'beta1' : 0.0,
        'beta2' : 0.9,
        'device' : torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    }
}

In [3]:
model_params = configs["model_params"]
hparams = configs["hparams"]

## Model & Optimizer & Loss Function Initializations

In [4]:
model = Model(model_params)
model.apply(weights_init)

enc_optim = torch.optim.Adam(model.encoder.parameters(), lr = hparams['lr'], betas = (hparams['beta1'], hparams['beta2']))
dec_optim = torch.optim.Adam(model.decoder.parameters(), lr = hparams['lr'], betas = (hparams['beta1'], hparams['beta2']))
disc_optim = torch.optim.Adam(model.discriminator.parameters(), lr = hparams['lr'], betas = (hparams['beta1'], hparams['beta2']))

gan_criterion = torch.nn.BCEWithLogitsLoss()
# contrastive_loss is imported from lib/loss.py

  torch.nn.init.xavier_uniform(m.weight.data, 1.)


## Load Datasets

In [5]:
train_loader = DataLoader(
    torchvision.datasets.CIFAR10(
        './data', 
        train = True,
        download = True, 
        transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(
                                (0.5, 0.5, 0.5), 
                                (0.5, 0.5, 0.5)),
                        ])
        ),
    batch_size=hparams['train_batch_size'], 
    shuffle=True, 
    drop_last=True
    )

test_loader = DataLoader(torchvision.datasets.CIFAR10(
    './data', 
    train = False, transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(
                                (0.5, 0.5, 0.5), 
                                (0.5, 0.5, 0.5)),
                        ])
    ),
    batch_size=hparams['test_batch_size'], 
    shuffle=True
    )

Files already downloaded and verified


## Training Loop and Saving Model

In [None]:
def train_in_notebook(model_params, hparams, model, gan_criterion, enc_optim ,dec_optim, disc_optim, _run, device = torch.device("cpu")):
    
    model.to(device)

    torch.manual_seed(123)

    if device.type == "cuda":
        torch.cuda.manual_seed(123)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    ####### LOG #######
    disc_train_loss = 0
    gen_train_loss = 0
    cont_train_loss = 0

    mean_generator_loss = 0
    mean_discriminator_loss = 0
    mean_contrastive_loss = 0

    _run.info["gen_loss_train"] = list()
    _run.info["disc_loss_train"] = list()
    _run.info["cont_loss_train"] = list()
    _run.info["fid"] = list()
    ###################

    disp_freq = hparams['disp_freq']
    step = 1

    for epoch in range(1,hparams['epochs']+1):
        
        iterator = tqdm(train_loader, leave=True)
        iterator.set_description_str(f"Epoch: {epoch}")
        batch_id = 0
        for point_batch, _ in iterator: 

            model.train()
            
            batch_id += 1
            
            model.train()
            model.device = device

            #### Real Data
            real_data = point_batch.to(device) 

            '''----------------         Discriminator Update         ----------------'''
            disc_optim.zero_grad()
            
            fake_data = model.gen_from_noise(size=(real_data.size(0), model_params['decoder']['latent_dim']))
            z_latent, rec_data = model(real_data)

            disc_fake_pred, _ = model.discriminator(fake_data)
            disc_fake_loss = gan_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))

            disc_rec_pred, _ = model.discriminator(rec_data)
            disc_rec_loss = gan_criterion(disc_rec_pred, torch.zeros_like(disc_rec_pred))

            disc_real_pred, _ = model.discriminator(real_data)
            disc_real_loss = gan_criterion(disc_real_pred, torch.ones_like(disc_real_pred))

            gan_objective = disc_real_loss + (disc_rec_loss + disc_fake_loss)*0.5 
            gan_objective.backward(retain_graph = True)
            disc_optim.step()

            ### Log
            _run.info["disc_loss_train"].append(gan_objective.item())
            disc_train_loss = gan_objective.item()
            mean_discriminator_loss += gan_objective.item()
            ########

            '''----------------         Generator Update         ----------------'''
            if step % hparams['gen_train_freq'] == 0:

                # KLD loss term missing !!!!!
                enc_optim.zero_grad()
                dec_optim.zero_grad()
                
                fake_data = model.gen_from_noise(size=(2*real_data.size(0), model_params['decoder']['latent_dim']))
                z_latent, rec_data = model(real_data)

                gen_fake_pred, _ = model.discriminator(fake_data)
                gen_fake_loss = gan_criterion(gen_fake_pred, torch.ones_like(gen_fake_pred))

                gen_rec_pred, _ = model.discriminator(rec_data)
                gen_rec_loss = gan_criterion(gen_rec_pred, torch.ones_like(gen_rec_pred))

                gan_objective =  gen_rec_loss + gen_fake_loss 

                gan_objective.backward(retain_graph = True)
                enc_optim.step()
                dec_optim.step()

                ### Log
                _run.info["gen_loss_train"].append(gan_objective.item())
                gen_train_loss = gan_objective.item()
                mean_generator_loss += gan_objective.item() 
                ########
            
                '''----------------         Contrastive Update         ----------------'''

                enc_optim.zero_grad()
                dec_optim.zero_grad()
                disc_optim.zero_grad()

                z_latent, rec_data = model(real_data)

                _, rec_contrastive = model.discriminator(rec_data)
                _, real_contrastive = model.discriminator(real_data)

                cont_loss = contrastive_loss(z_latent, real_contrastive, rec_contrastive)

                cont_loss.backward()
                
                disc_optim.step()
                enc_optim.step()
                dec_optim.step()

                ### Log
                _run.info["cont_loss_train"].append(cont_loss.item())
                cont_train_loss = cont_loss.item()
                mean_contrastive_loss += cont_loss.item() 
                ########
            
            # Visualize the generated & reconstructed images
            if step % disp_freq == 0:
                gen_images = model.gen_from_noise(size = (25, model_params['decoder']['latent_dim']))
                t_data, _ = iter(test_loader).next()
                t_data = t_data.to(device)
                _ , rec_t_data = model(t_data)
                show_img(gen_images, step, num_images=25, size=(3, 32, 32), img_save_path=osp.join(_run.experiment_info['base_dir'], 'runs', _run._id, 'results'), show=False)
                show_img_rec(t_data, rec_t_data, step, num_images=15, size=(3, 32, 32), img_save_path=osp.join(_run.experiment_info['base_dir'], 'runs', _run._id, 'results'), show=False)
            
            step += 1

            iterator.set_postfix_str(
                f"Disc Loss: {disc_train_loss:.4f}, Gen Loss: {gen_train_loss:.4f}, Cont Loss: {cont_train_loss:.4f}, Step: {step} " )
            
            '''----------------         Save Model         ----------------'''
            if step % hparams['checkpoint'] == 0:
                c_name = f"checkpoint_{step}.pt"
                checkpoint_path = osp.join(_run.experiment_info['base_dir'], 'runs', _run._id, "checkpoints", c_name)
                torch.save(model.state_dict(), checkpoint_path)
        
        # Calculate FID score at the end of each Epoch
        fid_samp = eval(model, model_params['decoder']['latent_dim'], hparams['test_batch_size'], device)
        print(f"Epoch: {epoch}| sampling fid: {fid_samp}")
        _run.info["fid"].append(fid_samp)

## Initialize Experiment and Start Training

In [None]:
dirname = os.path.dirname(os.path.realpath("__file__"))

experiment_dir = os.path.join(dirname, 'runs')

ex = Experiment("ceng796")
ex.observers.append(FileStorageObserver(experiment_dir))
ex.add_config(configs)

@ex.automain
def main(_config, _run):
    sacred.commands.print_config(_run)
    
    os.makedirs(os.path.join(experiment_dir, _run._id, "checkpoints"))
    os.makedirs(os.path.join(experiment_dir, _run._id, "results"))
    
    train_in_notebook(model_params, hparams, model, gan_criterion, enc_optim ,dec_optim, disc_optim, _run, device = hparams['device'])