Run the following cells to prepare your working environment:

In [65]:
# %cd /content
# !git clone https://github.com/mhsotoudeh/ProbUNet-Tutorial.git
# # !export PYTHONPATH="${PYTHONPATH}:$PWD/ProbUNet-Tutorial"
# %cd /content/ProbUNet-Tutorial

In [66]:
# !pip install -r requirements.txt

In [137]:
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

from data import *
from model import *

import json

from IPython import display
from tqdm.notebook import tqdm_notebook

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable

import numpy as np
import torch
from torch.utils.data import DataLoader

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [68]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device is {}".format(device))

Device is cuda


## Part 1: Using Hierarchical Probabilistic U-Net for Source Reconstruction

### Tensorboard Session

In [69]:
# %tensorboard --logdir runs/part1

### Train the Model

In [70]:
# !./run.sh

In [150]:
def plot_sample(ex, idx, observations, truths, preds, bounds=None, common_pred_colormap=True, show=False):
    _all = torch.cat([observations[ex], truths[ex], preds[ex]], dim=0)
    observation = observations[ex].squeeze().cpu().numpy()
    truth = truths[ex].squeeze().cpu().numpy()
    preds = preds[ex].cpu().numpy()

    if bounds is None:
        _min, _max = _all.min(), _all.max()
    else:
        _min, _max = bounds
    pred_min, pred_max = _min if common_pred_colormap is True else preds.min(), _max if common_pred_colormap is True else preds.max()
    
    fig, axs = plt.subplots(1, 5, figsize=(25,5))

    fig.suptitle('Training Example {}'.format(ex+1), size=14)
    axs[0].set_title('Observation')
    axs[1].set_title('Ground Truth')
    axs[2].set_title('Prediction {}'.format(idx+1))
    axs[3].set_title('Mean')
    axs[4].set_title('STD')

    im0 = axs[0].imshow(observation, vmin=_min, vmax=_max)
    im1 = axs[1].imshow(truth, vmin=_min, vmax=_max)
    im2 = axs[2].imshow(preds[idx], vmin=pred_min, vmax=pred_max)
    im3 = axs[3].imshow(preds.mean(axis=0), vmin=pred_min, vmax=pred_max)
    im4 = axs[4].imshow(preds.std(axis=0))

    imlist = [im0, im1, im2, im3, im4]
    for i, axi in enumerate(axs.ravel()):
        axi.set_axis_off()

        divider = make_axes_locatable(axi)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(imlist[i], cax=cax, orientation='vertical')

        imlist.append(cax)

    if show is True:
        plt.show()

    plt.close()
    return fig


def animate_samples(ex, observations, truths, preds, bounds=None, common_pred_colormap=True, num=None, output_type='jshtml'):
    if num is not None:
        assert num <= preds.shape[1]
    
    _all = torch.cat([observations[ex], truths[ex], preds[ex]], dim=0)
    observation = observations[ex].squeeze().cpu().numpy()
    truth = truths[ex].squeeze().cpu().numpy()
    preds = preds[ex].cpu().numpy()

    if bounds is None:
        _min, _max = _all.min(), _all.max()
    else:
        _min, _max = bounds
    pred_min, pred_max = _min if common_pred_colormap is True else preds.min(), _max if common_pred_colormap is True else preds.max()
    
    fig, axs = plt.subplots(1, 5, figsize=(21.5,4.3))

    fig.suptitle('Training Example {}'.format(ex+1), size=14)
    axs[0].set_title('Observation')
    axs[1].set_title('Ground Truth')
    axs[2].set_title('Prediction 1')
    axs[3].set_title('Mean')
    axs[4].set_title('STD')

    im0 = axs[0].imshow(observation, vmin=_min, vmax=_max)
    im1 = axs[1].imshow(truth, vmin=_min, vmax=_max)
    im2 = axs[2].imshow(preds[0], vmin=pred_min, vmax=pred_max)
    im3 = axs[3].imshow(preds.mean(axis=0), vmin=pred_min, vmax=pred_max)
    im4 = axs[4].imshow(preds.std(axis=0))

    imlist = [im0, im1, im2, im3, im4]
    for i, axi in enumerate(axs.ravel()):
        axi.set_axis_off()

        divider = make_axes_locatable(axi)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(imlist[i], cax=cax, orientation='vertical')

        imlist.append(cax)

    def animate(i):
        axs[2].set_title('Prediction {}'.format(i+1))
        im2 = axs[2].imshow(preds[i], vmin=pred_min, vmax=pred_max, animated=True)

        return im2,

    tot = num if num is not None else preds.shape[1]    
    plt.tight_layout(pad=3)
    anim = animation.FuncAnimation(fig, animate, frames=tot, interval=100, blit=True, repeat_delay=1000)
    plt.close()

    if output_type == 'video':
        out = anim.to_html5_video()
    elif output_type == 'jshtml':
        out = anim.to_jshtml()

    html = display.HTML(out)
    return html

### Visualize Samples

In [209]:
mpl.rc('image', cmap='hot')

np.random.seed(0)
torch.manual_seed(0)

bs = 32
k = 100  # Num of Predictions per Input

# Choose Model
model_dir = 'runs/part1/16'
model_stamp = '0219-1042_bg_559622_lensing_16_16-32-64-128_geco_lr1e-4_kappa7e-4_update10.0'
model_suffix = ''

# Choose Data File
data_dir = 'utils/32768-12288-2899'
dataset_name = '16_test'

# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device is {}".format(device))

Device is cuda


In [210]:
# Load Model & Loss
model = torch.load('{}/{}/model{}.pth'.format(model_dir, model_stamp, model_suffix), map_location=torch.device(device))
model.eval()

criterion = torch.load('{}/{}/loss{}.pth'.format(model_dir, model_stamp, model_suffix), map_location=torch.device(device))
criterion.eval()


# Load Args
with open('{}/{}/args.json'.format(model_dir, model_stamp), 'r') as f:
    args = json.load(f)


# Load Data
test_data, transdict = prepare_data('{}/{}.npy'.format(data_dir, dataset_name), normalization=None)
test_loader = DataLoader(test_data, batch_size=bs, shuffle=False)

In [211]:
data_iterator = iter(test_loader)

observations, truths = next(data_iterator)
observations, truths = observations.to(device), truths.to(device)

with torch.no_grad():
    preds, infodicts = model(observations, truths, times=k, insert_from_postnet=False)

In [212]:
observations.shape, truths.shape, preds.shape

(torch.Size([32, 1, 16, 16]),
 torch.Size([32, 1, 16, 16]),
 torch.Size([32, 100, 16, 16]))

In [171]:
# plot_sample(12, 17, observations, truths, preds, common_pred_colormap=True)

# 0219-1141_bg_849458_lensing_32_16-32-64-64-64_geco_lr1e-4_kappa3e-4_update1.0

html = animate_samples(12, observations, truths, preds, num=30, output_type='jshtml')
display.display(html)

In [176]:
# ELBO

html = animate_samples(12, observations, truths, preds, num=30, output_type='jshtml')
display.display(html)

In [203]:
# ELBO

html = animate_samples(12, observations, truths, preds, num=30, output_type='jshtml')
display.display(html)

In [213]:
html = animate_samples(12, observations, truths, preds, num=30, output_type='jshtml')
display.display(html)