<font size="+1">
<font color='red'>
<b> IMPORTANT NOTE: </b> 
</font>
Make sure to save a copy of this notebook in your personal drive to maintain the changes you make!
</font>

Run the following cells to prepare your working environment:

In [None]:
#@title Clone Repo & Install Requirements { display-mode: "form" }
%%capture

# Clone Repo
%cd /content
!git clone https://github.com/mhsotoudeh/ProbUNet-Tutorial.git
# !export PYTHONPATH="${PYTHONPATH}:$PWD/ProbUNet-Tutorial"
%cd /content/ProbUNet-Tutorial

# Install Requirements
!pip install -r requirements.txt

In [None]:
#@title Imports { display-mode: "form" }

%load_ext autoreload
%autoreload 2

%load_ext tensorboard

from data import *
from model import *

import json
import os

import gdown

from IPython import display
from tqdm.notebook import tqdm_notebook

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable

import numpy as np
import torch
from torch.utils.data import DataLoader

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device is {}".format(device))

# Part 1: Using Hierarchical Probabilistic U-Net for Source Reconstruction

## Tensorboard Session

In [None]:
# Used to visualize training
# %tensorboard --logdir runs/part1

## Data

Select desired resolution and run the following cell.

In [None]:
#@title Download { display-mode: "form" }
resolution = '16' #@param ['16', '32', '64']

# Create data directory
if not os.path.exists('data'):
    os.makedirs('data')

# Set File ID
if resolution == '16':
    file_id = '1p8uCmPFvC4KWFVSjdUyqafQ7niUHmUDc&confirm=t'
elif resolution == '32':
    file_id = '1KIOwnlnwcwc76G-VFA3Y8-nvpMhDayVr&confirm=t'
elif resolution == '64':
    file_id = '1WhD8JiZ2bty1Pq_T7hg-oX8jB5Q_chx_&confirm=t'

# Download File
gdown.download(id=file_id, output='data/data.zip', quiet=False)

# Unzip File
!unzip data/data.zip -d data/

## Training

In [None]:
# First, set your desired dataset, model architecture and training hypermaramters in a run script
# Then, execute it using the following command to start training :-)

# !./training_scripts/run.sh

## Visualizing Samples

In [None]:
#@title Required Functions { display-mode: "form" }

# function to plot a given prediction (idx) of a given example (ex) in the dataset
def plot_sample(ex, idx, observations, truths, preds, bounds=None, common_pred_colormap=True, show=False):

    # Concatenate the observations, ground truth, and predictions into a single tensor (to calculate common color axis limits)
    _all = torch.cat([observations[ex], truths[ex], preds[ex]], dim=0)
    
    # Convert torch tensors to numpy arrays
    observation = observations[ex].squeeze().cpu().numpy()
    truth = truths[ex].squeeze().cpu().numpy()
    preds = preds[ex].cpu().numpy()

    # Determine color axis limits
    if bounds is None:
        _min, _max = _all.min(), _all.max()
    else:
        _min, _max = bounds
    pred_min, pred_max = _min if common_pred_colormap is True else preds.min(), _max if common_pred_colormap is True else preds.max()

    # Create a figure with five subplots
    fig, axs = plt.subplots(1, 5, figsize=(25,5))

    # Set the titles for each subplot
    fig.suptitle('Training Example {}'.format(ex+1), size=14)
    axs[0].set_title('Observation')
    axs[1].set_title('Ground Truth')
    axs[2].set_title('Prediction {}'.format(idx+1))
    axs[3].set_title('Mean')
    axs[4].set_title('STD')

    # Display the observation, ground truth, prediction, mean and std maps
    im0 = axs[0].imshow(observation, vmin=_min, vmax=_max)
    im1 = axs[1].imshow(truth, vmin=_min, vmax=_max)
    im2 = axs[2].imshow(preds[idx], vmin=pred_min, vmax=pred_max)
    im3 = axs[3].imshow(preds.mean(axis=0), vmin=pred_min, vmax=pred_max)
    im4 = axs[4].imshow(preds.std(axis=0))

    # Create a list of image objects to be used for color bar display
    imlist = [im0, im1, im2, im3, im4]
    
    # For each subplot, turn off the axes, create a new axis for the color bar, and add it to the figure
    for i, axi in enumerate(axs.ravel()):
        axi.set_axis_off()

        divider = make_axes_locatable(axi)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(imlist[i], cax=cax, orientation='vertical')

    # If the show parameter is True, display the plot
    if show is True:
        plt.show()

    # Close the figure and return it
    plt.close()
    return fig


