<font size="+1">
<font color='red'>
<b> IMPORTANT NOTE: </b> 
</font>
Make sure to save a copy of this notebook in your personal drive to maintain the changes you make!
</font>

Run the following cells to prepare your working environment:

In [None]:
#@title Clone Repo & Install Requirements { display-mode: "form" }
%%capture

# Clone Repo
%cd /content
!git clone https://github.com/mhsotoudeh/ProbUNet-Tutorial.git
# !export PYTHONPATH="${PYTHONPATH}:$PWD/ProbUNet-Tutorial"
%cd /content/ProbUNet-Tutorial

# Install Requirements
!pip install -r requirements.txt

In [None]:
#@title Imports { display-mode: "form" }

%load_ext autoreload
%autoreload 2

%load_ext tensorboard

from model import *
from train import *

from random import randrange

from dotmap import DotMap
from IPython import display
from tqdm.notebook import tqdm_notebook

import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.animation as animation

import numpy as np
from sklearn.decomposition import PCA
import torch
from torch.utils.tensorboard import SummaryWriter

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device is {}".format(device))

# Part 2: Rescue the Randomness (Loss Functions)

## Required Functions

In [None]:
#@title Initialization { display-mode: "form" }
# Initialization

def initialize(args):
    # Set Random Seed
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    # Initialize Model
    model = HPUNet( in_ch=args.in_ch, out_ch=args.out_ch, chs=args.intermediate_ch,
                    latent_num=args.latent_num, latent_channels=args.latent_chs, latent_locks=args.latent_locks,
                    scale_depth=args.scale_depth, kernel_size=args.kernel_size, dilation=args.dilation,
                    padding_mode=args.padding_mode, conv_dim=1 )

    args.trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    model.to(device)


    # Set Loss Function

    ## Reconstruction Loss
    if args.rec_type.lower() == 'mse':
        reconstruction_loss = MSELossWrapper()

    else:
        print('Invalid reconstruction loss type, exiting...')
        exit()


    ## Total Loss
    if args.loss_type.lower() == 'elbo':
        if args.beta_asc_steps is None:
            beta_scheduler = BetaConstant(args.beta)
        else:
            beta_scheduler = BetaLinearScheduler(ascending_steps=args.beta_asc_steps, constant_steps=args.beta_cons_steps, max_beta=args.beta, saturation_step=args.beta_saturation_step)
        criterion = ELBOLoss(reconstruction_loss=reconstruction_loss, beta=beta_scheduler).to(device)

    elif args.loss_type.lower() == 'geco':
        kappa = args.kappa
        if args.kappa_px is True:
            kappa *= n
        criterion = GECOLoss(reconstruction_loss=reconstruction_loss, kappa=kappa, decay=args.decay, update_rate=args.update_rate, device=device).to(device)

    else:
        print('Invalid loss type, exiting...')
        exit()


    # Set Optimizer
    if args.optimizer == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.wd)

    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    elif args.optimizer == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)

    else:
        print('Optimizer not known, exiting...')
        exit()


    # Set LR Scheduler
    if args.scheduler_type == 'cons':
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs)

    elif args.scheduler_type == 'step':
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)

    elif args.scheduler_type == 'milestones':
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.scheduler_milestones, gamma=args.scheduler_gamma)

    return model, criterion, optimizer, lr_scheduler

In [None]:
#@title Training { display-mode: "form" }
# Training

def record_history(idx, loss_dict, type='train'):
    prefix = 'Minibatch Training ' if type == 'train' else 'Mean Validation '

    loss_per_pixel = loss_dict['loss'].item() / args.pixels
    reconstruction_per_pixel = loss_dict['reconstruction_term'].item() / args.pixels
    kl_term_per_pixel = loss_dict['kl_term'].item() / args.pixels
    kl_per_pixel = [ loss_dict['kls'][v].item() / args.pixels for v in range(args.latent_num) ]

    # Total Loss
    _dict = {   'total': loss_per_pixel,
                'kl term': kl_term_per_pixel, 
                'reconstruction': reconstruction_per_pixel  }
    writer.add_scalars(prefix + 'Loss Curve', _dict, idx)

    # KL Term Decomposition
    _dict = { 'sum': sum(kl_per_pixel) }
    _dict.update( { 'scale {}'.format(v): kl_per_pixel[v] for v in range(args.latent_num) } )
    writer.add_scalars(prefix + 'Loss Curve (K-L)', _dict, idx)

    # Coefficients
    if type == 'train':
        if args.loss_type.lower() == 'elbo':
            writer.add_scalar('Beta', criterion.beta_scheduler.beta, idx)
        elif args.loss_type.lower() == 'geco':
            lamda = criterion.log_inv_function(criterion.log_lamda).item()
            writer.add_scalar('Lagrange Multiplier', lamda, idx)
            writer.add_scalar('Beta', 1/(lamda+1e-20), idx)


