# VAE Beamsynthesis
### Overview
This script is used to run all Beamsynthesis tests involving VAE (BetaVAE, $\beta=1$), BetaVAE, FactorVAE, or BetaTCVAE. The script will train a VAE-based model on a fixed amount of data using the hyperparameters defined in the cell below. The script will generate latent traversals, compare original data against reconstructions of data, create 3D visualizations of the learned latent space, and evaluate the latent space using the BetaVAE disentanglement metric.

### Instructions
Set hyperparameters for the run in the cell below. Then, hit Run All on the jupyter notebook.


In [None]:
import random
seed = random.randint(0, 1000)
random.seed(seed)


### set hyperparameters for this run

from ae_utils_exp import B_TCVAE as VAE_BASED_MODEL # change <model> in ".... import <model> as ...."
### options: VAE (for VAE, BetaVAE), FACTOR_VAE, or B_TCVAE (for BetaTCVAE)

beta = 125. # \beta for BetaVae, FactorVAE, and BetaTCVAE
n_lat = 4 # VAE bottleneck size (m)
batch_size = 100 # batch size used for training
lr = 0.01 # learning rate used for training. should be 1e-4 if the model is FACTOR_VAE

print("Seed: ", seed)
print("Batch Size: ", batch_size)
print("Learning Rate: ", lr)
print("Beta: ", beta)

In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from torchvision.datasets import DatasetFolder
import matplotlib.pyplot as plt
from ae_utils_exp import s_init, beam_s2s2_norm, beam_s2s2_inorm
from torchvision.transforms import Compose

np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from architectures import enc_beamform_vae as enc
from architectures import dec_beamform as dec

inp_bn = beam_s2s2_norm
ae = VAE_BASED_MODEL(inp_bn, enc(lat=n_lat), dec(lat=n_lat), device, z_dim=n_lat, inp_inorm=beam_s2s2_inorm)

In [None]:
# set up the dataset retrieval
# loadfunc: given a "x.xxx.npy" file, return a tensor version and its 'name'
loadfunc = lambda path: (torch.tensor(np.load(path)).type(torch.float), path[-10:-4], path[-16:-11],)
tform = lambda x: x[0]
dataset = DatasetFolder("./beamsynthesis", loadfunc, (".npy",), transform=Compose([tform]))

In [None]:
rec_loss, kl_loss = \
    ae.fit(dataset, 100, beta=beta, lr=lr,\
           batch_size=batch_size, generator_ae=torch.Generator().manual_seed(0),)

In [None]:
# plot the loss curves on a log scale
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].set_ylabel("$log_{10}$(Reconstruction Loss)")
ax[1].set_ylabel("$log_{10}$(KL Loss)")
ax[0].set_xlabel("Group")
ax[1].set_xlabel("Group")
ax[0].plot(np.log10(rec_loss), linewidth=2, label='Reconstruction')
ax[1].plot(np.log10(kl_loss), linewidth=2, label='KL')
ax[0].grid(True, which='both', ls='-')
ax[1].grid(True, which='both', ls='-')

In [None]:
z_scores, std_scores, inp, rec = ae.record_latent_space(dataset, batch_size=10, n_batches=10)

z_base = z_scores[1]

fig, ax = plt.subplots(ae.z_dim, 5, figsize=(16,16))
for i in range(ae.z_dim):
    _min = z_scores[:, i].min()
    _max = z_scores[:, i].max()
    variation = torch.linspace(_min, _max, steps=5)
    for j in range(len(variation)):
        z = z_base.clone()
        z[i] = variation[j]
        out = ae.dec(z.to(ae.device)).detach()
        if _max - _min > 0.2:
            ax[i][j].plot(out.squeeze().cpu().numpy(), linewidth=2)
plt.tight_layout()

In [None]:
# plot original data vs. reconstructions and associate latent encodings with data generating factor ground truth

dataset = DatasetFolder("./beamsynthesis", loadfunc, (".npy",))
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, batch_size=1, num_workers=0)

ae.eval()

n_ex = len(dataloader)
latent = np.zeros((n_ex, ae.z_dim))

param1 = np.zeros((n_ex,))
param2 = np.zeros((n_ex,))
use_param2 = False

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
f_ind = 0