# function to creat an animation of "num" predictions of a given example (ex) in the dataset
def animate_samples(ex, observations, truths, preds, bounds=None, common_pred_colormap=True, num=None, output_type='jshtml'):
    # Make sure the number of predictions to display is less than or equal to the total number of available predictions
    if num is not None:
        assert num <= preds.shape[1]
        
    # Concatenate the observations, ground truth, and predictions into a single tensor (to calculate common color axis limits)
    _all = torch.cat([observations[ex], truths[ex], preds[ex]], dim=0)
    
    # Convert torch tensors to numpy arrays
    observation = observations[ex].squeeze().cpu().numpy()
    truth = truths[ex].squeeze().cpu().numpy()
    preds = preds[ex].cpu().numpy()
    
    # Determine color axis limits
    if bounds is None:
        _min, _max = _all.min(), _all.max()
    else:
        _min, _max = bounds
    pred_min, pred_max = _min if common_pred_colormap is True else preds.min(), _max if common_pred_colormap is True else preds.max()
    
    # Create a figure with five subplots
    fig, axs = plt.subplots(1, 5, figsize=(21.5,4.3))
    
    # Set the titles for each subplot
    fig.suptitle('Training Example {}'.format(ex+1), size=14)
    axs[0].set_title('Observation')
    axs[1].set_title('Ground Truth')
    axs[2].set_title('Prediction 1')
    axs[3].set_title('Mean')
    axs[4].set_title('STD')
    
    # Display the observation, ground truth, prediction, mean and std maps
    im0 = axs[0].imshow(observation, vmin=_min, vmax=_max)
    im1 = axs[1].imshow(truth, vmin=_min, vmax=_max)
    im2 = axs[2].imshow(preds[0], vmin=pred_min, vmax=pred_max)
    im3 = axs[3].imshow(preds.mean(axis=0), vmin=pred_min, vmax=pred_max)
    im4 = axs[4].imshow(preds.std(axis=0), cmap='viridis')

    # Create a list of image objects to be used for color bar display
    imlist = [im0, im1, im2, im3, im4]
    
    # For each subplot, turn off the axes, create a new axis for the color bar, and add it to the figure
    for i, axi in enumerate(axs.ravel()):
        axi.set_axis_off()

        divider = make_axes_locatable(axi)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(imlist[i], cax=cax, orientation='vertical')
    
    # Function to update the prediction subplot for each frame of the animation
    def animate(i):
        axs[2].set_title('Prediction {}'.format(i+1))
        im2 = axs[2].imshow(preds[i], vmin=pred_min, vmax=pred_max, animated=True)

        return im2,
    
    # Set the total number of frames
    frms = num if num is not None else preds.shape[1]    
    
    # Set the padding of the plot
    plt.tight_layout(pad=2)
    
    # Generate animation frames
    anim = animation.FuncAnimation(fig, animate, frames=frms, interval=100, blit=True, repeat_delay=1000)
    
    # Close the figure
    plt.close()
    
    # Genrate an return the animation output
    if output_type == 'video':
        out = anim.to_html5_video()
    elif output_type == 'jshtml':
        out = anim.to_jshtml()

    html = display.HTML(out)
    return html

In [None]:
#@title Set Parameters + Select Model & Dataset { display-mode: "both" }

# Set Default Colormap
mpl.rc('image', cmap='hot')

# Set Random Seed
np.random.seed(0)
torch.manual_seed(0)

# Set Parameters
bs = 128      # batch size
k = 100       # num of predictions per input

# Choose Model
model_dir = 'pretrained_models'
model_stamp = '16_elbo'
model_suffix = ''

# Choose Data File
data_dir = 'data'
dataset_name = '16_test'

In [None]:
#@title Load Model & Data { display-mode: "both" }

# Load Model & Loss
model = torch.load('{}/{}/model{}.pth'.format(model_dir, model_stamp, model_suffix), map_location=torch.device(device))
model.eval()

criterion = torch.load('{}/{}/loss{}.pth'.format(model_dir, model_stamp, model_suffix), map_location=torch.device(device))
criterion.eval()


# Load Args
with open('{}/{}/args.json'.format(model_dir, model_stamp), 'r') as f:
    args = json.load(f)


# Load Data
test_data, transdict = prepare_data('{}/{}.npy'.format(data_dir, dataset_name), normalization=None)
test_loader = DataLoader(test_data, batch_size=bs, shuffle=False)

In [None]:
#@title Generate a Batch of Predictions { display-mode: "both" }

data_iterator = iter(test_loader)

observations, truths = next(data_iterator)
observations, truths = observations.to(device), truths.to(device)

with torch.no_grad():
    preds_prior, infodicts_prior = model(observations, truths, times=k, insert_from_postnet=False)
    preds_post, infodicts_post = model(observations, truths, times=k, insert_from_postnet=True)

# Check Shapes
observations.shape, truths.shape, preds_prior.shape, preds_post.shape

In [None]:
#@title Visualize Predictions (Using PriorNet Latents) { display-mode: "both" }

# plot_sample(ex=12, idx=17,
#             observations=observations, truths=truths, preds=preds_prior,
#             common_pred_colormap=True)


## Specify an arbitrary ex in range (0,127)
html = animate_samples(ex=12,
                       observations=observations, truths=truths, preds=preds_prior, num=30,
                       output_type='jshtml')
display.display(html)

In [None]:
#@title Visualize Predictions (Using PosteriorNet Latents) { display-mode: "both" }

# plot_sample(ex=12, idx=17,
#             observations=observations, truths=truths, preds=preds_post,
#             common_pred_colormap=True)


## Specify an arbitrary ex in range (0,127)
html = animate_samples(ex=12,
                       observations=observations, truths=truths, preds=preds_post, num=30,
                       output_type='jshtml')
display.display(html)