In [None]:
import os
import sys
import pickle
sys.path.append('../metanas')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from metanas.utils.visualize import plot
from IPython.display import Image, display, Markdown

In [None]:
def plot_loss_accuracy(path, eval_every=5):
    with (open(path, 'rb')) as f:
        res = pickle.load(f)
    
    _, axes = plt.subplots(1, 2, figsize=(20, 5))
    test_spacing = np.linspace(0, len(res['train_test_loss']), # Length like train loss
                               num=len(res['test_test_loss']), 
                               retstep=eval_every, dtype=np.int32)[0]

    axes[0].plot(res['train_test_loss'], 'o-', color="r",label="Training test loss")
    axes[0].plot(test_spacing, res['test_test_loss'],  'o-', color="g",
                 label="Test test loss")
    axes[0].set_ylabel("Loss")
    axes[0].set_xlabel("Epochs")
    axes[0].legend(loc="best")

    axes[1].plot(res['train_test_accu'], 'o-', color="r", label="Training test accuracy")
    axes[1].plot(test_spacing, res['test_test_accu'], 'o-', color="g", label="Test test accuracy")
    axes[1].set_ylabel("Accuracy")
    axes[1].set_xlabel("Epochs")
    axes[1].legend(loc="best")
    
    # Returns the final sparse parameters, accuracy and loss
    return (res['sparse_params_logger'][-1], res['train_test_loss'][-1],
            res['test_test_loss'][-1], res['train_test_accu'][-1], 
            res['test_test_accu'][-1])

def plot_genotype(path, eval_every):
    with (open(path, 'rb')) as f:
        res = pickle.load(f)
        
    cwd = os.getcwd()
    print(res['genotype'])
    for i in range(len(res['genotype'])):
        if i % eval_every == 0:
            
            display(Markdown(f'# Iteration: {i}'))
            plot(res['genotype'][i].normal, 'normal', 'normal cell')
            plot(res['genotype'][i].reduce, 'reduce', 'reduce cell')

            display(Image('normal.png'))
            display(Image('reduce.png'))
            
            # Clean repository
            os.remove(os.path.join(cwd, "normal"))
            os.remove(os.path.join(cwd, "reduce"))
            os.remove(os.path.join(cwd, "normal.png"))
            os.remove(os.path.join(cwd, "reduce.png"))

In [None]:
# path to experiment pickle
eval_every_n = 5
path_exp = '/home/rob/Git/meta-fsl-nas/metanas/results/og_train/experiment.pickle'

plot_loss_accuracy(path_exp)
plot_genotype(path_exp, eval_every_n)