### MINIMAL EXAMPLE OF JUPYTER NOTEBOOK WHICH CAN BE RUN WITH CROMWELL

### IMPORT NECESSARY MODULES

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1" #for debugging, it decrease performance dramatically

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from utilities import show_batch, save_obj, load_obj 
from utilities import load_json_as_dict, save_dict_as_json, DatasetInMemory
from utilities import train_one_epoch, evaluate_one_epoch
from model import VaeClass 

#!pip install moviepy
from IPython.display import Image, display, HTML
import moviepy.editor as mpy
import numpy as np
import torch
import pyro

# Set up pyro environment
pyro.clear_param_store()
pyro.set_rng_seed(0)

# Check versions
print("pyro.__version__  --> ",pyro.__version__)
print("torch.__version__ --> ",torch.__version__)
assert(pyro.__version__.startswith('0.4'))
assert(torch.__version__.startswith('1.2'))

### read jason file

In [None]:
params = load_json_as_dict("./input_params.json")  
print(params)

### get the data

In [None]:
local_machine = True
if local_machine:
    output_dir = "/Users/ldalessi/cromwell_for_ML/RESULTS/"
    input_dir = "/Users/ldalessi/cromwell_for_ML/DATA/"
    train_file = input_dir+str(params["cloud"]["train_dataset"])+".pkl"
    test_file = input_dir+str(params["cloud"]["test_dataset"])+".pkl"
    train_dataset = DatasetInMemory(train_file,use_cuda=torch.cuda.is_available())
    test_dataset  = DatasetInMemory(test_file,use_cuda=torch.cuda.is_available())
else:
    raise Exception

In [None]:
train_dataset.check()

In [None]:
test_dataset.check()

# Instantiate everything

In [None]:
if params["run_type"]["from_scratch"]:
    

    epoch_restart = -1
    min_loss = 99999999
    history_dict = {}
    vae = VaeClass(params)
    optimizer = instantiate_optimizer(model, params):
    
else:
    
    resumed = load_everyhting(model=None, optimizer=None, params["run_type"]["from_scratch"])
    
    params = resumed.params
    epoch_restart = resumed.epoch
    history_dict = resumed.history_dict
    vae = resumed.model
    optimizer = resumed.optimizer
    

descriptor = params["run_type"]["identifier"]
name_vae          = descriptor+"_vae"
name_history      = descriptor+"_hystory"
    
save_dict_as_json(params,output_dir+"input_params.json")

if params["training"]["scheduler_is_active"]:
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                                step_size=params["training"]["step_size"],
                                                gamma=params["training"]["gamma"],
                                                last_epoch=-1)

### Train loop

In [None]:
TEST_FREQUENCY = params["training"]["TEST_FREQUENCY"]
WRITE_FREQUENCY = params["training"]["WRITE_FREQUENCY"]
NUM_EPOCHS = params["training"]["EPOCHS"]
BATCH_SIZE = params["training"]["batch_size"]
    
for delta_epoch in range(1,NUM_EPOCHS+1):
    epoch = delta_epoch+epoch_restart
    vae.train()   
    
    #with torch.autograd.set_detect_anomaly(True):
    with torch.autograd.set_detect_anomaly(False):
        metric_av = train_one_epoch(vae, train_dataset, optimizer, BATCH_SIZE, verbose=False, weight_clipper=None)
    if params["training"]["scheduler_is_active"]:
        scheduler.step(epoch=epoch)
    
    for k,v in metric_av.items():
        try: 
            history_dict["train_"+k].append(v)   
        except KeyError:
            history_dict["train_"+k] = [v]
        
    print("[epoch %03d] train loss: %.4f NLL: %.4f KL %.4f" % (epoch, history_dict["train_loss"][-1], history_dict["train_nll"][-1], history_dict["train_kl"][-1]))
          
    if(epoch % TEST_FREQUENCY == 0):
        vae.eval()
        metric_av = train_one_epoch(vae, train_dataset, optimizer, BATCH_SIZE, verbose=False, weight_clipper=None)
        test_loss = metric_av["loss"] 
        min_loss = test_loss if test_loss < min_loss else min_loss
        for k,v in metric_av.items():
            try: 
                history_dict["test_"+k].append(v)   
            except KeyError:
                history_dict["test_"+k] = [v]
            
        if((test_loss == min_loss) or ((epoch % WRITE_FREQUENCY) == 0)): 
            save_everything(vae, optimizer, history_dict, epoch, params, output_dir+name_vae+"_"+str(epoch)+".pkl"):
            save_dict_as_json(history_dict, output_dir+name_history+"_"+str(epoch)+".pkl")

# Check the results

In [None]:
history_dict.keys()

In [None]:
for k,v in history_dict.items():
    print(k," -->", history_dict[k][-3:])

