### MINIMAL EXAMPLE OF JUPYTER NOTEBOOK WHICH CAN BE RUN WITH CROMWELL

### The problem is that it produces run2.html and trial_v1_movie_rec.gif in local folder

### IMPORT NECESSARY MODULES

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import logging
import sys
from datetime import datetime
import moviepy.editor as mpy
import numpy as np
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1" #for debugging, it decrease performance dramatically

In [4]:
%matplotlib inline  
#%matplotlib notebook
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML, Image

In [5]:
from utilities import *
from model import * 

import torch
import pyro

# Set up pyro environment
pyro.clear_param_store()
pyro.set_rng_seed(0)

# Check versions
print("pyro.__version__  --> ",pyro.__version__)
print("torch.__version__ --> ",torch.__version__)
assert(pyro.__version__.startswith('0.4'))
assert(torch.__version__.startswith('1.2'))

pyro.__version__  -->  0.4.0
torch.__version__ -->  1.2.0


### read jason file

In [6]:
params = load_json_as_dict("./input_params.json")  
print(params)

{'use_local_machine': True, 'cloud': {'__comment': 'cloud parameters', 'VM_image': 'us.gcr.io/broad-dsde-methods/pyro@sha256:4c4745a22762852cf14263c537f645182df3557a1163527a3aaeca7e5da37b4225', 'GPU_type': 'tesla_k80', 'results_folder': 'gs://ld-results-bucket/fashionmnist_results', 'train_dataset': 'gs://ld-data-bucket/data/fashionmnist_train.pkl', 'test_dataset': 'gs://ld-data-bucket/data/fashionmnist_test.pkl'}, 'simulation': {'__comment': 'there are 3 types of runs: scratch resume pretrain', 'name': 'trial_v1', 'type': 'scratch', 'path_to_file': None}, 'architecture': {'__comment': 'parameters specifying the architecture of the model', 'dim_zwhat': 25, 'width_input_image': 28, 'ch_input_image': 1}, 'loss': {'__comment': 'parameter of the observation model', 'mse_sigma': 0.1}, 'optimizer': {'__comment': 'which optimizer to use', 'type': 'adam', 'lr': 0.001, 'betas': [0.9, 0.999], 'eps': 1e-08}, 'training': {'__comment': 'parameter of the observation model', 'EPOCHS': 1, 'TEST_FREQUE

### prepare the file names

In [7]:
if params["use_local_machine"]:
            
    # prepare input file
    train_tmp = params["cloud"]["train_dataset"].split("/")
    train_file = os.path.join(".",train_tmp[-2],train_tmp[-1])
    test_tmp = params["cloud"]["test_dataset"].split("/")
    test_file = os.path.join(".",test_tmp[-2],test_tmp[-1])
    
    # prepare output file
    simulation_name = params["simulation"]["name"]
    output_dir = os.path.basename(params["cloud"]["results_folder"])
    json_param_file = os.path.join(".",output_dir, "input_params.json")
    log_file = os.path.join(".",output_dir, str(simulation_name) + ".log")
    
else:
    
    raise Exception
    # prepare input file
    train_file = params["cloud"]["train_dataset"]
    test_file = params["cloud"]["test_dataset"]
    
    # prepare output file
    simulation_name = params["simulation"]["name"]
    result_dir = params["cloud"]["results_folder"]
    json_param_file = os.path.join(output_dir, "input_params.json")
    log_file = os.path.join(output_dir, str(simulation_name) + ".log")
    
print(train_file)
print(test_file)
print(log_file)
print(json_param_file)

./fashionmnist_results/input_params.json
./data/fashionmnist_train.pkl
./data/fashionmnist_test.pkl
./fashionmnist_results/trial_v1.log
./fashionmnist_results/input_params.json


### start logging some data

In [None]:
logging.basicConfig(level=logging.INFO,
                    format="luca_logging: %(message)s",
                    filename=log_file,
                    filemode="w")
console = logging.StreamHandler()
formatter = logging.Formatter("luca_logging: %(message)s")
console.setFormatter(formatter)  # Use the same format for stdout.
logging.getLogger('').addHandler(console)  # Log to stdout and a file.

# Log the start time.
logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

In [8]:
logging.info("saving input json in output directory")
save_dict_as_json(params,json_param_file)

### get the data

In [None]:
logging.info("get the data")
train_dataset = DatasetInMemory(train_file,use_cuda=torch.cuda.is_available())
test_dataset  = DatasetInMemory(test_file,use_cuda=torch.cuda.is_available())

In [None]:
train_dataset.check()

In [None]:
test_dataset.check()

### Instantiate model and optimizer

In [None]:
logging.info("Instantiate model and optimizer")
vae = VaeClass(params)
optimizer = instantiate_optimizer(vae, params)

### There are 3 possible simulation types: scratch, resumed, pretrained

In [None]:
logging.info("simulation type = "+str(params["simulation"]["type"]))

if (params["simulation"]["type"] == "scratch"):
    
    epoch_restart = -1
    history_dict = {}
    min_test_loss = 99999999

elif (params["simulation"]["type"] == "resumed"):
        
    resumed = load_info(path=params["simulation"]["path_to_file"], 
                        load_epoch=True, 
                        load_history=True)
    epoch_restart = resumed.epoch
    history_dict = resumed.history_dict
    min_test_loss = min(history_dict["test_loss"])
    
    load_model_optimizer(path=params["simulation"]["path_to_file"], 
                         model=vae,
                         optimizer=optimizer)

elif (params["simulation"]["type"] == "pretrained"):
       
    epoch_restart = -1
    history_dict = {}
    min_test_loss = 99999999
    
    load_model_optimizer(path=params["simulation"]["path_to_file"], 
                         model=vae,
                         optimizer=None)
    
# instantiate the scheduler if necessary    
if params["training"]["scheduler_is_active"]:
    scheduler = instantiate_scheduler(optimizer, params)

### Train loop

In [None]:
TEST_FREQUENCY = params["training"]["TEST_FREQUENCY"]
CHECKPOINT_FREQUENCY = params["training"]["CHECKPOINT_FREQUENCY"]
NUM_EPOCHS = params["training"]["EPOCHS"]
BATCH_SIZE = params["training"]["batch_size"]

logging.info("start training -> "+datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 
try:
    for delta_epoch in range(1,NUM_EPOCHS+1):
        epoch = delta_epoch+epoch_restart
        vae.train()   
        
        #with torch.autograd.set_detect_anomaly(True):
        with torch.autograd.set_detect_anomaly(False):
            train_metrics = train_one_epoch(vae, 
                                            train_dataset, 
                                            optimizer, 
                                            BATCH_SIZE, 
                                            verbose=False, 
                                            weight_clipper=None)
            s = pretty_print_metrics(epoch, train_metrics, is_train=True)
            logging.info(s)
                
            
            history_dict = add_named_tuple_to_dictionary(namedtuple=train_metrics, 
                                                         dictionary=history_dict,
                                                         key_prefix="train_")
            
        if params["training"]["scheduler_is_active"]:
            scheduler.step(epoch=epoch)
        
        if(epoch % TEST_FREQUENCY == 0):
            vae.eval()
            test_metrics = train_one_epoch(vae, 
                                           test_dataset, 
                                           optimizer, 
                                           BATCH_SIZE, 
                                           verbose=False, 
                                           weight_clipper=None)
            
            s = pretty_print_metrics(epoch, test_metrics, is_train=False)
            logging.info(s)
                    
            history_dict = add_named_tuple_to_dictionary(namedtuple=test_metrics, 
                                                         dictionary=history_dict,
                                                         key_prefix="test_")
            
            test_loss = test_metrics["loss"]
            min_test_loss = min(min_test_loss, test_loss)
                
            #if((test_loss == min_test_loss) or ((epoch % CHECKPOINT_FREQUENCY) == 0)): 
            if((test_loss == min_test_loss) or ((epoch % TEST_FREQUENCY) == 0)):
                checkpoint_file = os.path.join(output_dir, simulation_name+"_ckp_"+str(epoch)+".pkl")
                history_file = os.path.join(output_dir, simulation_name+"_history_"+str(epoch)+".pkl")
                
                save_everything(model=vae, 
                                optimizer=optimizer, 
                                history_dict=history_dict, 
                                epoch=epoch, 
                                params_dict=params, 
                                path=checkpoint_file)
                
                save_dict_as_json(history_dict, path=history_file)
                logging.info("saved files -> "+checkpoint_file+"  "+history_file)
                
    logging.info("end training -> "+datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 

except KeyboardInterrupt:
    logging.info("Keyboard interrupt.  Terminated without saving.\n")

# Check the results

In [None]:
history_dict.keys()

In [None]:
for k,v in history_dict.items():
    print(k," -->", history_dict[k][-3:])

In [None]:
#plt.yscale('log')
y_shift=0
x_shift=0
sign=1

fontsize=10
fig, ax = plt.subplots(1,1)
ax.set_xlabel('REC',fontsize=fontsize)
ax.set_ylabel('REG',fontsize=fontsize)

ax.plot(np.arange(x_shift, x_shift+len(history_dict["train_loss"])), sign*np.array(history_dict["train_loss"])+y_shift,'-')
ax.plot(np.arange(x_shift, x_shift+len(history_dict["test_loss"])*TEST_FREQUENCY,TEST_FREQUENCY), sign*np.array(history_dict["test_loss"])+y_shift, '.--')
ax.set_xlabel('epoch')
ax.set_ylabel('LOSS = - ELBO')
ax.set_title('Training procedure')
ax.grid(True)
ax.legend(['train', 'test_clean', 'test_noisy'])

fig.tight_layout()
tmp_file = os.path.join(output_dir, simulation_name+"_train.png")
fig.savefig(tmp_file)

In [None]:
# Plot of KL vs evidence
fontsize=20
labelsize=20

how_many = 2000
scale= 1
N = len(history_dict["train_kl"][-how_many :])
colors = np.arange(0.0,N,1.0)/N

fig, ax = plt.subplots(1,1)
#plt.yscale('log')
#plt.xlim(xmin=1.0, xmax=1.5)
ax.set_xlabel('REC',fontsize=fontsize)
ax.set_ylabel('REG',fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=labelsize)
ax.scatter(history_dict["train_nll"][-how_many :], history_dict["train_kl"][-how_many :],c=colors)
ax.plot(history_dict["train_nll"][-how_many :], history_dict["train_kl"][-how_many :], '-')
ax.grid()
#plt.xlim(xmax=2.5)

fig.tight_layout()
tmp_file = os.path.join(output_dir, simulation_name+"_kl_trajectory.png")
fig.savefig(tmp_file) 

### Check reconstruction

In [None]:
tmp_list = [291, 413, 133, 148, 1,2,3,4,5,6,7,8,9]
reference_imgs, labels=test_dataset.load(batch_size=9, indices=tmp_list)
metric, inference = vae.reconstruct_img(reference_imgs)

reconstruction_file = os.path.join(output_dir, simulation_name+"_reconstruction.png")
reference_file = os.path.join(output_dir, simulation_name+"_reference.png")

imgs_ref = show_batch(reference_imgs[:],n_col=3,n_padding=4,title="REFERENCE")
imgs_ref.savefig(reference_file)

imgs_rec = show_batch(inference.reconstruction, n_col=3,n_padding=4, title="REC_IMG")
imgs_rec.savefig(reconstruction_file)

display(imgs_rec, imgs_ref)

# MAKE MOVIE

In [None]:
epoch="xxx"
a = show_batch(inference.reconstruction[:9],n_col=3,n_padding=4,title="EPOCH = "+str(epoch))
display(a)

# actual loop

In [None]:
rec_filenames = []

for epoch in range(0,30,TEST_FREQUENCY):
    if(epoch<10):
        label ="_000"+str(epoch)
    elif(epoch<100):
        label = "_00"+str(epoch)
    elif(epoch<1000):
        label = "_0"+str(epoch)
    elif(epoch<10000):
        label = "_"+str(epoch)
    else:
        raise Exception
    
    try:
        checkpoint_file = os.path.join(output_dir, simulation_name+"_ckp_"+str(epoch)+".pkl")
        _ = load_model_optimizer(path=checkpoint_file, model=vae, optimizer=None)
        metric, inference = vae.reconstruct_img(reference_imgs)
        tmp_fig = show_batch(inference.reconstruction[:8],n_col=4,n_padding=4,title="EPOCH = "+str(epoch))
        tmp_rec_file = os.path.join(output_dir, simulation_name+label+"_rec.png")
        rec_filenames.append(tmp_rec_file)
        tmp_fig.savefig(tmp_rec_file, bbox_inches='tight') 
    except:
        pass
    
print(rec_filenames)

## Check individual images

In [None]:
def show_frame_rec(n):
    tmp = Image(filename=rec_filenames[n])
    return display(tmp)

def show_frame_all(n):
    c = Image(filename=rec_filenames[n])
    return display(c)

In [None]:
# make a gif file
#name_movie = "baseline_new_loss_v2.gif"
movie_rec_file_local = os.path.join("./", simulation_name+"_movie_rec.gif")
movie_rec_file_absolute = os.path.join(output_dir, simulation_name+"_movie_rec.gif")

frame_per_second = 2
im = mpy.ImageSequenceClip(rec_filenames, fps=frame_per_second)
im.write_gif(movie_rec_file_local, fps=frame_per_second)
im.write_gif(movie_rec_file_absolute, fps=frame_per_second)

In [None]:
HTML("<img src="+movie_rec_file_local+"></img>")

In [None]:
show_frame_rec(0)

In [None]:
show_batch(reference_imgs[:8],n_col=4,n_padding=4,title="REFERENCE")