In [None]:
%matplotlib inline

In [None]:
import os
import sys

ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import scipy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import proplot as plot
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap

from DeepSparseCoding.utils.file_utils import Logger
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.plot_functions as pf
import DeepSparseCoding.utils.data_processing as dp
from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams

In [None]:
workspace_dir = '/mnt/qb/bethge/dpaiton/'
model_name = 'lca_pool_lca_pool_cifar10'
model_version = '0'
log_file = workspace_dir + os.path.join(*['Projects', model_name, 'logfiles', f'{model_name}_v{model_version}.log'])
logger = Logger(log_file, overwrite=False)
log_text = logger.load_file()
params = logger.read_params(log_text)[-1]

In [None]:
model_stats = logger.read_stats(log_text)
x_key = "epoch"
y_keys = [key for key in list(model_stats.keys()) if 'test_' not in key]
stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)

if 'test_epoch' in list(model_stats.keys()):
    x_key = "test_epoch"
    y_keys = [key for key in list(model_stats.keys()) if 'test_' in key]
    test_stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)

In [None]:
model = loaders.load_model(params.model_type)
model.setup(params, logger)
model.to(params.device)
model_state_str = model.load_checkpoint()
print(model_state_str)

In [None]:
train_loader, val_loader, test_loader, data_stats, data_mean_std = dataset_utils.load_dataset(params)
train_mean_image = data_mean_std['dataset_mean_image'].to(model.params.device)
train_std_image = data_mean_std['dataset_std_image'].to(model.params.device)

In [None]:
lca_weights = model.lca_1.weight.detach().cpu().numpy()

In [None]:
def normalize_data_with_max(data):
    """
    Normalize data by dividing by abs(max(data))
    If abs(max(data)) is zero, then output is zero
    Inputs:
        data: [np.ndarray] data to be normalized
    Outputs:
        norm_data: [np.ndarray] normalized data
        data_max: [float] max that was divided out
    """
    data_max = np.max(np.abs(data), axis=(1,2), keepdims=True)
    norm_data = np.divide(data, data_max, out=np.zeros_like(data), where=data_max!=0)
    return norm_data, data_max

def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):
    if normalize:
        #matrix = normalize_data_with_max(matrix)[0]
        matrix = dp.rescale_data_to_one(torch.from_numpy(matrix), eps=1e-10, samplewise=True)[0].numpy()
    num_weights, img_c, img_h, img_w = matrix.shape
    #if img_c == 1:
    #    matrix = matrix.squeeze()
    #else:
    #    # TODO: separate channels, pad each individual one, then recombine.
    #    assert False, (f'Multiple color channels are not currently supported') 
    num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)
    matrices = []
    for channel_idx in range(img_c):
        channel_matrix = matrix[:, channel_idx, ...].copy()
        if num_extra_images > 0:
            channel_matrix = np.concatenate(
                [channel_matrix, np.zeros((num_extra_images, img_h, img_w))], axis=0)
        channel_matrix = np.pad(channel_matrix,
            pad_width=((0,0), (num_pad_pix, num_pad_pix), (num_pad_pix, num_pad_pix)),
            mode='constant', constant_values=pad_value)
        padded_img_h, padded_img_w = channel_matrix.shape[1:]
        num_edge_tiles = int(np.sqrt(channel_matrix.shape[0]))
        tiles = channel_matrix.reshape(num_edge_tiles, num_edge_tiles, padded_img_h, padded_img_w)
        tiles = tiles.swapaxes(1, 2)
        matrices.append(tiles.reshape(num_edge_tiles * padded_img_h, num_edge_tiles * padded_img_w))
    padded_matrix = np.stack(matrices, axis=0) # channel dim first
    return padded_matrix
    
def plot_matrix(matrix, title='', cmap=None):
    fig, ax = plot.subplots(figsize=(10,10))
    ax = pf.clear_axis(ax)
    ax.imshow(matrix, cmap=cmap)#, vmin=0.0, vmax=1.0)#, cmap='greys_r')
    ax.format(title=title)
    plot.show()
    return fig