def train():
    for e in tqdm_notebook(range(args.epochs)):
        # Initialization
        criterion.train()
        model.train()
        model.zero_grad()
        
        # Train One Step
        
        ## Generate Truths
        p = torch.randperm(32, device=device)           # Generate a random permutation to permute training data at each iteration
        noise = torch.randn(n,1,n, device=device)*sigma
        unrolled_truths = raw_truths + noise
        truths = random_roll(unrolled_truths)[p]
        
        ## Get Predictions and Prepare for Loss Calculation
        if args.rec_type.lower() == 'mse':
            preds, infodicts = model(inputs[p], truths)
            preds, infodict = preds[:,0], infodicts[0]

        truths = truths.squeeze(dim=1)
        
        
        ## Calculate Loss
        loss = criterion(preds, truths, kls=infodict['kls'], lr=lr_scheduler.get_last_lr()[0])


        ## Backpropagate
        loss.backward()             # Calculate Gradients
        optimizer.step()            # Update Weights


        ## Step Beta Scheduler
        if args.loss_type.lower() == 'elbo':
            criterion.beta_scheduler.step()


        # Record Train History
        loss_dict = criterion.last_loss.copy()
        loss_dict.update( { 'kls': infodict['kls'] } )

        record_history(e, loss_dict)


        # Validation
        if (e+1) % args.val_period == 0:
            criterion.eval()
            model.eval()

            noise = torch.randn(n,1,n, device=device)*sigma
            unrolled_truths = raw_truths + noise
            truths = random_roll(unrolled_truths)

            with torch.no_grad():
                preds, infodicts = model(inputs, truths, times=args.k, insert_from_postnet=False)

            fig = plot_latents_or_pca(infodicts, scale=-1, step=e+1)
            # fig.tight_layout()
            writer.add_figure('Evolution of Latents', fig, e+1)
            

    # Save Model & Loss After Training is Done
    torch.save(model, 'runs/part2/{}/model.pth'.format(stamp))
    torch.save(criterion, 'runs/part2/{}/loss.pth'.format(stamp))

In [None]:
#@title Evaluation { display-mode: "form" }
# Evaluation

# colors = ['#ff0000', '#ff3100', '#ff6100', '#ff9200', '#ffc200', '#fef200', '#daff00', '#aaff00', '#79ff00', '#49ff00', '#18ff00', '#00ff18', '#00ff49', '#00ff7a', '#00ffaa', '#00ffdb', '#00f3ff', '#00c2ff', '#0091ff', '#0061ff', '#0030ff', '#0505ff', '#3100ff', '#6100ff', '#9200ff', '#c300ff', '#f300ff', '#ff00da', '#ff00aa', '#ff0079', '#ff0048', '#ff0018']

# function to sample n colors from the hsv colormap (to assign each training example a color)
def get_colors(n):
    cmap = cm.get_cmap('hsv', n)
    color_list = [mpl.colors.rgb2hex(cmap(i)[:3]) for i in range(cmap.N)]

    return color_list


# function to plot components (dim = 2) or two given components (dim > 2) of the latent space at a given scale against each other
def plot_latents(infodicts, scale, step, comp_pair=None, show=False):
    k = len(infodicts)

    prior_latents, post_latents = [], []
    for i in range(k):
        prior_latents.append( infodicts[i]['prior_latents'][scale].squeeze().cpu().numpy() )
        post_latents.append( infodicts[i]['post_latents'][scale].squeeze().cpu().numpy() )

    prior_latents, post_latents = np.stack(prior_latents), np.stack(post_latents)
    
    dim = prior_latents.shape[-1]
    assert dim >= 2
    if dim > 2:  # if dim > 2, the two components given by comp_pair argument will be plotted
        assert comp_pair is not None
        assert len(comp_pair) == 2
        prior_latents = prior_latents[:,:,comp_pair]
        post_latents = post_latents[:,:,comp_pair]

    # Create a figure with two subplots (one for PriorNet latents and another for PosteriorNet latents)
    fig, axs = plt.subplots(1, 2, figsize=(12,6))

    # Set the titles for the figure and each subplot
    fig.suptitle('Latent Space Samples of Scale {} at Step {} (dim = {})'.format(scale, step, dim) + (' / Components: {}'.format(comp_pair) if dim > 2 else ''), size=14)
    axs[0].set_title('PriorNet Latents')
    axs[1].set_title('PoteriorNet Latents')

    # For each subplot, turn off the axes
    [axi.set_axis_off() for axi in axs.ravel()]

    for i in range(k):
        axs[0].scatter(prior_latents[i,:,0], prior_latents[i,:,1], c=colors, alpha=0.7, s=20)
        axs[1].scatter(post_latents[i,:,0], post_latents[i,:,1], c=colors, alpha=0.7, s=20)

    # If the show parameter is True, display the plot
    if show is True:
        plt.show()

    # Close the figure and return it
    plt.close()
    return fig


