In [None]:
import os
import sys
import types

parent_path = os.path.dirname(os.path.dirname(os.getcwd()))
if parent_path not in sys.path: sys.path.append(parent_path) 

import numpy as np
import proplot as plot

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.models.model_loader as ml
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.run_utils as ru
import DeepSparseCoding.utils.plot_functions as pf

In [None]:
workspace_dir = os.path.expanduser("~")+"/Work/"
#params_file = workspace_dir+"/Torch_projects/mlp_mnist/logfiles/mlp_mnist_v0.log"
#params_file = workspace_dir+"/Torch_projects/lca_mnist/logfiles/lca_mnist_v0.log"
params_file = workspace_dir+"/Torch_projects/lca_mlp_mnist/logfiles/lca_mlp_mnist_v0.log"
logger = Logger(params_file, overwrite=False)

log_text = logger.load_file()
params = logger.read_params(log_text)[-1]

In [None]:
train_loader, val_loader, test_loader, params = dataset_utils.load_dataset(params)

In [None]:
model = ml.load_model(params.model_type, params.lib_root_dir)
model.setup(params, logger)
model.to(params.device)
model.load_state_dict(torch.load(model.params.cp_save_dir+"trained_model.pt"))

In [None]:
model_stats = logger.read_stats(log_text)
x_key = "epoch"
stats_fig = pf.plot_stats(model_stats, x_key)

In [None]:
test_results = ru.test_epoch(0, model, test_loader, log_to_file=False)
print(test_results)

In [None]:
def set_size(width, fraction=1, subplot=[1, 1]):
    """ Set aesthetic figure dimensions to avoid scaling in latex.
    Parameters
    ----------
    width: float
            Width in pts
    fraction: float
            Fraction of the width which you wish the figure to occupy
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches

    Usage: figsize = set_size(text_width, fraction=1, subplot=[1, 1])
    Code obtained from: https://jwalton.info/Embed-Publication-Matplotlib-Latex/
    """
    fig_width_pt = width * fraction # Width of figure
    inches_per_pt = 1 / 72.27 # Convert from pt to inches
    golden_ratio = (5**.5 - 1) / 2 # Golden ratio to set aesthetic figure height
    fig_width_in = fig_width_pt * inches_per_pt # Figure width in inches
    fig_height_in = fig_width_in * golden_ratio * (subplot[0] / subplot[1]) # Figure height in inches
    fig_dim = (fig_width_in, fig_height_in) # Final figure dimensions
    return fig_dim

def plot_weights(weights, title="", figsize=None):
    num_weights, num_input_y, num_input_x = weights.shape
    num_plots_y = int(np.ceil(np.sqrt(num_weights))+1)
    num_plots_x = int(np.floor(np.sqrt(num_weights)))
    fig, axs = plot.subplots()