pad_value = 0.5
num_pad_pix = 1
padded_matrix = pad_matrix_to_image(lca_weights, num_pad_pix, pad_value, normalize=True)
fig = plot_matrix(np.transpose(padded_matrix, axes=[1, 2, 0]), title=f'{model.params.model_name} weights')
fig.savefig(
    f'{model.params.disp_dir}/weights_plot_matrix.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
def rgb_to_gray(rgb):
    num, chan, height, width = rgb.shape
    gray = np.zeros((num, 1, height, width))
    for neuron_idx in range(num):
        gray[neuron_idx, ...] = 0.2125 * rgb[neuron_idx, 0, ...]
        gray[neuron_idx, ...] += 0.7154 * rgb[neuron_idx, 1, ...]
        gray[neuron_idx, ...] += 0.0721 * rgb[neuron_idx, 2, ...]
    return gray

gray_lca_weights = rgb_to_gray(lca_weights)
pad_value = 0.5
num_pad_pix = 1
padded_matrix = pad_matrix_to_image(gray_lca_weights, num_pad_pix, pad_value, normalize=True)
fig = plot_matrix(np.squeeze(np.transpose(padded_matrix, axes=[1, 2, 0])), title=f'{model.params.model_name} weights', cmap='grays_r')
fig.savefig(
    f'{model.params.disp_dir}/weights_grayscale_plot_matrix.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
def generate_gaussian(shape, mean, cov):
    """
    Generate a Gaussian PDF from given mean & cov
    Inputs:
        shape: [tuple] specifying (num_rows, num_cols)
        mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center
        cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix
    Outputs:
        tuple containing (Gaussian PDF, grid_points used to generate PDF)
            grid_points are specified as a tuple of (y,x) points
    """
    (y_size, x_size) = shape
    y = np.linspace(0, y_size, np.int32(np.floor(y_size)))
    x = np.linspace(0, x_size, np.int32(np.floor(x_size)))
    y, x = np.meshgrid(y, x)
    pos = np.empty(x.shape + (2,)) #x.shape == y.shape
    pos[:, :, 0] = y; pos[:, :, 1] = x
    gauss = scipy.stats.multivariate_normal(mean, cov)
    return (gauss.pdf(pos), (y,x))


def gaussian_fit(pyx):
    """
    Compute the expected mean & covariance matrix for a 2-D gaussian fit of input distribution
    Inputs:
        pyx: [np.ndarray] of shape [num_rows, num_cols] that indicates the probability function to fit
    Outputs:
        mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center
        cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix
    """
    assert pyx.ndim == 2, (
        "Input must have 2 dimensions specifying [num_rows, num_cols]")
    mean = np.zeros((1,2), dtype=np.float32) # [mu_y, mu_x]
    for idx in np.ndindex(pyx.shape): # [y, x] ticks columns (x) first, then rows (y)
        mean += np.asarray([pyx[idx]*idx[0], pyx[idx]*idx[1]])[None,:]
    cov = np.zeros((2,2), dtype=np.float32)
    for idx in np.ndindex(pyx.shape): # ticks columns first, then rows
        cov += np.dot((idx-mean).T, (idx-mean))*pyx[idx] # typically an outer-product
    return (np.squeeze(mean), cov)


def get_gauss_fit(prob_map, num_attempts=1, perc_mean=0.33):
    """
    Returns a gaussian fit for a given probability map
    Fitting is done via robust regression, where a fit is
    continuously refined by deleting outliers num_attempts times
    Inputs:
        prob_map: 2-D probability map to be fit
        num_attempts: Number of times to fit & remove outliers
        perc_mean: All probability values below perc_mean*mean(gauss_fit) will be
            considered outliers for repeated attempts
    Outputs:
        gauss_fit: [np.ndarray] specifying the 2-D Gaussian PDF
        grid: [tuple] containing (y,x) points with which the Gaussian PDF can be plotted
        gauss_mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center
        gauss_cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix
    """
    assert prob_map.ndim==2, (
        "get_gauss_fit: Input prob_map must have 2 dimension specifying [num_rows, num_cols")
    if num_attempts < 1:
        num_attempts = 1
    orig_prob_map = prob_map.copy()
    gauss_success = False
    while not gauss_success:
        prob_map = orig_prob_map.copy()
        try:
            for i in range(num_attempts):
                map_min = np.min(prob_map)
                prob_map -= map_min
                map_sum = np.sum(prob_map)
                if map_sum != 1.0:
                    prob_map /= map_sum
                gauss_mean, gauss_cov = gaussian_fit(prob_map)
                gauss_fit, grid = generate_gaussian(prob_map.shape, gauss_mean, gauss_cov)
                gauss_fit = (gauss_fit * map_sum) + map_min
                if i < num_attempts-1:
                    gauss_mask = gauss_fit.copy().T
                    mask_slice = np.where(gauss_mask<perc_mean*np.mean(gauss_mask))
                    gauss_mask[mask_slice] = 0
                    gauss_mask[np.where(gauss_mask>0)] = 1
                    prob_map *= gauss_mask
            gauss_success = True
        except np.linalg.LinAlgError: # Usually means cov matrix is singular
            print("get_gauss_fit: Failed to fit Gaussian at attempt ",i,", trying again."+
                "\n  To avoid this try decreasing perc_mean.")
            num_attempts = i-1
            if num_attempts <= 0:
                assert False, ("get_gauss_fit: np.linalg.LinAlgError - Unable to fit gaussian.")
    return (gauss_fit, grid, gauss_mean, gauss_cov)


def hilbert_amplitude(weights, padding=None):
    """
    Compute Hilbert amplitude envelope of weight matrix
    Inputs:
        weights: [np.ndarray] of shape [num_inputs, num_outputs]
            num_inputs must have an even square root
        padding: [int] specifying how much 0-padding to use for FFT
            default is the closest power of 2 of sqrt(num_inputs)
    Outputs:
        env: [np.ndarray] of shape [num_outputs, num_inputs]
            Hilbert envelope
        bff_filt: [np.ndarray] of shape [num_outputs, padded_num_inputs]
            Filtered Fourier transform of basis function
        hil_filt: [np.ndarray] of shape [num_outputs, sqrt(num_inputs), sqrt(num_inputs)]
            Hilbert filter to be applied in Fourier space
        bffs: [np.ndarray] of shape [num_outputs, padded_num_inputs, padded_num_inputs]
            Fourier transform of input weights
    """
    cart2pol = lambda x,y: (np.arctan2(y,x), np.hypot(x, y))
    num_inputs, num_outputs = weights.shape
    assert np.sqrt(num_inputs) == np.floor(np.sqrt(num_inputs)), (
        "weights.shape[0] must have an even square root.")
    patch_edge_size = int(np.sqrt(num_inputs))
    if padding is None or padding <= patch_edge_size:
        # Amount of zero padding for fft2 (closest power of 2)
        N = np.int(2**(np.ceil(np.log2(patch_edge_size))))
    else:
        N = np.int(padding)
    # Analytic signal envelope for weights
    # (Hilbet transform of each basis function)
    env = np.zeros((num_outputs, num_inputs), dtype=complex)
    # Fourier transform of weights
    bffs = np.zeros((num_outputs, N, N), dtype=complex)
    # Filtered Fourier transform of weights
    bff_filt = np.zeros((num_outputs, N**2), dtype=complex)
    # Hilbert filters
    hil_filt = np.zeros((num_outputs, N, N))
    # Grid for creating filter
    f = (2/N) * np.pi * np.arange(-N/2.0, N/2.0)
    (fx, fy) = np.meshgrid(f, f)
    (theta, r) = cart2pol(fx, fy)
    for neuron_idx in range(num_outputs):
        # Grab single basis function, reshape to a square image
        bf = weights[:, neuron_idx].reshape(patch_edge_size, patch_edge_size)
        # Convert basis function into DC-centered Fourier domain
        bff = np.fft.fftshift(np.fft.fft2(bf-np.mean(bf), [N, N]))
        bffs[neuron_idx, ...] = bff
        # Find indices of the peak amplitude
        max_ys = np.abs(bff).argmax(axis=0) # Returns row index for each col
        max_x = np.argmax(np.abs(bff).max(axis=0))
        # Convert peak amplitude location into angle in freq domain
        fx_ang = f[max_x]
        fy_ang = f[max_ys[max_x]]
        theta_max = np.arctan2(fy_ang, fx_ang)
        # Define the half-plane with respect to the maximum
        ang_diff = np.abs(theta-theta_max)
        idx = (ang_diff>np.pi).nonzero()
        ang_diff[idx] = 2.0 * np.pi - ang_diff[idx]
        hil_filt[neuron_idx, ...] = (ang_diff < np.pi/2.0).astype(int)
        # Create analytic signal from the inverse FT of the half-plane filtered bf
        abf = np.fft.ifft2(np.fft.fftshift(hil_filt[neuron_idx, ...]*bff))
        env[neuron_idx, ...] = abf[0:patch_edge_size, 0:patch_edge_size].reshape(num_inputs)
        bff_filt[neuron_idx, ...] = (hil_filt[neuron_idx, ...]*bff).reshape(N**2)
    return (env, bff_filt, hil_filt, bffs)


def get_dictionary_stats(weights, padding=None, num_gauss_fits=20, gauss_thresh=0.2):
    """
    Compute summary statistics on dictionary elements using Hilbert amplitude envelope
    Inputs:
        weights: [np.ndarray] of shape [num_inputs, num_outputs]
        padding: [int] total image size to pad out to in the FFT computation
        num_gauss_fits: [int] total number of attempts to make when fitting the BFs
        gauss_thresh: All probability values below gauss_thresh*mean(gauss_fit) will be
            considered outliers for repeated fits
    Outputs:
      The function output is a dictionary containing the keys for each type of analysis
      Each key dereferences a list of len num_outputs (i.e. one entry for each weight vector)
      The keys and their list entries are as follows:
          basis_functions: [np.ndarray] of shape [patch_edge_size, patch_edge_size]
          envelopes: [np.ndarray] of shape [N, N], where N is the amount of padding
              for the hilbert_amplitude function
          envelope_centers: [tuples of ints] indicating the (y, x) position of the
              center of the Hilbert envelope
          gauss_fits: [list of np.ndarrays] containing (gaussian_fit, grid) where gaussian_fit
              is returned from get_gauss_fit and specifies the 2D Gaussian PDF fit to the Hilbert
              envelope and grid is a tuple containing (y,x) points with which the Gaussian PDF
              can be plotted
          gauss_centers: [list of ints] containing the (y,x) position of the center of
              the Gaussian fit
          gauss_orientations: [list of np.ndarrays] containing the (eigenvalues, eigenvectors) of
              the covariance matrix for the Gaussian fit of the Hilbert amplitude envelope. They are
              both sorted according to the highest to lowest Eigenvalue.
          fourier_centers: [list of ints] containing the (y,x) position of the center (max) of
              the Fourier amplitude map
          num_inputs: [int] dim[0] of input weights
          num_outputs: [int] dim[1] of input weights
          patch_edge_size: [int] int(floor(sqrt(num_inputs)))
          areas: [list of floats] area of enclosed ellipse
          spatial_frequncies: [list of floats] dominant spatial frequency for basis function
    """
    envelope, bff_filt, hil_filter, bffs = hilbert_amplitude(weights, padding)
    num_inputs, num_outputs = weights.shape
    patch_edge_size = np.int(np.floor(np.sqrt(num_inputs)))
    basis_funcs = [None]*num_outputs
    envelopes = [None]*num_outputs
    gauss_fits = [None]*num_outputs
    gauss_centers = [None]*num_outputs
    diameters = [None]*num_outputs
    gauss_orientations = [None]*num_outputs
    envelope_centers = [None]*num_outputs
    fourier_centers = [None]*num_outputs
    ellipse_orientations = [None]*num_outputs
    fourier_maps = [None]*num_outputs
    spatial_frequencies = [None]*num_outputs
    areas = [None]*num_outputs
    phases = [None]*num_outputs
    for bf_idx in range(num_outputs):
        # Reformatted individual basis function
        basis_funcs[bf_idx] = weights.T[bf_idx,...].reshape((patch_edge_size, patch_edge_size))
        # Reformatted individual envelope filter
        envelopes[bf_idx] = np.abs(envelope[bf_idx,...]).reshape((patch_edge_size, patch_edge_size))
        # Basis function center
        max_ys = envelopes[bf_idx].argmax(axis=0) # Returns row index for each col
        max_x = np.argmax(envelopes[bf_idx].max(axis=0))
        y_cen = max_ys[max_x]
        x_cen = max_x
        envelope_centers[bf_idx] = (y_cen, x_cen)
        # Gaussian fit to Hilbet amplitude envelope
        gauss_fit, grid, gauss_mean, gauss_cov = get_gauss_fit(envelopes[bf_idx],
            num_gauss_fits, gauss_thresh)
        gauss_fits[bf_idx] = (gauss_fit, grid)
        gauss_centers[bf_idx] = gauss_mean
        evals, evecs = np.linalg.eigh(gauss_cov)
        sort_indices = np.argsort(evals)[::-1]
        gauss_orientations[bf_idx] = (evals[sort_indices], evecs[:,sort_indices])
        width, height = evals[sort_indices] # Width & height are relative to orientation
        diameters[bf_idx] = np.sqrt(width**2+height**2)
        # Fourier function center, spatial frequency, orientation
        fourier_map = np.sqrt(np.real(bffs[bf_idx, ...])**2+np.imag(bffs[bf_idx, ...])**2)
        fourier_maps[bf_idx] = fourier_map
        N = fourier_map.shape[0]
        center_freq = int(np.floor(N/2))
        fourier_map[center_freq, center_freq] = 0 # remove DC component
        max_fys = fourier_map.argmax(axis=0)
        max_fx = np.argmax(fourier_map.max(axis=0))
        fy_cen = (max_fys[max_fx] - (N/2)) * (patch_edge_size/N)
        fx_cen = (max_fx - (N/2)) * (patch_edge_size/N)
        fourier_centers[bf_idx] = [fy_cen, fx_cen]
        # NOTE: we flip fourier_centers because fx_cen is the peak of the x frequency,
        # which would be a y coordinate
        ellipse_orientations[bf_idx] = np.arctan2(*fourier_centers[bf_idx][::-1])
        spatial_frequencies[bf_idx] = np.sqrt(fy_cen**2 + fx_cen**2)
        areas[bf_idx] = np.pi * np.prod(evals)
        phases[bf_idx] = np.angle(bffs[bf_idx])[y_cen, x_cen]
    output = {"basis_functions":basis_funcs, "envelopes":envelopes, "gauss_fits":gauss_fits,
        "gauss_centers":gauss_centers, "gauss_orientations":gauss_orientations, "areas":areas,
        "fourier_centers":fourier_centers, "fourier_maps":fourier_maps, "num_inputs":num_inputs,
        "spatial_frequencies":spatial_frequencies, "envelope_centers":envelope_centers,
        "num_outputs":num_outputs, "patch_edge_size":patch_edge_size, "phases":phases,
        "ellipse_orientations":ellipse_orientations, "diameters":diameters}
    return output

In [None]:
bf_stats = get_dictionary_stats(
    gray_lca_weights.reshape(gray_lca_weights.shape[0], -1).T,
    padding=32,
    num_gauss_fits=20,
    gauss_thresh=0.2)

np.savez(
    model.params.save_dir+'bf_summary_stats.npz',
    data={'bf_stats':bf_stats})

In [None]:
def clear_axis(ax, spines="none"):
    for ax_loc in ["top", "bottom", "left", "right"]:
        ax.spines[ax_loc].set_color(spines)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.tick_params(axis="both", bottom=False, top=False, left=False, right=False)
    return ax

def plot_ellipse(axis, center, shape, angle, color_val="auto", alpha=1.0, lines=False,
    fill_ellipse=False):
    """
    Add an ellipse to given axis
    Inputs:
        axis [matplotlib.axes._subplots.AxesSubplot] axis on which ellipse should be drawn
        center [tuple or list] specifying [y, x] center coordinates
        shape [tuple or list] specifying [width, height] shape of ellipse
        angle [float] specifying angle of ellipse
        color_val [matplotlib color spec] specifying the color of the edge & face of the ellipse
        alpha [float] specifying the transparency of the ellipse
        lines [bool] if true, output will be a line, where the secondary axis of the ellipse
            is collapsed
        fill_ellipse [bool] if true and lines is false then a filled ellipse will be plotted
    Outputs:
        ellipse [matplotlib.patches.ellipse] ellipse object
    """
    if fill_ellipse:
        face_color_val = "none" if color_val=="auto" else color_val
    else:
        face_color_val = "none"
    y_cen, x_cen = center
    width, height = shape
    if lines:
        min_length = 0.1
        if width < height:
            width = min_length
        elif width > height:
            height = min_length
    ellipse = matplotlib.patches.Ellipse(xy=[x_cen, y_cen], width=width,
        height=height, angle=angle, edgecolor=color_val, facecolor=face_color_val,
        alpha=alpha, fill=True)
    axis.add_artist(ellipse)
    ellipse.set_clip_box(axis.bbox)
    return ellipse

def plot_ellipse_summaries(bf_stats, num_bf=-1, lines=False, rand_bf=False):
    """
    Plot basis functions with summary ellipses drawn over them
    Inputs:
        bf_stats [dict] output of dp.get_dictionary_stats()
        num_bf [int] number of basis functions to plot (<=0 is all; >total is all)
        lines [bool] If true, will plot lines instead of ellipses
        rand_bf [bool] If true, will choose a random set of basis functions
    """
    tot_num_bf = len(bf_stats["basis_functions"])
    if num_bf <= 0 or num_bf > tot_num_bf:
        num_bf = tot_num_bf
    SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)
        for fcent in bf_stats["fourier_centers"]], dtype=np.float32)
    sf_sort_indices = np.argsort(SFs)
    if rand_bf:
        bf_range = np.random.choice([i for i in range(tot_num_bf)], num_bf, replace=False)
    num_plots_y = int(np.ceil(np.sqrt(num_bf)))
    num_plots_x = int(np.ceil(np.sqrt(num_bf)))
    gs = gridspec.GridSpec(num_plots_y, num_plots_x)
    fig = plt.figure(figsize=(17,17))
    filter_idx = 0
    for plot_id in  np.ndindex((num_plots_y, num_plots_x)):
        ax = clear_axis(fig.add_subplot(gs[plot_id]))
        if filter_idx < tot_num_bf and filter_idx < num_bf:
            if rand_bf:
                bf_idx = bf_range[filter_idx]
            else:
                bf_idx = filter_idx
            bf = bf_stats["basis_functions"][bf_idx]
            ax.imshow(bf, interpolation="Nearest", cmap="grays_r")
            ax.set_title(str(bf_idx), fontsize="8")
            center = bf_stats["gauss_centers"][bf_idx]
            evals, evecs = bf_stats["gauss_orientations"][bf_idx]
            orientations = bf_stats["fourier_centers"][bf_idx]
            angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))
            alpha = 1.0
            ellipse = plot_ellipse(ax, center, evals, angle, color_val="b", alpha=alpha, lines=lines)
            filter_idx += 1
        ax.set_aspect("equal")
    plt.show()
    return fig

