In [None]:
from IPython.display import clear_output
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

sys.path.append('..')
from utils import (
    ImSet,
    ConditionalVariationalAutoEncoder,
    ConditionalDiscriminator,
    show_tensor
)
%matplotlib inline

In [None]:
device = torch.device(0)

imset = ImSet()

batch_size = len(imset)
loader = DataLoader(imset, shuffle=True, batch_size=batch_size, drop_last=True)

cvae = ConditionalVariationalAutoEncoder().to(device)

n_trials = 5000
n_epochs = 5000

model_fname = 'cvae.pth'

latent_size = 1024

In [None]:
dsc = ConditionalDiscriminator().to(device)

In [None]:
lossfun = nn.BCELoss()

# create labels
fake_labels = torch.zeros(batch_size, 1).to(device)
real_labels = torch.ones(batch_size, 1).to(device)

In [None]:
for trial in range(n_trials):
    
    d_losses = []
    g_losses = []

    real_preds = []
    fake_preds = []
    
    
    cvae.load_state_dict(torch.load(model_fname), strict=False)
    gen = cvae.dec
    dsc.load_state_dict(cvae.enc.state_dict(), strict=False)
    
    # initialize optimizers
    opt_d = torch.optim.Adam(dsc.parameters(), lr=.0002, betas=(.5,.999))
    opt_g = torch.optim.Adam(gen.parameters(), lr=.0002, betas=(.5,.999))

    # disable early layers
    for i, layer in gen.net.named_children():
        if int(i) > 13:
            layer.requires_grad_ = False
    
    for epoch in range(n_epochs):
        for (img_fname, img, species_, class_, gender_) in loader:
            img = img.to(device)
            species_ = species_.to(device)
            class_ = class_.to(device)
            gender_ = gender_.to(device)
            
            # forward pass for real
            pred_real = dsc(img, species_, class_, gender_)
            
            d_loss_real = lossfun(pred_real, real_labels)
            
            # forward pass for fake
            fake_data = torch.randn(batch_size, latent_size).to(device)
            fake_images = gen(fake_data, species_, class_, gender_)
            pred_fake = dsc(fake_images, species_, class_, gender_)
            d_loss_fake = lossfun(pred_fake, fake_labels)
            
            # backprop
            d_loss = d_loss_real + d_loss_fake
            opt_d.zero_grad()
            d_loss.backward()
            opt_d.step()
            d_losses.append(d_loss.item())
            
            # train generator
            fake_data = torch.randn(batch_size, latent_size).to(device)
            fake_images = gen(fake_data, species_, class_, gender_)
            pred_fake = dsc(fake_images, species_, class_, gender_)
            
            # backprop
            g_loss = lossfun(pred_fake, real_labels)
            opt_g.zero_grad()
            g_loss.backward()
            opt_g.step()
            g_losses.append(g_loss.item())
            
            real_preds.append(torch.mean((pred_real > 0.5).float()).detach().cpu())
            fake_preds.append(torch.mean((pred_fake > 0.5).float()).detach().cpu())
            
            # visualize output
            grid = make_grid(fake_images[:36], nrow=9)

            fig = plt.figure(figsize=(15, 12))
            gs = GridSpec(5, 1, figure=fig)
            ax0 = fig.add_subplot(gs[:3])
            ax1 = fig.add_subplot(gs[3])
            ax2 = fig.add_subplot(gs[4])

            clear_output(wait=True)
            show_tensor(grid, ax0)
            ax1.plot(g_losses)
            ax1.plot(d_losses)
            ax2.plot(real_preds)
            ax2.plot(fake_preds)
            plt.show()
    
        if epoch % 100 == 0:
            fig.savefig(f'CGAN_output/trial_{trial:09}_epoch_{epoch:09}.png')
            torch.save(gen.state_dict(), f'CGAN_output/trial_{trial:09}_epoch_{epoch:09}_gen.pth')
            torch.save(dsc.state_dict(), f'CGAN_output/trial_{trial:09}_epoch_{epoch:09}_dsc.pth')