# Visualize compression along irrelevant axis over training

In [3]:
import pickle
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from itertools import combinations

#### Function for loading results

In [4]:
def load_results(results_fn, rep_name, averaged=True):
    
    # Open file
    results_dir = '../../results/'
    results_path = os.path.join(results_dir,results_fn)
    with open(results_path, 'rb') as f:
        data = pickle.load(f)
    analysis = data['analysis']
    
    # List with all results
    params = [[s['get_orth_vis_params'][rep_name+'_ctx'] for s in run] for run in analysis]
    
    # Get useful variables (fixed across checkpoints/runs)
    n_states = params[0][0]['n_states']
    locs = params[0][0]['locs']
    idx2loc = params[0][0]['idx2loc']
    C_idxs = params[0][0]['C_idxs']
    P_idxs = params[0][0]['P_idxs']
    
    # Mappings from indices to groups
    idx2c = {}
    for idx in range(n_states):
        for c, group in enumerate(C_idxs):
            if idx in group:
                idx2c[idx] = c

    idx2p = {}
    for idx in range(n_states):
        for p, group in enumerate(P_idxs):
            if idx in group:
                idx2p[idx] = p

    # Get visualization parameters in context 0
    alpha_0 = [[p['alpha_0'] for p in run] for run in params]
    beta_0 = [[p['beta_0'] for p in run] for run in params]
    alpha_0 = np.array(alpha_0) # [n_runs, n_checkpoints, n_params]
    beta_0 = np.array(beta_0) # [n_runs, n_checkpoints, n_params]
    
    # Get visualization parameters in context 1
    alpha_1 = [[p['alpha_1'] for p in run] for run in params]
    beta_1 = [[p['beta_1'] for p in run] for run in params]
    alpha_1 = np.array(alpha_1) # [n_runs, n_checkpoints, n_params]
    beta_1 = np.array(beta_1) # [n_runs, n_checkpoints, n_params]
    
    # Get accuracy results
    train_results = data['results']
    train_accs = []
    test_accs = []
    for run in train_results:
        train_accs.append([s['acc'] for s in run['train_accs']])
        test_accs.append([s['acc'] for s in run['test_accs']])
    
    # Get regression with 1D rank differences
    t_vals = []
    for run in analysis:
        t_vals.append([s['regression_with_1D'][rep_name]['categorical_regression']['t_statistic'][1] for s in run])
    t_vals = np.array(t_vals)
    
    
    # Average over runs
    if averaged:
        alpha_0 = np.mean(alpha_0, axis=0) # [n_checkpoints, n_params]
        beta_0 = np.mean(beta_0, axis=0)   # [n_checkpoints, n_params]
        alpha_1 = np.mean(alpha_1, axis=0) # [n_checkpoints, n_params]
        beta_1 = np.mean(beta_1, axis=0)   # [n_checkpoints, n_params]
        train_accs = np.mean(train_accs, axis=0)
        test_accs = np.mean(test_accs, axis=0)
        t_vals = np.mean(t_vals, axis=0)
    else:
        alpha_0 = alpha_0[0] # [n_checkpoints, n_params]
        beta_0 = beta_0[0]   # [n_checkpoints, n_params]
        alpha_1 = alpha_1[0] # [n_checkpoints, n_params]
        beta_1 = beta_1[0]   # [n_checkpoints, n_params]
        train_accs = np.array(train_accs[0])
        test_accs = np.array(test_accs[0])
        t_vals = t_vals[0] # [n_checkpoints]
    
    # Return results
    results = {'n_states': n_states,
               'locs': locs,
               'idx2c': idx2c,
               'idx2p': idx2p,
               'alpha_0': alpha_0,
               'beta_0': beta_0,
               'alpha_1': alpha_1,
               'beta_1': beta_1,
               'train_accs': train_accs,
               'test_accs': test_accs,
               't_vals': t_vals}
    
    return results

#### Function for reconstructing grid from params

In [5]:
def reconstruct_grid(alpha, beta, n_states, idx2c, idx2p):
    n_params = len(alpha)
    
    # Cumulative sum 
    cum_alpha = np.zeros(n_params+1)
    cum_beta = np.zeros(n_params+1)
    cum_alpha[1:] = np.cumsum(alpha)
    cum_beta[1:] = np.cumsum(beta)
    
    # Get x and y coordinates
    X = np.zeros([n_states,2])
    for idx in range(n_states):
        c = idx2c[idx] # C group
        p = idx2p[idx] # P group
        X[idx,0] = cum_alpha[c] # x coordinate
        X[idx,1] = cum_beta[p]  # y coordinate
    
    # Mean-center
    X = X - np.mean(X, axis=0, keepdims=True)
    
    return X

#### Function for building .gif