In [None]:
fig = plot_ellipse_summaries(bf_stats, lines=False)

fig.savefig(
    f'{model.params.disp_dir}/basis_function_fits.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
def bgr_colormap():
    """
    In cdict, the first column is interpolated between 0.0 & 1.0 - this indicates the value to be plotted
    the second column specifies how interpolation should be done from below
    the third column specifies how interpolation should be done from above
    if the second column does not equal the third, then there will be a break in the colors
    """
    darkness = 0.85 #0 is black, 1 is white
    cdict = {
        'red': ((0.0, 0.0, 0.0),
            (0.5, darkness, darkness),
            (1.0, 1.0, 1.0)),
        'green': ((0.0, 0.0, 0.0),
            (0.5, darkness, darkness),
            (1.0, 0.0, 0.0)),
        'blue': ((0.0, 1.0, 1.0),
            (0.5, darkness, darkness),
            (1.0, 0.0, 0.0))
    }
    return LinearSegmentedColormap("bgr", cdict)

def plot_pooling_centers(bf_stats, pooling_filters, num_pooling_filters, num_connected_weights,
    spot_size=10, figsize=None):
    """
    Plot 2nd layer (fully-connected) weights in terms of spatial/frequency centers of
        1st layer weights
    Inputs:
        bf_stats [dict] Output of dp.get_dictionary_stats() which was run on the 1st layer weights
        pooling_filters [np.ndarray] 2nd layer weights
            should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]
        num_pooling_filters [int] How many 2nd layer neurons to plot
        figsize [tuple] Containing the (width, height) of the figure, in inches
        spot_size [int] How big to make the points
    """
    num_filters_y = int(np.ceil(np.sqrt(num_pooling_filters)))
    num_filters_x = int(np.ceil(np.sqrt(num_pooling_filters)))
    tot_pooling_filters = pooling_filters.shape[1]
    #filter_indices = np.random.choice(tot_pooling_filters, num_pooling_filters, replace=False)
    filter_indices = np.arange(tot_pooling_filters, dtype=np.int32)
    cmap = plt.get_cmap(bgr_colormap())# Could also use "nipy_spectral", "coolwarm", "bwr"
    cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)
    scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
    x_p_cent = [x for (y,x) in bf_stats["gauss_centers"]]# Get raw points
    y_p_cent = [y for (y,x) in bf_stats["gauss_centers"]]
    x_f_cent = [x for (y,x) in bf_stats["fourier_centers"]]
    y_f_cent = [y for (y,x) in bf_stats["fourier_centers"]]
    max_sf = np.max(np.abs(x_f_cent+y_f_cent))
    pair_w_gap = 0.01
    group_w_gap = 0.03
    h_gap = 0.03
    plt_w = (num_filters_x/num_pooling_filters)
    plt_h = plt_w
    if figsize is None:
        fig = plt.figure()
        figsize = (fig.get_figwidth(), fig.get_figheight())
    else:
        fig = plt.figure(figsize=figsize) #figsize is (w,h)
    axes = []
    filter_id = 0
    for plot_id in np.ndindex((num_filters_y, num_filters_x)):
        if all(pid == 0 for pid in plot_id):
            axes.append(clear_axis(fig.add_axes([0, plt_h+h_gap, 2*plt_w, plt_h])))
            scalarMap._A = []
            cbar = fig.colorbar(scalarMap, ax=axes[-1], ticks=[-1, 0, 1], aspect=10, location="bottom")
            cbar.ax.set_xticklabels(["-1", "0", "1"])
            cbar.ax.xaxis.set_ticks_position('top')
            cbar.ax.xaxis.set_label_position('top')
            for label in cbar.ax.xaxis.get_ticklabels():
                label.set_weight("bold")
                label.set_fontsize(10+figsize[0])
        if (filter_id < num_pooling_filters):
            example_filter = pooling_filters[:, filter_indices[filter_id]]
            top_indices = np.argsort(np.abs(example_filter))[::-1] #descending
            selected_indices = top_indices[:num_connected_weights][::-1] #select top, plot weakest first
            filter_norm = np.max(np.abs(example_filter))
            connection_colors = [scalarMap.to_rgba(example_filter[bf_idx]/filter_norm)
                for bf_idx in range(bf_stats["num_outputs"])]
            if num_connected_weights < top_indices.size:
                black_indices = top_indices[num_connected_weights:][::-1]
                xp = [x_p_cent[i] for i in black_indices]+[x_p_cent[i] for i in selected_indices]
                yp = [y_p_cent[i] for i in black_indices]+[y_p_cent[i] for i in selected_indices]
                xf = [x_f_cent[i] for i in black_indices]+[x_f_cent[i] for i in selected_indices]
                yf = [y_f_cent[i] for i in black_indices]+[y_f_cent[i] for i in selected_indices]
                c = [(0.1,0.1,0.1,1.0) for i in black_indices]+[connection_colors[i] for i in selected_indices]
            else:
                xp = [x_p_cent[i] for i in selected_indices]
                yp = [y_p_cent[i] for i in selected_indices]
                xf = [x_f_cent[i] for i in selected_indices]
                yf = [y_f_cent[i] for i in selected_indices]
                c = [connection_colors[i] for i in selected_indices]
            (y_id, x_id) = plot_id
            if x_id == 0:
                ax_l = 0
                ax_b = - y_id * (plt_h+h_gap)
            else:
                bbox = axes[-1].get_position().get_points()[0]#bbox is [[x0,y0],[x1,y1]]
                prev_l = bbox[0]
                prev_b = bbox[1]
                ax_l = prev_l + plt_w + group_w_gap
                ax_b = prev_b
            ax_w = plt_w
            ax_h = plt_h
            axes.append(clear_axis(fig.add_axes([ax_l, ax_b, ax_w, ax_h])))
            axes[-1].invert_yaxis()
            axes[-1].scatter(xp, yp, c=c, s=spot_size, alpha=0.8)
            axes[-1].set_xlim(0, bf_stats["patch_edge_size"]-1)
            axes[-1].set_ylim(bf_stats["patch_edge_size"]-1, 0)
            axes[-1].set_aspect("equal")
            axes[-1].set_facecolor("w")
            axes.append(clear_axis(fig.add_axes([ax_l+ax_w+pair_w_gap, ax_b, ax_w, ax_h])))
            axes[-1].scatter(xf, yf, c=c, s=spot_size, alpha=0.8)
            axes[-1].set_xlim([-max_sf, max_sf])
            axes[-1].set_ylim([-max_sf, max_sf])
            axes[-1].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
            axes[-1].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
            axes[-1].set_aspect("equal")
            axes[-1].set_facecolor("w")
            filter_id += 1
    plt.show()
    return fig