with torch.no_grad():
    for i, (data, lab,) in enumerate(dataloader):
            ex = data[0].to(device).detach_()
            par1 = float(data[1][0])
            param1[i] = par1
            if (len(data) > 2):
                use_param2 = True
                par2 = float(data[2][0])
                param2[i] = par2 # param2 in this case is S1_duty_cycle
            out = ae(ex).squeeze()
            latent[i] = ae.mu.cpu().numpy()
            ex = inp_bn(ex)
            
            if (i)%(len(dataloader)//11) == 0 and f_ind < 10:
                ind = f_ind//5, f_ind%5
                if use_param2:
                    axes[ind].set_title("DC:{:1.3f}   FR:{:1.3f}".format(par2, par1))
                else:
                    axes[ind].set_title("Param: {:1.3f}".format(par1))
                axes[ind].plot(ex[0].cpu().numpy(), linewidth=2, label='in')
                axes[ind].plot(out.cpu().numpy(), linewidth=2, label='out')
                axes[ind].legend()
                f_ind += 1

    
    

In [None]:
# parameters for viewing the latent space in 3D
view_alt=5
view_ang=90
alpha=0.5

In [None]:
# plot the learned latent space in 3D (two views)
fig = plt.figure(figsize=(10, 4))
ax = None
ax = fig.add_subplot(121, projection='3d')

for i in range(ae.z_dim):
    ax.scatter(param2, param1, latent[..., i], label='L{}'.format(i+1), alpha=alpha)

ax.view_init(view_alt, view_ang)
ax.set_xlabel('S2_duty_cycle')
ax.set_ylabel('S2_frequency')
ax.set_zlabel('Latent Activation')
#ax.legend()

ax = fig.add_subplot(122, projection='3d')

for i in range(ae.z_dim):
    ax.scatter(param2, param1, latent[..., i], label='L{}'.format(i+1), alpha=alpha)

ax.view_init(view_alt, view_ang + 45)
ax.set_xlabel('S2_duty_cycle')
ax.set_ylabel('S2_frequency')
ax.set_zlabel('Latent Activation')
ax.legend()

plt.tight_layout()

In [None]:
# count the number of learned latent features
ave_vars = std_scores.square().mean(dim=0)
print("BetaVAE Learned Features: ", (ave_vars <= 0.8).sum().item())
z_max = z_scores.max(dim=0)[0]
z_min = z_scores.min(dim=0)[0]
print("FactorVAE, BetaTCVAE Learned Features: ", (z_max - z_min >= 2.).sum().item())

In [None]:
freqs = [10., 15., 20.]
# frequency
for f in freqs:
    filt = param1 == f
    _z_scores = torch.tensor(latent[filt, :])
    _n_zs = filt.sum()
    _z_diff = (_z_scores[:_n_zs//2] - _z_scores[_n_zs//2:]).abs()
    print(f)
    print(_z_diff.mean(dim=0))

In [None]:
# initialize the BetaVAE disentanglement metric
from ae_utils_exp import DisentanglementMetric as DM
dm = DM(n_lat, 2, lr=1.0)
freqs = [10., 15., 20.]

In [None]:
# train the disentanglement metric linear classifier
n_iterations = 10000
bsize = 20
losses = torch.zeros(n_iterations)
for i in range(n_iterations):
    # construct the batch
    batch = torch.zeros((bsize, n_lat))
    # randomly choose data generating factor to hold constant
    is_freq = torch.rand(1) > 0.5
    for b_ind in range(bsize):
        if is_freq: # this is ind 1
            # randomly choose a frequency
            freq_ind = int(torch.rand(1)*3)
            filt = param1 == freqs[freq_ind]
            
        else: # dc is ind 0
            # randomly choose a duty cycle
            tenths = torch.randint(low=2, high=8, size=(1,)).item()
            hundredths = torch.randint(low=0, high=10, size=(1,)).item()
            thousandths = 0 if torch.rand(1) > 0.5 else 5
            dc = tenths * 100 + hundredths * 10. + thousandths
            filt = param2*1000. == dc
        _z_scores = torch.tensor(latent[filt, :])
        _z_scores = _z_scores[torch.randperm(_z_scores.shape[0])]
        # _z_scores is shuffled, select the difference of the first 2 as elem
        ex = (_z_scores[0] - _z_scores[1]).abs()
        batch[b_ind] = ex
    # batch is now constructed
    # train on batch
    prediction = dm(batch)
    loss = dm.fit_batch(1 if is_freq else 0, batch.cpu())
    losses[i] = loss
plt.figure()
plt.plot(losses)

In [None]:
# train the disentanglement metric linear classifier
n_iterations = 1000
bsize = 20
n_correct = 0
for i in range(n_iterations):
    if i == int(n_iterations*0.9):
        dm.set_lr(0.001)
    # construct the batch
    batch = torch.zeros((bsize, n_lat))
    # randomly choose data generating factor to hold constant
    is_freq = i >= n_iterations//2
    for b_ind in range(bsize):
        if is_freq: # this is ind 1
            # randomly choose a frequency
            freq_ind = int(torch.rand(1)*3)
            filt = param1 == freqs[freq_ind]
        else: # dc is ind 0
            # randomly choose a duty cycle
            tenths = torch.randint(low=2, high=8, size=(1,)).item()
            hundredths = torch.randint(low=0, high=10, size=(1,)).item()
            thousandths = 0 if torch.rand(1) > 0.5 else 5
            dc = tenths * 100 + hundredths * 10. + thousandths
            filt = param2*1000. == dc
        _z_scores = torch.tensor(latent[filt, :])
        _z_scores = _z_scores[torch.randperm(_z_scores.shape[0])]
        # _z_scores is shuffled, select the difference of the first 2 as elem
        ex = (_z_scores[0] - _z_scores[1]).abs()
        batch[b_ind] = ex
    # batch is now constructed
    # train on batch
    prediction = dm(batch.mean(dim=0).unsqueeze(0))
    n_correct += 1. if prediction == is_freq else 0.
print("Acc: {:1.2f}".format(n_correct/n_iterations*100.))