# function to plot components (dim = 2) or the first two principal components (dim > 2) of the latent space at a given scale against each other
def plot_latents_or_pca(infodicts, scale, step, show=False):
    k = len(infodicts)

    prior_latents, post_latents = [], []
    for i in range(k):
        prior_latents.append( infodicts[i]['prior_latents'][scale].squeeze().cpu().numpy() )
        post_latents.append( infodicts[i]['post_latents'][scale].squeeze().cpu().numpy() )

    prior_latents, post_latents = np.stack(prior_latents), np.stack(post_latents)
    
    dim = prior_latents.shape[-1]
    assert dim >= 2
    if dim > 2:  # if dim > 2, the first two principal components will be plotted
        pca = PCA(n_components=2)
        prior_latents = pca.fit_transform(prior_latents.reshape(-1,dim)).reshape(k,n,2)
        post_latents = pca.fit_transform(post_latents.reshape(-1,dim)).reshape(k,n,2)

    # Create a figure with two subplots (one for PriorNet latents and another for PosteriorNet latents)
    fig, axs = plt.subplots(1, 2, figsize=(12,6))

    # Set the titles for the figure and each subplot
    fig.suptitle('Latent Space Samples of Scale {} at Step {} (dim = {})'.format(scale, step, dim) + (' / First Two PCs' if dim > 2 else ''), size=14)
    axs[0].set_title('PriorNet Latents')
    axs[1].set_title('PoteriorNet Latents')

    # For each subplot, turn off the axes
    [axi.set_axis_off() for axi in axs.ravel()]

    for i in range(k):
        axs[0].scatter(prior_latents[i,:,0], prior_latents[i,:,1], c=colors, alpha=0.7, s=20)
        axs[1].scatter(post_latents[i,:,0], post_latents[i,:,1], c=colors, alpha=0.7, s=20)

    # If the show parameter is True, display the plot
    if show is True:
        plt.show()

    # Close the figure and return it
    plt.close()
    return fig


# function to plot a given prediction (idx) of a given example (ex) in the dataset
def plot_sample(idx, inputs, truths, preds, bounds=None, common_pred_colormap=True, show=False):
    
    # Concatenate the observations, ground truth, and predictions into a single tensor (to calculate common color axis limits)
    _all = torch.cat([inputs, truths, preds], dim=1)
    
    # Convert torch tensors to numpy arrays
    inputs = inputs.squeeze().cpu().numpy()
    truths = truths.squeeze().cpu().numpy()
    preds = preds.cpu().numpy()

    # Determine color axis limits
    if bounds is None:
        _min, _max = _all.min(), _all.max()
    else:
        _min, _max = bounds
    pred_min, pred_max = _min if common_pred_colormap is True else preds.min(), _max if common_pred_colormap is True else preds.max()
    
    # Create a figure with five subplots
    fig, axs = plt.subplots(1, 5, figsize=(25,5))

    # Set the titles for each subplot
    # fig.suptitle('Training Dataset (each row is a training example)', size=14)
    axs[0].set_title('Inputs')
    axs[1].set_title('Ground Truths')
    axs[2].set_title('Outputs {}'.format(idx+1))
    axs[3].set_title('Means')
    axs[4].set_title('STDs')

    # Display the observation, ground truth, prediction, mean and std maps
    im0 = axs[0].imshow(inputs, vmin=_min, vmax=_max)
    im1 = axs[1].imshow(truths, vmin=_min, vmax=_max)
    im2 = axs[2].imshow(preds[:,idx], vmin=pred_min, vmax=pred_max)
    im3 = axs[3].imshow(preds.mean(axis=1), vmin=pred_min, vmax=pred_max)
    im4 = axs[4].imshow(preds.std(axis=1), cmap='magma')

    # Create a list of image objects to be used for color bar display
    imlist = [im0, im1, im2, im3, im4]
    
    # For each subplot, turn off the axes, create a new axis for the color bar, and add it to the figure
    for i, axi in enumerate(axs.ravel()):
        axi.set_axis_off()

        divider = make_axes_locatable(axi)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(imlist[i], cax=cax, orientation='vertical')

        imlist.append(cax)

    # If the show parameter is True, display the plot
    if show is True:
        plt.show()

    # Close the figure and return it
    plt.close()
    return fig