In [None]:
kernel_pos = 0
pool_weights = model.pool_1.layer.weight.detach().cpu().numpy()
outputs, inputs, kernel_h, kernel_w = pool_weights.shape

fig = plot_pooling_centers(
    bf_stats,
    pool_weights[:, :, kernel_pos, kernel_pos].T,
    num_pooling_filters=outputs,
    num_connected_weights=inputs,
    spot_size=3,
    figsize=(5, 5))

fig.savefig(
    f'{model.params.disp_dir}/pooling_spots.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
def plot_pooling_summaries(bf_stats, pooling_filters, num_pooling_filters,
    num_connected_weights, lines=False, figsize=None):
    """
    Plot 2nd layer (fully-connected) weights in terms of connection strengths to 1st layer weights
    Inputs:
        bf_stats [dict] output of dp.get_dictionary_stats() which was run on the 1st layer weights
        pooling_filters [np.ndarray] 2nd layer weights
            should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]
        num_pooling_filters [int] How many 2nd layer neurons to plot
        num_connected_weights [int] How many 1st layer weight summaries to include
            for a given 2nd layer neuron
        lines [bool] if True, 1st layer weight summaries will appear as lines instead of ellipses
    """
    num_inputs = bf_stats["num_inputs"]
    num_outputs = bf_stats["num_outputs"]
    tot_pooling_filters = pooling_filters.shape[1]
    patch_edge_size = np.int32(np.sqrt(num_inputs))
    filter_idx_list = np.arange(num_pooling_filters, dtype=np.int32)
    assert num_pooling_filters <= num_outputs, (
        "num_pooling_filters must be less than or equal to bf_stats['num_outputs']")
    cmap = bgr_colormap()#plt.get_cmap('bwr')
    cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)
    scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)
    num_plots_y = np.int32(np.ceil(np.sqrt(num_pooling_filters)))
    num_plots_x = np.int32(np.ceil(np.sqrt(num_pooling_filters)))+1 # +cbar col
    gs_widths = [1 for _ in range(num_plots_x-1)]+[0.3]
    gs = gridspec.GridSpec(num_plots_y, num_plots_x, width_ratios=gs_widths)
    if figsize is None:
        fig = plt.figure()
        figsize = (fig.get_figwidth(), fig.get_figheight())
    else:
        fig = plt.figure(figsize=figsize)
    filter_total = 0
    for plot_id in  np.ndindex((num_plots_y, num_plots_x-1)):
        (y_id, x_id) = plot_id
        ax = fig.add_subplot(gs[plot_id])
        if (filter_total < num_pooling_filters and x_id != num_plots_x-1):
            ax = clear_axis(ax, spines="k")
            filter_idx = filter_idx_list[filter_total]
            example_filter = pooling_filters[:, filter_idx]
            top_indices = np.argsort(np.abs(example_filter))[::-1] #descending
            filter_norm = np.max(np.abs(example_filter))
            SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)
                for fcent in bf_stats["fourier_centers"]], dtype=np.float32)
            # Plot weakest of the top connected filters first because of occlusion
            for bf_idx in top_indices[:num_connected_weights][::-1]:
                connection_strength = example_filter[bf_idx]/filter_norm
                color_val = scalarMap.to_rgba(connection_strength)
                center = bf_stats["gauss_centers"][bf_idx]
                evals, evecs = bf_stats["gauss_orientations"][bf_idx]
                orientations = bf_stats["fourier_centers"][bf_idx]
                angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))
                alpha = 0.5#todo:spatial_freq for filled ellipses?
                ellipse = plot_ellipse(ax, center, evals, angle, color_val, alpha=alpha, lines=lines)
            ax.set_xlim(0, patch_edge_size-1)
            ax.set_ylim(patch_edge_size-1, 0)
            filter_total += 1
        else:
            ax = clear_axis(ax, spines="none")
        ax.set_aspect("equal")
    scalarMap._A = []
    ax = clear_axis(fig.add_subplot(gs[0, -1]))
    cbar = fig.colorbar(scalarMap, ax=ax, ticks=[-1, 0, 1])
    cbar.ax.set_yticklabels(["-1", "0", "1"])
    for label in cbar.ax.yaxis.get_ticklabels():
        label.set_weight("bold")
        label.set_fontsize(14)
    plt.show()
    return fig

