# VAE CelebA
### Overview
This script is used to train a VAE (BetaVAE,  𝛽=1 ), BetaVAE, FactorVAE, or BetaTCVAE on the CelebA dataset, then save it. The script will train a VAE-based model on a fixed amount of data using the hyperparameters defined in the cell below. The script will train the network with the given hyperparameters, compare original data with reconstructions, create images of latent traversals, and save the model.

### Instructions
Set hyperparameters for the run in the cell below. Then, hit Run All on the jupyter notebook.


In [None]:
#### CHOOSE A SEED #### (or generate randomly)
seed = 60
import random
random.seed(seed)

from ae_utils_exp import B_TCVAE as VAE_BASED_MODEL # change <model> in ".... import <model> as ...."
### options: VAE (for VAE, BetaVAE), FACTOR_VAE, or B_TCVAE (for BetaTCVAE)

### SELECT HYPERPARAMETERS FOR THIS RUN #######

beta = 50.0 # \beta for BetaVae, FactorVAE, and BetaTCVAE# \beta for BetaVae, FactorVAE, and BetaTCVAE
n_lat = 32 # VAE bottleneck size (m) 
batch_size = 200 # batch size used for training
lr = 0.001 # learning rate used for training. should be 1e-4 if the model is FACTOR_VAE

savename = "./models/celeba_vae_lr1e-3_seed{}_b{}.pt".format(seed, beta) # savename for the trained model


print("Seed: ", seed)
print("Beta: ", beta)
print("Batch Size: ", batch_size)
print("Savename: ", savename)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from ae_utils_exp import multi_t
from torchvision.transforms import Compose

np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from torchvision.datasets import CelebA
import torchvision.transforms as tforms

tform = tforms.Compose([tforms.Resize(96), tforms.CenterCrop(64), tforms.ToTensor()])

dataset = CelebA(root='../beamsynthesizer/data', split='all', download=False, transform=tform)


In [None]:
from ae_utils_exp import celeba_norm, celeba_inorm
from architectures import enc_celeba_small_vae as enc
from architectures import dec_celeba_small as dec

ae = VAE_BASED_MODEL(celeba_norm, enc(lat=n_lat, inp_chan=3), dec(lat=n_lat, inp_chan=3), \
                     device, z_dim=n_lat, inp_inorm=celeba_inorm)


In [None]:
rec_loss, kl_loss = \
    ae.fit(dataset, 200, batch_per_group=20, batch_size=batch_size, lr=lr, beta=beta)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 3))
# plot the loss curves on a log scale
ax[0].set_ylabel("$log_{10}$(LogProb Rec Loss)")
ax[0].set_xlabel("Group")
ax[0].plot(np.log10(rec_loss), linewidth=2, label='Reconstruction')
ax[0].legend()
ax[0].grid(True, which='both', ls='-')

ax[1].set_ylabel("$log_{10}$(KL Loss)")
ax[1].set_xlabel("Group")
ax[1].plot(np.log10(kl_loss), linewidth=2, label='KL')
ax[1].legend()
ax[1].grid(True, which='both', ls='-')



In [None]:
plt_batch_size=200
num_to_plot=20
z_scores, z_pred_scores, inp, rec = ae.record_latent_space(dataset, batch_size=plt_batch_size, n_batches=5)

inp = multi_t(inp, 1, 3).clamp(0, 1).cpu().numpy()
rec = multi_t(rec, 1, 3).clamp(0, 1).cpu().numpy()

fig, axes = plt.subplots(2, num_to_plot, figsize=(20, 4))
for i in range(num_to_plot):
    axes[0][i].imshow(inp[i])
    axes[1][i].imshow(rec[i])
    axes[0][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    axes[1][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
plt.tight_layout()

In [None]:
from ae_utils_exp import InvNorm

invn = celeba_inorm


# determine base z_scores
ind = 0
z_base = z_scores[ind]
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(inp[ind], cmap='gray')
axes[0].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
axes[1].imshow(rec[ind], cmap='gray')
axes[1].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)


In [None]:
# decode
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 24))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i].min()
        _max = z_scores[:, i].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min >= 0.2:
                z = z_base.clone()
                z[i] = variation[j]
                im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0, 1).squeeze().cpu().numpy()
                axes[i][j].imshow(im, cmap='gray')
plt.tight_layout()

In [None]:
# decode
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 24))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i + ae.z_dim//2].min()
        _max = z_scores[:, i + ae.z_dim//2].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min >= 0.2:
                z = z_base.clone()
                z[i + ae.z_dim//2] = variation[j]
                im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0,1).squeeze().cpu().numpy()
                axes[i][j].imshow(im, cmap='gray')
plt.tight_layout()

In [None]:
torch.save(ae.state_dict(), savename)