# function to creat an animation of "num" predictions of a given example (ex) in the dataset
def animate_samples(inputs, truths, preds, bounds=None, common_pred_colormap=True, num=None, output_type='jshtml'):
    # Make sure the number of predictions to display is less than or equal to the total number of available predictions
    if num is not None:
        assert num <= preds.shape[1]
    
    # Concatenate the observations, ground truth, and predictions into a single tensor (to calculate common color axis limits)
    _all = torch.cat([inputs, truths, preds], dim=1)
    
    # Convert torch tensors to numpy arrays
    inputs = inputs.squeeze().cpu().numpy()
    truths = truths.squeeze().cpu().numpy()
    preds = preds.cpu().numpy()

    # Determine color axis limits
    if bounds is None:
        _min, _max = _all.min(), _all.max()
    else:
        _min, _max = bounds
    pred_min, pred_max = _min if common_pred_colormap is True else preds.min(), _max if common_pred_colormap is True else preds.max()
    
    # Create a figure with five subplots
    fig, axs = plt.subplots(1, 5, figsize=(21.5,4.3))

    # Set the titles for each subplot
    # fig.suptitle('Training Dataset (each row is a training example)', size=14)
    axs[0].set_title('Inputs')
    axs[1].set_title('Ground Truths')
    axs[2].set_title('Outputs 1')
    axs[3].set_title('Means')
    axs[4].set_title('STDs')

    # Display the observation, ground truth, prediction, mean and std maps
    im0 = axs[0].imshow(inputs, vmin=_min, vmax=_max)
    im1 = axs[1].imshow(truths, vmin=_min, vmax=_max)
    im2 = axs[2].imshow(preds[:,0], vmin=pred_min, vmax=pred_max)
    im3 = axs[3].imshow(preds.mean(axis=1), vmin=pred_min, vmax=pred_max)
    im4 = axs[4].imshow(preds.std(axis=1), cmap='magma')

    # Create a list of image objects to be used for color bar display
    imlist = [im0, im1, im2, im3, im4]
    
    # For each subplot, turn off the axes, create a new axis for the color bar, and add it to the figure
    for i, axi in enumerate(axs.ravel()):
        axi.set_axis_off()

        divider = make_axes_locatable(axi)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(imlist[i], cax=cax, orientation='vertical')

        imlist.append(cax)

    # Function to update the prediction subplot for each frame of the animation
    def animate(i):
        axs[2].set_title('Outputs {}'.format(i+1))
        im2 = axs[2].imshow(preds[:,i], vmin=pred_min, vmax=pred_max, animated=True) 

        return im2,

    # Set the total number of frames
    frms = num if num is not None else preds.shape[1]    
    
    # Set the padding of the plot
    plt.tight_layout(pad=2)
    
    # Generate animation frames
    anim = animation.FuncAnimation(fig, animate, frames=frms, interval=100, blit=True, repeat_delay=1000)
    
    # Close the figure
    plt.close()

    # Genrate an return the animation output
    if output_type == 'video':
        out = anim.to_html5_video()
    elif output_type == 'jshtml':
        out = anim.to_jshtml()

    html = display.HTML(out)
    return html

## Data

In [None]:
def random_roll(a):
    n = a.shape[0]

    r = torch.rand(n)
    rolls = torch.zeros_like(r, dtype=torch.int8)
    rolls[r < 0.1] = -2
    rolls[(0.1 < r) & (r < 0.3)] = -1
    rolls[(0.3 < r) & (r < 0.7)] = 0
    rolls[(0.7 < r) & (r < 0.9)] = 1
    rolls[0.9 < r] = 2

    output = list(map(torch.roll, torch.unbind(a, dim=0), rolls.numpy()))
    return torch.stack(output, dim=0)

In [None]:
# Dataset Parameters

n = 32                        # Dataset Size
sigma = 0.1                   # Noise Level

# assign the color of each training example
colors = get_colors(n)

