# NashAE dSprites
### Overview
This script is used to train a NashAE or AE (NashAE, $\lambda=0$) on the dSprites dataset, then evaluate it. The script will train a NashAE on a fixed amount of data using the hyperparameters defined in the cell below. The script will train the network with the given hyperparameters, compare original data with reconstructions, plot true latent variables against their predictions, create images of latent traversals, and evaluate the model using the BetaVAE metric.

### Instructions
Set hyperparameters for the run in the cell below. Then, hit Run All on the jupyter notebook.


In [None]:
import random
seed = random.randint(0, 1000)
random.seed(seed)

### SELECT HYPERPARAMETERS FOR THE MODEL ###

ar = 0.008 # adversarial ratio (\lambda)
n_lat = 10 # AE bottleneck size (m)
batch_size = 200 # batch size used for training
lr = 0.001 # learning rate used for training


print("Seed: ", seed)
print("Batch Size: ", batch_size)
print("LR: ", lr)
print("AdvRatio: ", ar)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from ae_utils_exp import s_init, AutoEncoder, InpNorm1D
from torchvision.transforms import Compose

np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from dsprites import DSPRITES
from torchvision.transforms import Compose, ToTensor

ten_type = lambda x: torch.tensor(x, dtype=torch.float)
flatten = lambda x: x.view(-1)
chan_insert = lambda x: x.unsqueeze(0)

dataset = DSPRITES(path="../beamsynthesizer/data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz", transform=Compose([ten_type, chan_insert]))

In [None]:
ident = torch.nn.Identity()
from architectures import enc_dsprites_fc as enc
from architectures import dec_dsprites_fc as dec
ae = AutoEncoder(ident, enc(lat=n_lat), dec(lat=n_lat), device, z_dim=n_lat, inp_inorm=ident,)


In [None]:
rec_loss, adv_loss, pred_loss = \
    ae.fit(dataset, 200, preds_train_iters=5, batch_per_group=20, batch_size=batch_size, lr=lr, pred_lr=0.01, ar=ar)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16, 3))
# plot the loss curves on a log scale
ax[0].set_ylabel("$log_{10}$(MSE Loss)")
ax[0].set_xlabel("Group")
ax[0].plot(np.log10(rec_loss), linewidth=2, label='Reconstruction')
ax[0].plot(np.log10(pred_loss), linewidth=2, label='Predictor')
ax[0].legend()
ax[0].grid(True, which='both', ls='-')

ax[1].set_ylabel("$log_{10}$ (Abs. Mean Cov.)")
ax[1].set_xlabel("Group")
ax[1].plot(np.log10(adv_loss.abs()/ae.z_dim), linewidth=2, label='Adversarial')
ax[1].legend()
ax[1].grid(True, which='both', ls='-')

In [None]:
# plot some reconstructions of the data
plt_batch_size=200
num_to_plot=20
z_scores, z_pred_scores, inp, rec = ae.record_latent_space(dataset, batch_size=plt_batch_size, n_batches=5)

inp = inp.view(-1, 64, 64).cpu().numpy()
rec = rec.view(-1, 64, 64).cpu().numpy()

