In [None]:
import os
import sys
import types

parent_path = os.path.dirname(os.path.dirname(os.getcwd()))
if parent_path not in sys.path: sys.path.append(parent_path) 

import numpy as np
import proplot as plot

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.models.model_loader as ml
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils

In [None]:
workspace_dir = os.path.expanduser("~")+"/Work/"
#params_file = workspace_dir+"/Projects/mlp_mnist/logfiles/mlp_mnist_v0.log"
params_file = workspace_dir+"/Torch_projects/lca_mlp_mnist/logfiles/lca_mlp_mnist_v0.log"
logger = Logger(params_file, overwrite=False)

log_text = logger.load_file()
params = logger.read_params(log_text)[-1]

In [None]:
key_split = [key.split('_') for key in params.__dict__.keys()][0]
"_".join(key_split[1:])

In [None]:
read_params = types.SimpleNamespace()
read_params.ensemble_params = []

ensemble_nums = set()
for key, value in params.__dict__.items():
    key_split = key.split("_")
    ens_num = key_split[0] 
    if ens_num.isdigit(): # ensemble params are prefaced with the ensemble index
        ens_num = int(ens_num)
        if ens_num not in ensemble_nums:
            ensemble_nums.add(ens_num)
            read_params.ensemble_params.append(types.SimpleNamespace())
        setattr(read_params.ensemble_params[ens_num], "_".join(key_split[1:]), value)
    else: # if it is not a digit then it is a general param
        setattr(read_params, key, value)

In [None]:
params.name

In [None]:
train_loader, val_loader, test_loader, params = dataset_utils.load_dataset(read_params)

In [None]:
model = ml.load_model(read_params.model_type)
model.setup(read_params, logger)
model.to(read_params.device)
model.load_state_dict(torch.load(model.params.cp_save_dir+"trained_model.pt"))

In [None]:
with torch.no_grad():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.to(model.params.device), target.to(model.params.device)
        batch_test_loss, batch_correct = run_utils.test_single_model(model, data, target, 0)
        test_loss += batch_test_loss
        correct += batch_correct
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
print("Test loss:", test_loss)
print("Test accuracy:", test_accuracy, "%")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import re

def plot_stats(data, keys=None, labels=None, start_index=0, figsize=None, save_filename=None):
    """
    Generate time-series plots of stats specified by keys
    Inputs:
        data: [dict] containing data to be plotted. len of all values should be equal
            data must have the key "batch_step"
        keys: [list of str] optional list of keys to plot, each should exist in data.keys()
            If nothing is given, data.keys() will be used
        labels: [list of str] optional list of labels, should be the same length as keys input
            If nothing is given, data.keys() will be used
        save_filename: [str] containing the complete output filename.
    """
    if keys is None:
        keys = list(data.keys())
    else:
        assert all([key in data.keys() for key in keys]), (
            "All input keys must exist as keys in the data dictionary")
    assert len(keys) > 0, "Keys must be None or have length > 0."
    if "batch_step" in keys:
        keys.remove("batch_step")
    if "schedule_index" in keys:
        keys.remove("schedule_index")
    if "global_batch_index" in keys:
        keys.remove("global_batch_index")
    if labels is None:
        labels = keys
    else:
        assert len(labels) == len(keys), (
            "The number of labels must match the number of keys")
    num_keys = len(keys)
    gs = gridspec.GridSpec(num_keys, 1, hspace=0.5)
    fig = plt.figure(figsize=figsize)
    axis_image = [None]*num_keys
    for key_idx, key in enumerate(keys):
        x_dat = data["batch_step"][start_index:]
        y_dat = data[key][start_index:]
        ax = fig.add_subplot(gs[key_idx])
        axis_image[key_idx] = ax.plot(x_dat, y_dat)
        if key_idx < len(keys)-1:
            ax.get_xaxis().set_ticklabels([])
        ax.locator_params(axis="y", nbins=5)
        ax.set_ylabel("\n".join(re.split("_", labels[key_idx])))
        ax.set_yticks([np.minimum(0.0, np.min(y_dat)), np.maximum(0.0, np.max(y_dat))])
        ylabel_xpos = -0.05
        ax.yaxis.set_label_coords(ylabel_xpos, 0.5)
    ax.set_xlabel("Batch Number")
    fig.suptitle("Stats per Batch", y=0.95, x=0.5)
    if save_filename is not None:
        fig.savefig(save_filename, transparent=True)
        plt.close(fig)
        return None
    plt.show()
    return fig

In [None]:
run_stats = logger.read_stats(log_text)

In [None]:
keys=["loss", "train_accuracy"]
labels=["Loss", "Train Accuracy"]
stats_fig = plot_stats(run_stats, keys=keys, labels=labels, start_index=0, figsize=(10,5))

In [None]:
def set_size(width, fraction=1, subplot=[1, 1]):
    """ Set aesthetic figure dimensions to avoid scaling in latex.
    Parameters
    ----------
    width: float
            Width in pts
    fraction: float
            Fraction of the width which you wish the figure to occupy
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches

    Usage: figsize = set_size(text_width, fraction=1, subplot=[1, 1])
    Code obtained from: https://jwalton.info/Embed-Publication-Matplotlib-Latex/
    """
    fig_width_pt = width * fraction # Width of figure
    inches_per_pt = 1 / 72.27 # Convert from pt to inches
    golden_ratio = (5**.5 - 1) / 2 # Golden ratio to set aesthetic figure height
    fig_width_in = fig_width_pt * inches_per_pt # Figure width in inches
    fig_height_in = fig_width_in * golden_ratio * (subplot[0] / subplot[1]) # Figure height in inches
    fig_dim = (fig_width_in, fig_height_in) # Final figure dimensions
    return fig_dim

def plot_weights(weights, title="", figsize=None):
    num_weights, num_input_y, num_input_x = weights.shape
    num_plots_y = int(np.ceil(np.sqrt(num_weights))+1)
    num_plots_x = int(np.floor(np.sqrt(num_weights)))
    fig, axs = plot.subplots()