In [None]:
# Create the Dataset

a = torch.arange(n, device=device)

inputs = torch.zeros(n,n, device=device)
inputs[torch.arange(n), a] = 1

raw_truths = torch.flip(inputs, dims=(1,))

inputs, raw_truths = inputs.unsqueeze(dim=1), raw_truths.unsqueeze(dim=1)

In [None]:
# Generate a Realization of Truths

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

In [None]:
# Visualize the Dataset

fig, axs = plt.subplots(1, 4, figsize=(24,6))

fig.suptitle('Training Dataset (each row is a training example)', size=14)

axs[0].set_title('Input')
axs[1].set_title('Raw Truth (without noise and unrolled)')
axs[2].set_title('Unrolled Truth (noise added)')
axs[3].set_title('Truth')

axs[0].imshow(inputs.squeeze().cpu().numpy())
axs[1].imshow(raw_truths.squeeze().cpu().numpy())
axs[2].imshow(unrolled_truths.squeeze().cpu().numpy())
axs[3].imshow(truths.squeeze().cpu().numpy())

[axi.set_axis_off() for axi in axs.ravel()]

plt.show()

In [None]:
inputs.shape, raw_truths.shape, unrolled_truths.shape, truths.shape

## Model

In [None]:
# Constant Args
args = DotMap()

args.random_seed = 42

## Data
args.pixels = n

## Model
args.in_ch, args.out_ch = 1, 1
args.intermediate_ch = [4, 8, 8, 16, 16]
args.kernel_size = [3, 3, 3, 3, 3]
args.scale_depth, args.dilation = [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]
args.padding_mode = 'zeros' 

args.latent_num = 5
args.latent_chs = [1, 1, 1, 1, 1]
args.latent_locks = [False, False, False, False, False]

## Training
args.epochs = 8000
args.optimizer = 'adamax'
args.wd = 1e-5
args.lr = 5e-4
args.scheduler_type = 'cons'

## Validation
args.val_period = 1000
args.k = 100

## Tensorboard Session

In [None]:
%tensorboard --logdir runs/part2

## Experiments

#### **ELBO** - Constant Beta ($\beta = 1$)

In [None]:
# Stamp
stamp = 'elbo / beta = 1'

# Initialize SummaryWriter (for tensorboard)
writer = SummaryWriter('runs/part2/{}/tb'.format(stamp))

# Variable Args
## Loss
args.rec_type = 'mse'
args.loss_type = 'elbo'
args.beta = 1.0
args.beta_asc_steps = None

In [None]:
# Initialize and Train the Model

model, criterion, optimizer, lr_scheduler = initialize(args)
train()

In [None]:
# Generate Samples for Evaluation

k = 100                        # Num of samples per training example

model.eval()
criterion.eval()

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

with torch.no_grad():
    preds, infodicts = model(inputs, truths, times=k, insert_from_postnet=False)

In [None]:
# Plot or Animate Samples

# plot_sample(80, inputs, truths, preds)

html = animate_samples(inputs, truths, preds, bounds=(-0.2,1.2), num=30)
display.display(html)

In [None]:
# Visualize Latent Representations

plot_latents_or_pca(infodicts, scale=4, step=args.epochs)

#### **ELBO** - Constant Beta ($\beta = 0.7$)

In [None]:
# Stamp
stamp = 'elbo / beta = 0.7'

# Initialize SummaryWriter (for tensorboard)
writer = SummaryWriter('runs/part2/{}/tb'.format(stamp))

# Variable Args
## Loss
args.rec_type = 'mse'
args.loss_type = 'elbo'
args.beta = 0.1
args.beta_asc_steps = None

In [None]:
# Initialize and Train the Model

model, criterion, optimizer, lr_scheduler = initialize(args)
train()

In [None]:
# Generate Samples for Evaluation

k = 100                        # Num of samples per training example

model.eval()
criterion.eval()

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

with torch.no_grad():
    preds, infodicts = model(inputs, truths, times=k, insert_from_postnet=False)

In [None]:
# Plot or Animate Samples

# plot_sample(80, inputs, truths, preds)

html = animate_samples(inputs, truths, preds, bounds=(-0.2,1.2), num=30)
display.display(html)

In [None]:
# Visualize Latent Representations

plot_latents_or_pca(infodicts, scale=4, step=args.epochs)

#### **ELBO** - Linear Beta

In [None]:
# Stamp
stamp = 'elbo / linear beta'

# Initialize SummaryWriter (for tensorboard)
writer = SummaryWriter('runs/part2/{}/tb'.format(stamp))

