In [None]:
import os
import sys

ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import numpy as np
import torch

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.plot_functions as pf

In [None]:
log_files = [
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log'])
    ]

cp_latest_filenames = [
    os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']),
    os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt'])
]

target_index = 1

logger = Logger(log_files[target_index], overwrite=False)
log_text = logger.load_file()
params = logger.read_params(log_text)[-1]

In [None]:
model_stats = logger.read_stats(log_text)
x_key = "epoch"
stats_fig = pf.plot_stats(model_stats, x_key)

In [None]:
model = loaders.load_model(params.model_type)
model.setup(params, logger)
model.to(params.device)
model.load_checkpoint()

In [None]:
train_loader, val_loader, test_loader, params = dataset_utils.load_dataset(params)

In [None]:
test_results = run_utils.test_epoch(0, model, test_loader, log_to_file=False)
print(test_results)

In [None]:
if(model.params.model_type == "ensemble"):
    weights = list(model.lca.parameters())[0].data.cpu().numpy()
    weights = weights.T
else:
    weights = list(model.parameters())[0].data.cpu().numpy()

num_neurons, num_pixels = weights.shape
weights = np.reshape(weights, [num_neurons, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])

In [None]:
import proplot as plot
import numpy as np

def normalize_data_with_max(data):
    """
    Normalize data by dividing by abs(max(data))
    If abs(max(data)) is zero, then output is zero
    Inputs:
        data: [np.ndarray] data to be normalized
    Outputs:
        norm_data: [np.ndarray] normalized data
        data_max: [float] max that was divided out
    """
    data_max = np.max(np.abs(data), axis=(1,2), keepdims=True)
    norm_data = np.divide(data, data_max, out=np.zeros_like(data), where=data_max!=0)
    return norm_data, data_max

def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):
    if normalize:
        matrix = normalize_data_with_max(matrix)[0]
    num_weights, img_h, img_w = matrix.shape
    num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)
    if num_extra_images > 0:
        matrix = np.concatenate(
            [matrix, np.zeros((num_extra_images, img_h, img_w))], axis=0)
    matrix = np.pad(matrix,
        pad_width=((0,0), (num_pad_pix,num_pad_pix), (num_pad_pix,num_pad_pix)),
        mode='constant', constant_values=pad_value)
    img_h, img_w = matrix.shape[1:]
    num_edge_tiles = int(np.sqrt(matrix.shape[0]))
    tiles = matrix.reshape(num_edge_tiles, num_edge_tiles, img_h, img_w)
    tiles = tiles.swapaxes(1, 2)
    matrix = tiles.reshape(num_edge_tiles*img_h, num_edge_tiles*img_w)
    return matrix
    
def plot_matrix(matrix, title=''):
    fig, ax = plot.subplots(figsize=(10,10))
    ax = pf.clear_axis(ax)
    ax.imshow(matrix, cmap='greys_r', vmin=0.0, vmax=1.0)
    ax.format(title=title)
    plot.show()
    return fig

pad_value = 0.5
num_pad_pix = 2
fig = plot_matrix(pad_matrix_to_image(weights, num_pad_pix, pad_value,
    normalize=True), title=f'{model.params.model_name} weights')
fig.savefig(f'{model.params.disp_dir}/weights_plot_matrix.png')

In [None]:
import DeepSparseCoding.tf1x.utils.plot_functions as tfpf

tfpf.plot_image(pad_matrix_to_image(weights), vmin=None, vmax=None, title="", save_filename=model.params.disp_dir+"/weights_plot_image.png")
tfpf.plot_weights(weights, save_filename=model.params.disp_dir+"/weights_plot_weights.png")
tfpf.plot_data_tiled(weights[..., None], save_filename=model.params.disp_dir+"/weights_plot_data_tiled.png")