# NashAE CelebA
### Overview
This script is used to train a NashAE or AE (NashAE, $\lambda=0$) on the CelebA dataset, then save it. The script will train a NashAE on a fixed amount of data using the hyperparameters defined in the cell below. The script will train the network with the given hyperparameters, compare original data with reconstructions, plot true latent variables against their predictions, create images of latent traversals, and save the model.

### Instructions
Set hyperparameters for the run in the cell below. Then, hit Run All on the jupyter notebook.


In [None]:
########## SELECT A SEED ##############

seed = 55
import random
random.seed(seed)

####### SELECT HYPERPARAMETERS ###########
ar = 0.2 # adversarial ratio (\lambda)
batch_size = 200 # batch size used for training
n_lat = 32 # number of latent features used for training
lr = 0.001 # learning rate for training

savename = "./models/celeba_ae_lr1e-3_seed{}_ar{}.pt".format(seed, ar) # savename for the trained model

print("Seed: ", seed)
print("Batch Size: ", batch_size)
print("LR: ", lr)
print("AdvRatio: ", ar)
print("Savename: ", savename)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from ae_utils_exp import s_init, AutoEncoder, multi_t
from torchvision.transforms import Compose

np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from torchvision.datasets import CelebA
import torchvision.transforms as tforms

tform = tforms.Compose([tforms.Resize(96), tforms.CenterCrop(64), tforms.ToTensor()])

dataset = CelebA(root='../beamsynthesizer/data', split='all', download=False, transform=tform)


In [None]:
from ae_utils_exp import celeba_norm, celeba_inorm
inp_bn = celeba_norm
from architectures import enc_celeba_small as enc
from architectures import dec_celeba_small as dec
ae = AutoEncoder(inp_bn, enc(lat=n_lat, inp_chan=3), dec(lat=n_lat, inp_chan=3), device, z_dim=n_lat, inp_inorm=celeba_inorm)


In [None]:
rec_loss, adv_loss, pred_loss = \
    ae.fit(dataset, 200, batch_per_group=20, batch_size=batch_size, lr=lr, ar=ar)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16, 3))
# plot the loss curves on a log scale
ax[0].set_ylabel("$log_{10}$(MSE Loss)")
ax[0].set_xlabel("Group")
ax[0].plot(np.log10(rec_loss), linewidth=2, label='Reconstruction')
ax[0].plot(np.log10(pred_loss), linewidth=2, label='Predictor')
ax[0].legend()
ax[0].grid(True, which='both', ls='-')

ax[1].set_ylabel("$log_{10}$ (Abs. Mean Cov.)")
ax[1].set_xlabel("Group")
ax[1].plot(np.log10(adv_loss.abs()/ae.z_dim), linewidth=2, label='Adversarial')
ax[1].legend()
ax[1].grid(True, which='both', ls='-')

In [None]:
# show original data and reconstructions

plt_batch_size=200
num_to_plot=20
z_scores, z_pred_scores, inp, rec = ae.record_latent_space(dataset, batch_size=plt_batch_size, n_batches=5)

inp = multi_t(inp, 1, 3).clamp(0, 1).cpu().numpy()
rec = multi_t(rec, 1, 3).clamp(0, 1).cpu().numpy()

fig, axes = plt.subplots(2, num_to_plot, figsize=(20, 4))
for i in range(num_to_plot):
    axes[0][i].imshow(inp[i])
    axes[1][i].imshow(rec[i])
    axes[0][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    axes[1][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
plt.tight_layout()

In [None]:
# plot true latent variables against their predictions

from ae_utils_exp import covariance
# plot the latent space against itself
fig, axes = plt.subplots(4, ae.z_dim//4, figsize=(24, 8))
rho2_agg = 0.
for ind in range(ae.z_dim):
    i = ind // (ae.z_dim//4)
    j = ind % (ae.z_dim//4)
    axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    axes[i][j].scatter(z_scores[..., ind], z_pred_scores[..., ind])
    axes[i][j].set_xlim((-0.05, 1.05))
    axes[i][j].set_ylim((-0.05, 1.05))
    cov = covariance(z_scores[..., ind], z_pred_scores[..., ind]).item()
    std = z_scores[..., ind].std(dim=0).item()
    std_p = z_pred_scores[..., ind].std(dim=0).item()
    rho2 = 0.
    if std > 0.01 and std_p > 0.:
        rho2 = (cov/(std*std_p))**2
    rho2_agg += rho2
    axes[i][j].set_title("R2: {:1.3f}".format(rho2))
print(rho2_agg/ae.z_dim)

In [None]:
# base image used for traversals

from ae_utils_exp import InvNorm

invn = celeba_inorm

# determine base z_scores
ind = 2
z_base = z_scores[ind]
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(inp[ind], cmap='gray')
axes[0].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
axes[1].imshow(rec[ind], cmap='gray')
axes[1].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)


In [None]:
# decode traversals of the base latent encoding. omit traversals if their max-min range is less than 0.2
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 24))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i].min()
        _max = z_scores[:, i].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min >= 0.2:
                z = z_base.clone()
                z[i] = variation[j]
                im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0, 1).squeeze().cpu().numpy()
                axes[i][j].imshow(im, cmap='gray')
plt.tight_layout()

In [None]:
# decode traversals of the base latent encoding. omit traversals if their max-min range is less than 0.2
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 24))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i + ae.z_dim//2].min()
        _max = z_scores[:, i + ae.z_dim//2].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min >= 0.2:
                z = z_base.clone()
                z[i + ae.z_dim//2] = variation[j]
                im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0,1).squeeze().cpu().numpy()
                axes[i][j].imshow(im, cmap='gray')
plt.tight_layout()

In [None]:
torch.save(ae.state_dict(), savename)