In [None]:
fig = plot_pooling_summaries(
    bf_stats,
    pool_weights[:, :, kernel_pos, kernel_pos].T,
    num_pooling_filters=outputs,
    num_connected_weights=40,
    lines=True,
    figsize=(18,18))

fig.savefig(
    f'{model.params.disp_dir}/pooling_lines.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
P = pool_weights[:, :, kernel_pos, kernel_pos] # [inputs, outputs]
p_norm = np.linalg.norm(P, ord=2, axis=0)
affinity = np.dot(P.T, P) # cosyne similarity of neurons in embedded space
for i in range(affinity.shape[0]):
    for j in range(affinity.shape[1]):
        affinity[i, j] = affinity[i, j] /  (p_norm[i] * p_norm[j])
affinity = affinity.T # [inputs, inputs]

In [None]:
fig = plot_pooling_centers(
    bf_stats,
    affinity,
    num_pooling_filters=outputs,
    num_connected_weights=128, 
    spot_size=30,
    figsize=(5, 5))

fig.savefig(
    f'{model.params.disp_dir}/affinity_spots.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
fig = plot_pooling_summaries(
    bf_stats,
    affinity,
    num_pooling_filters=outputs,
    num_connected_weights=20,
    lines=True,
    figsize=(10, 10))

fig.savefig(
    f'{model.params.disp_dir}/affinity_lines.png',
    transparent=False,
    bbox_inches='tight')

In [None]:
example_batch = next(iter(train_loader))[0].to(model.params.device)
example_batch = model[0].preprocess_data(example_batch)
example_batch *= train_std_image
example_batch += train_mean_image
batch_min = example_batch.min().item()
batch_max = example_batch.max().item()

example_image = example_batch[0, ...]
print(
    f'min = {example_image.min().item()}'+
    f'\nmean = {example_image.mean().item()}'+
    f'\nmax = {example_image.max().item()}'+
    f'\nstd = {example_image.std().item()}')

plot_example_image = ((example_image * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)
fig, ax = plot.subplots(nrows=1, ncols=1)
ax = pf.clear_axis(ax)
ax.imshow(plot_example_image, vmin=0, vmax=1)
plot.show()

In [None]:
beta_2 = model(example_image[None,...])

In [None]:
class lca_2_recon_params(LcaParams):
    def set_params(self):
        super(lca_2_recon_params, self).set_params()
        self.model_type = 'lca'
        self.model_name = 'lca_2_recon'
        self.version = '0'
        self.layer_types = ['fc']
        self.standardize_data = False
        self.rescale_data_to_one = False
        self.center_dataset = False
        self.batch_size = 1
        self.dt = 0.001
        self.tau = 0.2
        self.num_steps = 75
        self.rectify_a = True
        self.thresh_type = 'hard'
        self.compute_helper_params()
        
params = lca_2_recon_params()
params.set_params()
params.layer_channels = model.lca_2.params.layer_channels
params.sparse_mult = model.lca_2.params.sparse_mult
params.data_shape = list(beta_2.shape)
params.epoch_size = 1
params.num_pixels = np.prod(params.data_shape)

lca_2_recon_model = loaders.load_model(params.model_type)
lca_2_recon_model.setup(params)
lca_2_recon_model.to(params.device)
lca_2_recon_model.eval()
with torch.no_grad():
    lca_2_recon_model.weight = nn.Parameter(model.pool_2.weight)
alpha_2_hat = lca_2_recon_model(beta_2)[:, :, None, None]

In [None]:
with torch.no_grad():
    beta_1_hat = F.conv_transpose2d(
        input=alpha_2_hat,
        weight=model.lca_2.weight,
        bias=None,
        stride=model.lca_2.params.stride,
        padding=model.lca_2.params.padding)

In [None]:
from DeepSparseCoding.modules.lca_module import LcaModule
from DeepSparseCoding.models.base_model import BaseModel
from DeepSparseCoding.utils.run_utils import compute_deconv_output_shape
import DeepSparseCoding.modules.losses as losses

class TransposedLcaModule(LcaModule):
    def setup_module(self, params):
        super(TransposedLcaModule, self).setup_module(params)
        if self.params.layer_types[0] == 'conv':
            self.layer_output_shapes = [self.params.data_shape] # [channels, height, width]
            assert (self.params.data_shape[-1] % self.params.stride == 0), (
                f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}')
            self.w_shape = [
                self.params.layer_channels,
                self.params.data_shape[0], # channels = 1
                self.params.kernel_size,
                self.params.kernel_size
            ]
            output_height = compute_deconv_output_shape(
                self.layer_output_shapes[-1][1],
                self.params.kernel_size,
                self.params.stride,
                self.params.padding,
                output_padding=self.params.output_padding,
                dilation=1)
            output_width = compute_deconv_output_shape(
                self.layer_output_shapes[-1][2],
                self.params.kernel_size,
                self.params.stride,
                self.params.padding,
                output_padding=self.params.output_padding,
                dilation=1)
            self.layer_output_shapes.append([self.params.layer_channels, output_height, output_width])
        w_init = torch.randn(self.w_shape)
        w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps)
        self.weight = nn.Parameter(w_init_normed, requires_grad=True)

    def compute_excitatory_current(self, input_tensor, a_in, weight=None):
        if weight is None:
            weight = self.weight
        if self.params.layer_types[0] == 'fc':
            excitatory_current = torch.matmul(input_tensor, weight)
        else:
            recon = self.get_recon_from_latents(a_in, weight)
            recon_error = input_tensor - recon
            error_injection = F.conv_transpose2d(
                input=recon_error,
                weight=weight,
                bias=None,
                stride=self.params.stride,
                padding=self.params.padding,
                output_padding=self.params.output_padding,
                dilation=1
            )
            excitatory_current = error_injection + a_in
        return excitatory_current

    def get_recon_from_latents(self, a_in, weight=None):
        if weight is None:
            weight = self.weight
        if self.params.layer_types[0] == 'fc':
            recon = torch.matmul(a_in, torch.transpose(weight, dim0=0, dim1=1))
        else:
            recon = F.conv2d(
                input=a_in,
                weight=weight,
                bias=None,
                stride=self.params.stride,
                padding=self.params.padding,
                dilation=1
            )
        return recon

class TransposedLcaModel(BaseModel, TransposedLcaModule):
    def setup(self, params, logger=None):
        super(TransposedLcaModel, self).setup(params, logger)
        self.setup_module(params)
        self.setup_optimizer()
        if params.checkpoint_boot_log != '':
            checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)
            self.module.load_state_dict(checkpoint['model_state_dict'])
            self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    def get_total_loss(self, input_tuple):
        input_tensor, input_labels = input_tuple
        latents = self.get_encodings(input_tensor)
        recon = self.get_recon_from_latents(latents)
        recon_loss = losses.half_squared_l2(input_tensor, recon)
        sparse_loss = self.params.sparse_mult * losses.l1_norm(latents)
        total_loss = recon_loss + sparse_loss
        return total_loss

    def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
        if update_dict is None:
            update_dict = super(TransposedLcaModel, self).generate_update_dict(input_data, input_labels, batch_step)
        stat_dict = dict()
        latents = self.get_encodings(input_data)
        recon = self.get_recon_from_latents(latents)
        recon_loss = losses.half_squared_l2(input_data, recon).item()
        sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item()
        stat_dict['weight_lr'] = self.scheduler.get_lr()[0]
        stat_dict['loss_recon'] = recon_loss
        stat_dict['loss_sparse'] = sparse_loss
        stat_dict['loss_total'] = recon_loss + sparse_loss
        stat_dict['input_max_mean_min'] = [
                input_data.max().item(), input_data.mean().item(), input_data.min().item()]
        stat_dict['recon_max_mean_min'] = [
                recon.max().item(), recon.mean().item(), recon.min().item()]
        def count_nonzero(array, dim):
            # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7
            return torch.sum(array !=0, dim=dim, dtype=torch.float)
        latent_dims = tuple([i for i in range(len(latents.shape))])
        latent_nnz = count_nonzero(latents, dim=latent_dims).item()
        stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel()
        if self.params.layer_types[0] == 'conv':
            latent_map_dims = latent_dims[2:]
            latent_map_size = np.prod(list(latents.shape[2:]))
            latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size
            latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item()
            stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz
            num_channels = latents.shape[1]
            latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item()
            stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz
        update_dict.update(stat_dict)
        return update_dict

In [None]:
class lca_1_recon_params(LcaParams):
    def set_params(self):
        super(lca_1_recon_params, self).set_params()
        self.model_type = 'lca'
        self.model_name = 'lca_1_recon'
        self.version = '0'
        self.layer_types = ['conv']
        self.standardize_data = False
        self.rescale_data_to_one = False
        self.center_dataset = False
        self.batch_size = 1
        self.dt = 0.001
        self.tau = 0.2
        self.num_steps = 75
        self.rectify_a = True
        self.thresh_type = 'hard'
        self.compute_helper_params()
        
params = lca_1_recon_params()
params.set_params()
params.layer_channels = model.pool_1.params.layer_channels[0]
params.kernel_size = model.pool_1.params.pool_ksize
params.stride = model.pool_1.params.pool_stride
params.padding = 0
params.sparse_mult = 0.01#model.lca_1.params.sparse_mult
params.data_shape = list(beta_1_hat.shape[1:])
params.epoch_size = 1
params.output_padding = 1
params.num_pixels = np.prod(params.data_shape)

lca_1_recon_model = TransposedLcaModel()
lca_1_recon_model.setup(params)
lca_1_recon_model.to(params.device)
lca_1_recon_model.eval()
with torch.no_grad():
    lca_1_recon_model.weight = nn.Parameter(model.pool_1.weight)
alpha_1_hat = lca_1_recon_model(beta_1_hat)

In [None]:
with torch.no_grad():
    recon = F.conv_transpose2d(
        input=alpha_1_hat,
        weight=model.lca_1.weight,
        bias=None,
        stride=model.lca_1.params.stride,
        padding=model.lca_1.params.padding)

In [None]:
alpha_2_nnz = torch.sum(alpha_2_hat !=0,
          dim=tuple([i for i in range(len(alpha_2_hat.shape))]),
          dtype=torch.float)/alpha_2_hat.numel()
alpha_1_nnz = torch.sum(alpha_1_hat !=0,
          dim=tuple([i for i in range(len(alpha_1_hat.shape))]),
          dtype=torch.float)/alpha_1_hat.numel()
print(
    f'beta2 shape = {beta_2.shape}' + 
    f'\nalpha2^ nnz = {alpha_2_nnz}'+
    f'\nalpha2^ shape = {alpha_2_hat.shape}'+
    f'\nbeta1^ shape = {beta_1_hat.shape}'
    f'\nalpha1^ nnz = {alpha_1_nnz}'+
    f'\nalpha1^ shape = {alpha_1_hat.shape}'+
    f'\nimage^ shape = {recon.shape}'
)

In [None]:
print(
    f'recon min = {recon.min().item()}'+
    f'\nrecon mean = {recon.mean().item()}'+
    f'\nrecon max = {recon.max().item()}'+
    f'\nrecon std = {recon.std().item()}')

plot_recon = ((recon.squeeze() * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)
fig, ax = plot.subplots(nrows=1, ncols=1)
ax = pf.clear_axis(ax)
ax.imshow(plot_recon, vmin=0, vmax=1)
plot.show()

In [None]:
plot_recon = recon.squeeze().cpu().numpy().transpose(1,2,0)
plot_recon = (plot_recon - plot_recon.min()) / (plot_recon.max() - plot_recon.min())
fig, ax = plot.subplots(nrows=1, ncols=1)
ax = pf.clear_axis(ax)
ax.imshow(plot_recon)
plot.show()