In [6]:
def build_gif(results, model_name):
    # Unpack results
    n_states = results['n_states']
    locs = results['locs']
    idx2c = results['idx2c']
    idx2p = results['idx2p']
    alpha_0 = results['alpha_0']
    beta_0 = results['beta_0']
    alpha_1 = results['alpha_1']
    beta_1 = results['beta_1']
    train_accs = results['train_accs']
    test_accs = results['test_accs']
    t_vals = results['t_vals']
    
    # Reconstruct grid for each time point
    n_steps = len(alpha_0)
    reconstruction_0 = np.zeros([n_steps, n_states, 2])
    for t, (alpha0_i, beta0_i) in enumerate(zip(alpha_0,beta_0)):
        X_0 = reconstruct_grid(alpha0_i, beta0_i, n_states, idx2c, idx2p)
        reconstruction_0[t,:,:] = X_0
    reconstruction_1 = np.zeros([n_steps, n_states, 2])    
    for t, (alpha1_i, beta1_i) in enumerate(zip(alpha_1,beta_1)):
        X_1 = reconstruct_grid(alpha1_i, beta1_i, n_states, idx2c, idx2p)
        reconstruction_1[t,:,:] = X_1
    reconstruction = np.concatenate([reconstruction_0, reconstruction_1], axis=1)
    
    # Prepare to plot reconstruction
    xmin = np.min(reconstruction[:,:,0])
    xmax = np.max(reconstruction[:,:,0])
    ymin = np.min(reconstruction[:,:,1])
    ymax = np.max(reconstruction[:,:,1])
    eps = 0.1*(np.max([xmax-xmin, ymax-ymin]))

    t_vals_max = np.max(t_vals)
    t_vals_min = np.min(t_vals)
    
    filenames = []
    for t,M in enumerate(reconstruction):
        fig, ax = plt.subplots(3, 1, 
                               figsize=[8,12], 
                               gridspec_kw={'height_ratios': [1,1,3]})

        # Congruent vs. incongruent accuracies over time
        ax[0].plot(train_accs[:t], c='tab:purple')
        ax[0].plot(test_accs[:t], c='tab:orange')
        ax[0].plot(t-1, train_accs[t-1], marker='o', c='tab:purple')
        ax[0].plot(t-1, test_accs[t-1], marker='o', c='tab:orange')
        ax[0].set_title("Train and test accuracy")
        ax[0].set_xlim([0,n_steps-1])
        ax[0].set_ylim([-0.05,1.05])
        ax[0].set_xlabel("Steps")
        ax[0].set_ylabel("Accuracy")
        ax[0].legend(['Train', 'Test'], loc='lower right')

        # T statistic for 1D regression
        ax[1].plot(t_vals[:t], c='tab:green')
        ax[1].plot(t-1, t_vals[t-1], marker='o', c='tab:green')
        ax[1].set_title("Compression along irrelevant axis")
        ax[1].set_xlim([0,n_steps-1])
        ax[1].set_ylim([t_vals_min,t_vals_max])
        ax[1].set_xlabel("Steps")
        ax[1].set_ylabel("T statistic")

        # Reconstructed grid
        M_0 = M[:n_states] # context 0
        scatter = ax[2].scatter(M_0[:,0], M_0[:,1], color='tab:red')
        for loc,m in zip(locs,M_0):
            ax[2].annotate(loc,m)
        M_1 = M[n_states:] # context 1
        scatter = ax[2].scatter(M_1[:,0], M_1[:,1], color='tab:blue')
        for loc,m in zip(locs,M_1):
            ax[2].annotate(loc,m)
        main_title = "{} Representations (reconstructed)".format(model_name.upper())
        ax[2].set_title(main_title)
        ax[2].set_xlim([xmin-eps, xmax+eps])
        ax[2].set_ylim([ymin-eps, ymax+eps])
        ax[2].set_xticks([])
        ax[2].set_yticks([])
        ax[2].legend(["Competence context", "Popularity context"])

        # Add grid lines
        for (loc1, m1), (loc2, m2) in combinations(zip(locs, M_0), 2):
            x1, y1 = loc1
            x2, y2 = loc2
            one_up = x1-x2 == 0 and abs(y1-y2) == 1
            one_over = y1-y2 == 0 and abs(x1-x2) == 1
            if one_up or one_over:
                xx = [m1[0], m2[0]]
                yy = [m1[1], m2[1]]
                ax[2].plot(xx, yy, '--', color='tab:red')    
        for (loc1, m1), (loc2, m2) in combinations(zip(locs, M_1), 2):
            x1, y1 = loc1
            x2, y2 = loc2
            one_up = x1-x2 == 0 and abs(y1-y2) == 1
            one_over = y1-y2 == 0 and abs(x1-x2) == 1
            if one_up or one_over:
                xx = [m1[0], m2[0]]
                yy = [m1[1], m2[1]]
                ax[2].plot(xx, yy, '--', color='tab:blue') 

        plt.tight_layout()
        filename = '../../results/visualize_compression_{}{}.png'.format(model_name, t)
        filenames.append(filename)

        # More time on first and last frames
        if t == n_steps-1:
            for extra_time in range(40):
                filenames.append(filename)
        plt.savefig(filename, dpi=100)
        plt.close()
        
    # Write .gif
    gif_name = 'visualize_reconstructed_compression_{}.gif'.format(model_name)
    with imageio.get_writer(gif_name, mode='I') as writer:
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)
    
    # remove files
    for filename in set(filenames):
        if os.path.isfile(filename):
            os.remove(filename)

## MLP

In [11]:
results_fn = 'mlp.P'
rep_name = 'hidden'
model_name = 'MLP0'
averaged = False

In [12]:
results = load_results(results_fn, rep_name, averaged)
build_gif(results, model_name)

<img src="visualize_reconstructed_compression_MLP0.gif" width="750" align="center">

## RNN

In [15]:
results_fn = 'rnn.P'
rep_name = 'average'
model_name = 'RNN0'
averaged = False

In [16]:
results = load_results(results_fn, rep_name, averaged)
build_gif(results, model_name)

<img src="visualize_reconstructed_compression_RNN0.gif" width="750" align="center">