# Variable Args
## Loss
args.rec_type = 'mse'
args.loss_type = 'elbo'
args.beta = 1.0
args.beta_asc_steps = 8000
args.beta_cons_steps = None
args.beta_saturation_step = 8000

In [None]:
# Initialize and Train the Model

model, criterion, optimizer, lr_scheduler = initialize(args)
train()

In [None]:
# Generate Samples for Evaluation

k = 100                        # Num of samples per training example

model.eval()
criterion.eval()

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

with torch.no_grad():
    preds, infodicts = model(inputs, truths, times=k, insert_from_postnet=False)

In [None]:
# Plot or Animate Samples

# plot_sample(80, inputs, truths, preds)

html = animate_samples(inputs, truths, preds, bounds=(-0.2,1.2), num=30)
display.display(html)

In [None]:
# Visualize Latent Representations

plot_latents_or_pca(infodicts, scale=4, step=args.epochs)

#### **ELBO** - Cyclical Beta

In [None]:
# Stamp
stamp = 'elbo / cyclical beta'

# Initialize SummaryWriter (for tensorboard)
writer = SummaryWriter('runs/part2/{}/tb'.format(stamp))

# Variable Args
## Loss
args.rec_type = 'mse'
args.loss_type = 'elbo'
args.beta = 1.0
args.beta_asc_steps = 100
args.beta_cons_steps = 100
args.beta_saturation_step = args.epochs

In [None]:
# Initialize and Train the Model

model, criterion, optimizer, lr_scheduler = initialize(args)
train()

In [None]:
# Generate Samples for Evaluation

k = 100                        # Num of samples per training example

model.eval()
criterion.eval()

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

with torch.no_grad():
    preds, infodicts = model(inputs, truths, times=k, insert_from_postnet=False)

In [None]:
# Plot or Animate Samples

# plot_sample(80, inputs, truths, preds)

html = animate_samples(inputs, truths, preds, bounds=(-0.2,1.2), num=30)
display.display(html)

In [None]:
# Visualize Latent Representations

plot_latents_or_pca(infodicts, scale=4, step=args.epochs)

#### **ELBO** - Constant Beta ($\beta = 0$)

In [None]:
# Stamp
stamp = 'elbo / beta = 0'

# Initialize SummaryWriter (for tensorboard)
writer = SummaryWriter('runs/part2/{}/tb'.format(stamp))

# Variable Args
## Loss
args.rec_type = 'mse'
args.loss_type = 'elbo'
args.beta = 0.0
args.beta_asc_steps = None

In [None]:
# Initialize and Train the Model

model, criterion, optimizer, lr_scheduler = initialize(args)
train()

In [None]:
# Generate Samples for Evaluation

k = 100                        # Num of samples per training example

model.eval()
criterion.eval()

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

with torch.no_grad():
    preds, infodicts = model(inputs, truths, times=k, insert_from_postnet=False)

In [None]:
# Plot or Animate Samples

# plot_sample(80, inputs, truths, preds)

html = animate_samples(inputs, truths, preds, bounds=(-0.2,1.2), num=30)
display.display(html)

In [None]:
# Visualize Latent Representations

plot_latents_or_pca(infodicts, scale=4, step=args.epochs)

#### **GECO**

In [None]:
# Stamp
stamp = 'geco / kappa 0.01'

# Initialize SummaryWriter (for tensorboard)
writer = SummaryWriter('runs/part2/{}/tb'.format(stamp))

# Variable Args
## Loss
args.rec_type = 'mse'
args.loss_type = 'geco'
args.kappa = 0.01
args.kappa_px = True
args.decay = 0.9
args.update_rate = 1.0

In [None]:
# Initialize and Train the Model

model, criterion, optimizer, lr_scheduler = initialize(args)
train()

In [None]:
# Generate Samples for Evaluation

k = 100                        # Num of samples per training example

model.eval()
criterion.eval()

noise = torch.randn(n,1,n, device=device)*sigma
unrolled_truths = raw_truths + noise
truths = random_roll(unrolled_truths)

with torch.no_grad():
    preds, infodicts = model(inputs, truths, times=k, insert_from_postnet=False)

In [None]:
# Plot or Animate Samples

# plot_sample(80, inputs, truths, preds)

html = animate_samples(inputs, truths, preds, bounds=(-0.2,1.2), num=30)
display.display(html)

In [None]:
# Visualize Latent Representations

plot_latents_or_pca(infodicts, scale=4, step=args.epochs)