fig, axes = plt.subplots(2, num_to_plot, figsize=(20, 4))
for i in range(num_to_plot):
    axes[0][i].imshow(inp[i], cmap='gray')
    axes[1][i].imshow(rec[i], cmap='gray')
    axes[0][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    axes[1][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
plt.tight_layout()

In [None]:
# count the number of learned latent features
_mins, _min_indices = z_scores.min(dim=0)
_maxes, _max_indices = z_scores.max(dim=0)
diff = _maxes - _mins
print("NashAE Count: ", (diff > 0.2).sum().item())

In [None]:
# plot true latent variables vs their predictions; calculate r^2 statistic for each pair
from ae_utils_exp import covariance
# plot the latent space against itself
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for ind in range(ae.z_dim):
    i = ind // 5
    j = ind % 5
    axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    axes[i][j].scatter(z_scores[..., ind], z_pred_scores[..., ind])
    axes[i][j].set_xlim((-0.05, 1.05))
    axes[i][j].set_ylim((-0.05, 1.05))
    cov = covariance(z_scores[..., ind], z_pred_scores[..., ind]).item()
    std = z_scores[..., ind].std(dim=0).item()
    std_p = z_pred_scores[..., ind].std(dim=0).item()
    rho2 = 0.
    if std > 0. and std_p > 0.:
        rho2 = (cov/(std*std_p))**2
    axes[i][j].set_title("R2: {:1.3f}".format(rho2))

In [None]:
# determine base z_scores
ind = 0
z_base = z_scores[ind]
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(inp[ind], cmap='gray')
axes[0].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
axes[1].imshow(rec[ind], cmap='gray')
axes[1].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)


In [None]:
# decode
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 8))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i].min()
        _max = z_scores[:, i].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min > 0.2:
                z = z_base.clone()
                z[i] = variation[j]
                im = dsprites_inorm(ae.dec(z.to(ae.device))).view(64, 64).cpu().numpy()
                axes[i][j].imshow(im, cmap='gray', vmin=0., vmax=1.)
plt.tight_layout()

In [None]:
# decode
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 8))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i + ae.z_dim//2].min()
        _max = z_scores[:, i + ae.z_dim//2].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min > 0.2:
                z = z_base.clone()
                z[i + ae.z_dim//2] = variation[j]
                im = dsprites_inorm(ae.dec(z.to(ae.device))).view(64, 64).cpu().numpy()
                axes[i][j].imshow(im, cmap='gray', vmin=0., vmax=1.)
plt.tight_layout()

In [None]:
# initialize the BetaVAE disentanglement metric
from ae_utils_exp import DisentanglementMetric as DM
dm = DM(n_lat, 4, lr=1.0)

In [None]:
# train the disentanglement metric linear classifier
n_groups = 3000
batch_per_group = 20
bsize = 100
losses = torch.zeros(n_groups)
for i in range(n_groups):
    loss = 0.
    for j in range(batch_per_group):
        # construct the batch
        # randomly choose a data generating factor to hold constant, and create batches
        const_dgf, batch_tup1, batch_tup2 = dataset.sample_latent_dm(bsize)
        # push batches through the autoencoder
        z1 = ae.z_act(ae.enc(ae.inp_norm(batch_tup1[0].to(device)))).detach()
        z2 = ae.z_act(ae.enc(ae.inp_norm(batch_tup2[0].to(device)))).detach()
        z_diff = (z1 - z2).abs()
        # train on batch. min dgf is 2
        loss += dm.fit_batch(const_dgf - 2, z_diff.cpu())
    losses[i] = loss/batch_per_group
    if i % 500 == 0:
        print(i)
    if i == int(0.95*n_groups):
        dm.set_lr(0.05)
plt.figure()
plt.plot(losses)


In [None]:
n_iterations = 1000
bsize = 100
n_correct = 0
for i in range(n_iterations):
    # construct the batch
    # randomly choose a data generating factor to hold constant, and create batches
    const_dgf, batch_tup1, batch_tup2 = dataset.sample_latent_dm(bsize)
    # push batches through the autoencoder
    z1 = ae.z_act(ae.enc(ae.inp_norm(batch_tup1[0].to(device)))).detach()
    z2 = ae.z_act(ae.enc(ae.inp_norm(batch_tup2[0].to(device)))).detach()
    z_diff = (z1 - z2).abs()
    # batch is now constructed
    prediction = dm(z_diff.mean(dim=0).unsqueeze(0).cpu())
    n_correct += 1. if prediction == const_dgf - 2 else 0.
print("Acc: {:1.2f}".format(n_correct/n_iterations*100.))

In [None]:
print("Seed: ", seed)