Instructions for setting up conda kernel on NCAR machine:  
- open terminal (e.g., in JupyterHub)
- execute the following series of commands:

`module load conda`  
`conda create -n pyroved`  
`conda activate pyroved`  
`conda install pytorch pyro-ppl pytorch-cuda ipykernel jupyter ipywidgets -c pytorch -c nvidia`  
`pip install pyroved`  

- now you should see a [conda-env:pyroved] conda kernel available when you open a jupyter notebook

# Test: Shift-VAE example from documentation
source: https://colab.research.google.com/github/ziatdinovmax/pyroVED/blob/master/examples/shiftVAE.ipynb#scrollTo=8CIc2tL_45qT

In [None]:
import pyroved as pv
import torch
import matplotlib.pyplot as plt

In [None]:
torch.cuda.is_available() # check if GPU accessible

In [None]:
def gaussian(x, mu, sig):
    return torch.exp(-torch.pow(x - mu, 2.) / (2 * torch.pow(sig, 2.)))

n_samples = 5000 # number of samples
l_signal = 100 # signal length

# Generate synthetic dataset with randomly shifted noisy 1D signal
torch.manual_seed(1)  # for reproducibility
x = torch.linspace(-12, 12, l_signal).expand(n_samples, l_signal)
noise = torch.randint(1, 100, (n_samples, 1)) / 1e3
mu = torch.randint(-30, 30, size=(n_samples, 1)) / 10
sig = torch.randint(50, 500, size=(n_samples, 1)) / 1e2
train_data = gaussian(x, mu, sig) + noise * torch.randn(size=(n_samples, l_signal))
# Normalize to (0, 1)
train_data = (train_data - train_data.min()) / (train_data.max() - train_data.min())

# Initialize train loader
train_loader = pv.utils.init_dataloader(train_data.unsqueeze(1), batch_size=64)

In [None]:
fig, axes = plt.subplots(8, 8, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for ax, (y,) in zip(axes.flat, train_loader):
    ax.plot(x[0], y[0, 0])

In [None]:
in_dim = (100,)

# Initialize vanilla VAE
vae = pv.models.iVAE(in_dim, latent_dim=2, invariances=None, seed=0)

# Initialize SVI trainer
trainer = pv.trainers.SVItrainer(vae)

# Train for n epochs:
for e in range(250):
    trainer.step(train_loader)
    if e % 10 == 0: # only print every 10 epochs
        trainer.print_statistics()

In [None]:
z_mean, z_sd = vae.encode(train_data)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
im1 = ax1.scatter(z_mean[:, -1], z_mean[:, -2], s=1, c=mu)
ax1.set_xlabel(r"$z_2$", fontsize=14)
ax1.set_ylabel(r"$z_1$", fontsize=14)
cbar1 = fig.colorbar(im1, ax=ax1, shrink=.8)
cbar1.set_label(r"$\mu$", fontsize=14)
ax1.set_xlim(-2, 2)
ax1.set_ylim(-2, 2)
im2 = ax2.scatter(z_mean[:, -1], z_mean[:, -2], s=1, c=sig)
ax2.set_xlabel(r"$z_2$", fontsize=14)
ax2.set_ylabel(r"$z_1$", fontsize=14)
cbar2 = fig.colorbar(im2, ax=ax2, shrink=.8)
cbar2.set_label(r"$\sigma$", fontsize=14)
ax2.set_xlim(-2, 2)
ax2.set_ylim(-2, 2);

In [None]:
vae.manifold2d(d=10);

In [None]:
_, ax = plt.subplots(2, 2, figsize=(10, 10))
ax[0, 0].scatter(z_mean[:, 0], mu, c='blue')
ax[0, 0].set_xlabel(r"$\mu$", fontsize=14)
ax[0, 0].set_ylabel("Latent variable 1", fontsize=14)
ax[0, 1].scatter(z_mean[:, 1], mu, c='blue')
ax[0, 1].set_xlabel(r"$\mu$", fontsize=14)
ax[0, 1].set_ylabel("Latent variable 2", fontsize=14)
ax[1, 0].scatter(z_mean[:, 0], sig, c='blue')
ax[1, 0].set_xlabel(r"$\sigma$", fontsize=14)
ax[1, 0].set_ylabel("Latent variable 1", fontsize=14)
ax[1, 1].scatter(z_mean[:, 1], sig, c='blue')
ax[1, 1].set_xlabel(r"$\sigma$", fontsize=14)
ax[1, 1].set_ylabel("Latent variable 2", fontsize=14);

In [None]:
in_dim = (100,)

# Initialize shift-invariant VAE (to do this we add 't' to invariances)
svae = pv.models.iVAE(in_dim, latent_dim=2, invariances=['t'], dx_prior=.3)

# Initialize SVI trainer
trainer = pv.trainers.SVItrainer(svae)

# Train for n epochs:
for e in range(250):
    trainer.step(train_loader)
    if e % 10 == 0: # only print every 10 epochs
        trainer.print_statistics()

In [None]:
z_mean, z_sd = svae.encode(train_data)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
im1 = ax1.scatter(z_mean[:, -1], z_mean[:, -2], s=1, c=mu)
ax1.set_xlabel(r"$z_1$", fontsize=14)
ax1.set_ylabel(r"$z_2$", fontsize=14)
cbar1 = fig.colorbar(im1, ax=ax1, shrink=.8)
cbar1.set_label(r"$\mu$", fontsize=14)
ax1.set_xlim(-2, 2)
ax1.set_ylim(-2, 2)
im2 = ax2.scatter(z_mean[:, -1], z_mean[:, -2], s=1, c=sig)
ax2.set_xlabel(r"$z_1$", fontsize=14)
ax2.set_ylabel(r"$z_2$", fontsize=14)
cbar2 = fig.colorbar(im2, ax=ax2, shrink=.8)
cbar2.set_label(r"$\sigma$", fontsize=14)
ax2.set_xlim(-2, 2)
ax2.set_ylim(-2, 2);

In [None]:
svae.manifold2d(d=10);

In [None]:
rescale = ((x.max() - x.min()) / 2 ) * svae.t_prior.cpu() # rescaling coefficient for shift latent variable
i = 2 # non-collapsed latent variable 
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(13.5, 6))
ax1.scatter(mu, z_mean[:, 0] * rescale, c='blue', label="Encoded shift")
ax1.scatter(mu, mu, label="Actual shift")
ax1.set_xlabel(r"$\mu$", fontsize=14)
ax1.set_ylabel("Encoded shift", fontsize=14)
ax1.legend()
ax1.grid()
ax2.scatter(sig, z_mean[:, i], c='blue')
ax2.set_xlabel(r"$\sigma$", fontsize=14)
ax2.set_ylabel("Latent variable {}".format(i), fontsize=14)
ax2.grid()