In [None]:
#plt.yscale('log')
y_shift=0
x_shift=0
sign=1
plt.plot(np.arange(x_shift, x_shift+len(history_dict["train_loss"])), sign*np.array(history_dict["train_loss"])+y_shift,'-')
plt.plot(np.arange(x_shift, x_shift+len(history_dict["test_loss"])*TEST_FREQUENCY,TEST_FREQUENCY), sign*np.array(history_dict["test_loss"])+y_shift, '.--')
plt.xlabel('epoch')
plt.ylabel('LOSS = - ELBO')
plt.title('Training procedure')
#plt.ylim(ymax=2)
plt.grid(True)
plt.legend(['train', 'test_clean', 'test_noisy'])
#plt.show()
from matplotlib import pyplot as plt
plt.savefig(output_dir+name_vae+'_train.png')

In [None]:
# Plot of KL vs evidence
fontsize=20
labelsize=20

how_many = 2000
scale= 1
N = len(history_dict["train_kl"][-how_many :])
colors = np.arange(0.0,N,1.0)/N

#plt.yscale('log')
#plt.xlim(xmin=1.0, xmax=1.5)
plt.xlabel('REC',fontsize=fontsize)
plt.ylabel('REG',fontsize=fontsize)
plt.tick_params(axis='both', which='major', labelsize=labelsize)
plt.scatter(history_dict["train_nll"][-how_many :], history_dict["train_kl"][-how_many :],c=colors)
plt.plot(history_dict["train_nll"][-how_many :], history_dict["train_kl"][-how_many :], '-')
plt.grid()
#plt.xlim(xmax=2.5)
plt.savefig(output_dir+name_vae+'_kl_trajectory.png')

In [None]:
from matplotlib import pyplot as mp

tmp_list = [291, 413, 133, 148, 1,2,3,4,5,6,7,8,9]
reference_imgs, labels=test_dataset.load(batch_size=9, indices=tmp_list)
save_obj(reference_imgs ,output_dir+name_vae+"reference_img.pkl")

#reference_imgs = load_obj(output_dir+name_vae+"reference_img.pkl")

imgs_ref = show_batch(reference_imgs[:],n_col=3,n_padding=4,title="REFERENCE")
imgs_ref.savefig(output_dir+name_vae+'_reference.png')
display(imgs_ref)

# Clean vs Noisy reconstruction

In [None]:
metric, inference = vae.reconstruct_img(reference_imgs)

imgs_rec = show_batch(inference.reconstruction, n_col=3,n_padding=4, title="REC_IMG")
imgs_rec.savefig(output_dir+name_vae+'_reconstruction.png')
display(imgs_rec, imgs_ref)

# MAKE MOVIE

In [None]:
epoch="xxx"
tmp = show_batch(inference.reconstruction[:9],n_col=3,n_padding=4,title="EPOCH = "+str(epoch))
display(tmp)

# actual loop

In [None]:
list_of_rec_files = []
list_of_map_files = []
list_of_bg_files = []
#mpl.interactive(False)


for epoch in range(0,700,1):
    if(epoch<10):
        label ="_000"+str(epoch)
    elif(epoch<100):
        label = "_00"+str(epoch)
    elif(epoch<1000):
        label = "_0"+str(epoch)
    elif(epoch<10000):
        label = "_"+str(epoch)
    else:
        raise Exception
    
    try:
        _ = load_everything(vae, optimizer, output_dir+name_vae+"_"+str(epoch)+".pkl")
    except:
        print("merda")
        pass
        
    try:
        metric, inference = vae.reconstruct_img(reference_imgs)
        tmp = show_batch(inference.reconstruction[:8],n_col=4,n_padding=4,title="EPOCH = "+str(epoch))
        name_output_rec = name_vae+label+'rec.png'
        list_of_rec_files.append(name_output_rec)
        tmp.savefig(output_dir+name_output_rec, bbox_inches='tight') 
    except:
        pass
    
print(list_of_rec_files)
print(list_of_map_files)
print(list_of_bg_files)

## Check individual images

In [None]:
# concatenate filenames and directory
rec_filenames = [write_dir+name for name in list_of_rec_files]
map_filenames = [write_dir+name for name in list_of_map_files]
bg_filenames = [write_dir+name for name in list_of_bg_files]

print(rec_filenames)
print(map_filenames)
print(bg_filenames)

In [None]:
def show_frame_rec(n):
    return display.Image(filename=rec_filenames[n])

def show_frame_all(n):
    c = Image(filename=rec_filenames[n])
    return display(c)

In [None]:
# make a gif file
#name_movie = "baseline_new_loss_v2.gif"

movie_rec = "movie_"+name_vae+"_rec.gif"

frame_per_second = 2
im = mpy.ImageSequenceClip(rec_filenames, fps=frame_per_second)
im.write_gif(movie_rec, fps=frame_per_second)

In [None]:
HTML("<img src="+movie_rec+"></img>")

In [None]:
show_frame_rec(0)

In [None]:
show_batch(reference_imgs[:9],n_col=3,n_padding=4,title="REFERENCE")