In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import print_function

#%matplotlib inline

import os
import random

import torch
import torch.nn as nn
import torch.optim as optim

import torch.utils.data
import torchvision.datasets as dset
from torch.utils.tensorboard import SummaryWriter

import sys
from utils.commandline import load_config
from importlib.machinery import SourceFileLoader
from datasets.web_caricature import WebCaricatureDataset
from losses.model_losses  import AdversarialLoss, PatchAdversarialLoss
from models.model_warpgan import WarpGANGenerator, WarpGANDiscriminator

%matplotlib inline

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)

random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  999


<torch._C.Generator at 0x7fe072059b50>

In [3]:
# load configuration file from specified configuration file path
config = SourceFileLoader('config', './config/config_train.py').load_module()

In [4]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = WebCaricatureDataset(config)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.in_batch, shuffle=True)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(f"\nDevice: {device}\n")


6128 images of 126 classes loaded.
Classes 126: 3016 photos, 3112 caricatures

Device: cpu



In [5]:
# load model

# Create the generator
warpgan_generator     = WarpGANGenerator(config).to(device)
warpgan_discriminator = WarpGANDiscriminator(config).to(device)


In [6]:
# load losses

adversarial_loss       = AdversarialLoss(config)
patch_adversarial_loss = PatchAdversarialLoss() 


In [7]:
# Setup Adam optimizers for both G and D

optimizerG = optim.Adam(warpgan_generator.parameters(), lr=config.lr, weight_decay=config.weight_decay,
                        betas=(config.optimizer[1]["beta1"], config.optimizer[1]["beta2"]))

optimizerD = optim.Adam(warpgan_discriminator.parameters(), lr=config.lr, weight_decay=config.weight_decay,
                        betas=(config.optimizer[1]["beta1"], config.optimizer[1]["beta2"]))


In [9]:

# Training Loop

writer = SummaryWriter("runs/Train-1")

# For each epoch
for epoch in range(config.num_epochs):
    
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
                
        global_iter = epoch * len(dataloader) + i
            
        # ------------------------------------------
        # Input dicts for Generator and Disciminator
        # ------------------------------------------
                        
        generator_input_dict = {
            
            "images_photo" : data["images_photo"],
            "images_caric" : data["images_caric"],
            
            "labels_photo" : data["labels_photo"],
            "labels_caric" : data["labels_caric"],
            
            "scales_photo" : data["scales_photo"],
            "scales_caric" : data["scales_caric"],
            
        }
        
        discriminator_input_dict = {
            
            "images_photo" : data["images_photo"],
            "images_caric" : data["images_caric"],
            
            "generated_caric": None
            
        }
        
        # ------------------------------------------
        # Generator Network Forward Pass
        # ------------------------------------------
                
        # forward pass on generator
        generator_output = warpgan_generator(generator_input_dict)
        
        # add generated caricature to discriminator input dict
        discriminator_input_dict["generated_caric"] = generator_output["generated_caric"]
                
        # forward pass on discriminator
        discriminator_output = warpgan_discriminator(discriminator_input_dict)
                
        # adversial losses on generated caricature
        
        adversial_loss_input_dict = {

                "logits_caric" : discriminator_output["logits_caric"],
                "logits_photo" : discriminator_output["logits_photo"],
                "logits_generated_caric": discriminator_output["logits_generated_caric"],

                "labels_caric" : generator_input_dict["labels_caric"],
                "labels_photo" : generator_input_dict["labels_photo"],
                "labels_generated_caric" : generator_input_dict["labels_photo"],

                }
                
        loss_DA, loss_GA = adversarial_loss(adversial_loss_input_dict)
        
        loss_DA, loss_GA = config.coef_adv * loss_DA, config.coef_adv * loss_GA
        
        # patch adversial losses on generated caricature
                
        loss_DP, loss_GP = patch_adversarial_loss(discriminator_output["logits_caric"],
                                                  discriminator_output["logits_photo"],
                                                  discriminator_output["logits_generated_caric"])
        
        loss_DP, loss_GP = config.coef_adv * loss_DP, config.coef_adv * loss_GP
        
        # identity mapping (reconstruction) loss
                
        loss_idt_caric = torch.mean(torch.abs(generator_output["rendered_caric"] - generator_input_dict["images_caric"]))
        loss_idt_caric = config.coef_idt * loss_idt_caric
        
        loss_idt_photo = torch.mean(torch.abs(generator_output["rendered_photo"] - generator_input_dict["images_photo"]))
        loss_idt_photo = config.coef_idt * loss_idt_photo
        
        loss_G_idt     = loss_idt_caric + loss_idt_photo
        
        # tensorboard writer save all losses
                    
        writer.add_scalar('Loss-Generator/Adversial',     loss_GA,        global_iter)
        writer.add_scalar('Loss-Generator/Patch',         loss_GP,        global_iter)
        writer.add_scalar('Loss-Generator/IdentityCaric', loss_idt_caric, global_iter)
        writer.add_scalar('Loss-Generator/IdentityPhoto', loss_idt_photo, global_iter)
        
        writer.add_scalar('Loss-Discriminator/Adversial', loss_DA,        global_iter)
        writer.add_scalar('Loss-Discriminator/Patch',     loss_DP,        global_iter)

        # collect all losses
    
        # all losses for generator
        loss_G = loss_GA + loss_GP + loss_G_idt
        
        # all losses for discriminator
        loss_D = loss_DA + loss_DP
                
        # reset gradients of discriminator
        warpgan_discriminator.zero_grad()
        
        # calculate gradients for discriminator
        loss_D.backward(retain_graph=True)
                
        # reset gradients of generator
        warpgan_generator.zero_grad()
    
        # calculate gradients for generator
        loss_G.backward()
        
        # optimize generator for single step
        optimizerD.step()
        
        # optimize generator for single step
        optimizerG.step()
        
        # output training stats
        if i % 100 == 0:
            
            log  = f"[{epoch}/{config.num_epochs}][{i}/{len(dataloader)}]\t"
            log += f"Loss_G: {loss_G}\t"
            log += f"Loss_D: {loss_D}\t"

            print(log)
            
        # check how the generator is doing by saving G's output
        if global_iter % 1 == 0:
            
            caricature = generator_output["generated_caric"][0].detach().cpu()
            writer.add_image("Caricature", caricature, global_iter)
       

[0/20][0/3064]	Loss_G: 20.178382873535156	Loss_D: 27.349363327026367	


KeyboardInterrupt: 