In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

sys.path.append('..')
from utils import ImSet, ConditionalVariationalAutoEncoder, show_tensor
%matplotlib inline

In [None]:
device = torch.device(0)

imset = ImSet()

batch_size = len(imset)
loader = DataLoader(imset, shuffle=True, batch_size=batch_size, drop_last=True)

cvae = ConditionalVariationalAutoEncoder().to(device)
opt = torch.optim.Adam(cvae.parameters())

n_epochs = 10000
best_loss = 9e9
losses = []

model_fname = 'cvae.pth'
cvae.load_state_dict(torch.load(model_fname), strict=False)

breakout = False
avg_loss = 9e9

for trial in range(1000):

    for epoch in range(n_epochs):
        counter = 0

        for (img_fname, img, species_, class_, gender_) in loader:
            img = img.to(device)
            species_ = species_.to(device)
            class_ = class_.to(device)
            gender_ = gender_.to(device)

            Xt = cvae(img, species_, class_, gender_)
            loss = ((img - Xt)**2).sum() + cvae.enc.kl

            opt.zero_grad()
            loss.backward()
            opt.step()

            losses.append(loss.item())

            with open('cvae.log', 'a') as fout:
                fout.write('epoch=%d,iter=%d,loss=%.2f,kl=%.2f\n' % (epoch, counter, loss.item(), cvae.enc.kl))

            losses = losses[-1000:]

            if counter % 10 == 0:
                # update best loss and save best model
                avg_loss = np.mean(losses[-20:])
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    torch.save(cvae.state_dict(), model_fname)

                clear_output(wait=True)

                # show real images
                grid = make_grid(img[:10], nrow=10)
                fig, ax = plt.subplots(figsize=(15, 5))
                show_tensor(grid, ax)

                # show encoded images
                grid = make_grid(Xt[:10], nrow=10)
                fig, ax = plt.subplots(figsize=(15, 5))
                show_tensor(grid, ax)

                # show loss curve
                fig, ax = plt.subplots(figsize=(15, 3))
                ax.plot(losses)
                plt.show()
            
            if np.isnan(avg_loss):
                breakout = True
                break

            counter += 1
        
        if np.isnan(avg_loss):
            breakout = True
            break

In [None]:
# Xt = cvae.enc(img, species_, class_, gender_)
# Xt.shape

In [None]:
# prior scores: loss=17845560.00, kl=1967967.62