Importing relevant packages and modules

In [None]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline

In [None]:
import sys
from pathlib import Path
import torch
import numpy as np
import tqdm

# Plotting
import matplotlib
from matplotlib import pyplot as plt
plt.style.use('bioAI.mplstyle')

'''
CUSTOM PACKAGES
'''
# avoid adding multiple relave paths to sys.path
sys.path.append("../src") if "../src" not in sys.path else None

from Models import SorscherRNN
from Experiment import Experiment
from datahandling import Dataset, MESampler
from plotting_functions import *
from synthetic_grid_cells import *
from methods import *
from stats import *

import utils

In [None]:
#base_path = Path("/storA/GridCells/")
base_path = "/mnt/WD12TB-HDD"
experiment = Experiment(name="gg-3ME", base_path=base_path)
experiment.setup()
boxsize = experiment.environments[0].boxsize

In [None]:
module_indices = np.load(f"{experiment.paths['experiment']}/module_indices_uninteresting.npz")
print(module_indices.files)
clusters = [module_indices[f] for f in module_indices.files]

module_indices = np.load(f"{experiment.paths['experiment']}/module_indices_new.npz")
print(module_indices.files)
module_indices = module_indices['C0_from_env_2']
ncells = len(module_indices)

In [None]:
ratemaps = utils.load_ratemaps(experiment)
mean_fire = np.nanmean(ratemaps[1][clusters[0]],axis=(1,2))
baddies = np.argsort(mean_fire)[::-1][:8]
ratemaps = [] # free up memory
clusters[0] = np.delete(clusters[0],baddies)

In [None]:
for cluster in clusters:
    print(len(cluster), len(set(cluster).intersection(set(module_indices))))

In [None]:
def load_model(experiment, random_model=False):
    # load weights
    checkpoint_filenames = filenames(experiment.paths['checkpoints'])
    # load model latest (wrt. #epochs trained)
    print(f"Loading model at epoch = {checkpoint_filenames[-1]}", experiment.paths['checkpoints'] / checkpoint_filenames[-1])
    checkpoint = torch.load(experiment.paths['checkpoints'] / checkpoint_filenames[-1])
    # instantiate trained model this time
    model = SorscherRNN(experiment.pc_ensembles, Ng=experiment.params['Ng'], Np=experiment.params['Np'])
    if not random_model:
        model.load_state_dict(checkpoint['model_state_dict'])
    return model

model = load_model(experiment)
random_model = load_model(experiment, random_model=True)

In [None]:
# detach experiment specifics
params = experiment.params
environments = experiment.environments
agents = experiment.agents
pc_ensembles = experiment.pc_ensembles
paths = experiment.paths

num_workers = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"{device=}")

# Initialise data loading
num_samples = params['nsteps'] * params['batch_size'] # * params['nepochs']
dataset = Dataset(agents = agents, pc_ensembles = pc_ensembles, num_samples = num_samples, seq_len=20)#, **params)
datasampler = eval(params['sampler'])(num_environments = len(environments), num_samples = num_samples, \
                                      num_epochs = params['nepochs'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=params['batch_size'], sampler = datasampler, num_workers=num_workers)

In [None]:
def pred_inference(model, inputs, labels, positions, indices, trajectory_slice=None):
    trajectory_slice = slice(0,positions.shape[1]) if trajectory_slice is None else trajectory_slice
    position_slice = slice(trajectory_slice.start,trajectory_slice.stop+1)
    indices = np.array(indices)
    log_predictions = model(inputs, log_softmax=True)
    #loss = self.loss_fn(log_predictions, labels, weight_decay)
    #labels = labels.to(self.device, dtype=self.dtype)
    positions = positions.to(model.device, dtype=model.dtype)
    pred_error = model.position_error(log_predictions[:,trajectory_slice], 
                                      positions[:,position_slice], indices, model.place_cell_ensembles)
    return pred_error.item()

In [None]:
def prune_model(model, dataloader, module_indices, clusters, nsteps=40, trajectory_slice=slice(19,20)):
    pe_true = []
    pe_cluster0 = []
    pe_cluster1 = []
    pe_cluster2 = []
    pe_random_model = []

    # bag/hat to keep continuous pruning idxs in
    cluster0 = np.array([])
    cluster1 = np.array([])
    cluster2 = np.array([])

    i = 0
    for inputs, labels, positions, indices in dataloader:
        if i == nsteps:
            break

        # get number of cells to prune
        if (len(module_indices) % nsteps) > i:
            ncells2prune = int(len(module_indices)  / nsteps) + 1
        else:
            ncells2prune = int(len(module_indices)  / nsteps)

        # true
        model.prune_mask = []
        pe_true.append(pred_inference(model, inputs, labels, positions, indices, trajectory_slice))

        # cluster 0
        remaining = list(set(clusters[0]) - set(cluster0))
        tmp = np.random.choice(remaining, size=ncells2prune if len(remaining) > ncells2prune else len(remaining), replace=False)
        cluster0 = np.append(cluster0, tmp).astype(int)
        model.prune_mask = cluster0
        pe_cluster0.append(pred_inference(model, inputs, labels, positions, indices, trajectory_slice))

        # cluster 1
        remaining = list(set(clusters[1]) - set(cluster1))
        tmp = np.random.choice(remaining, size=ncells2prune if len(remaining) > ncells2prune else len(remaining), replace=False)
        cluster1 = np.append(cluster1, tmp).astype(int)
        model.prune_mask = cluster1
        pe_cluster1.append(pred_inference(model, inputs, labels, positions, indices, trajectory_slice))

        # cluster 2
        remaining = list(set(clusters[2]) - set(cluster2))
        tmp = np.random.choice(remaining, size=ncells2prune if len(remaining) > ncells2prune else len(remaining), replace=False)
        cluster2 = np.append(cluster2, tmp).astype(int)
        model.prune_mask = cluster2
        pe_cluster2.append(pred_inference(model, inputs, labels, positions, indices, trajectory_slice))

        # random model
        pe_random_model.append(pred_inference(random_model, inputs, labels, positions, indices, trajectory_slice))
        
        i+=1
            
    return pe_true, pe_cluster0, pe_cluster1, pe_cluster2, pe_random_model

In [None]:
def prune_stats(nstats, *args, **kwargs):
    stats = []
    for j in tqdm.trange(nstats):
        stats.append(np.array(prune_model(*args, **kwargs)))
    return np.array(stats) # shape: (nstats x 6 x nsteps)

In [None]:
nsteps = 60
nstats = 30
load_stats = False

if load_stats:
    with open(experiment.paths['experiment'] / "pruning_errors_uninteresting_clusters.pkl", "rb") as f:
        pruning_errors = pickle.load(f)
else:
    pruning_errors = prune_stats(nstats, model, dataloader, module_indices, clusters, nsteps=nsteps)
    # save pruning errors statistics - since it takes so long to compute it can be loaded instead
    with open(experiment.paths['experiment'] / "pruning_errors_uninteresting_clusters.pkl", "wb") as f:
        pickle.dump(pruning_errors, f)

In [None]:
panesize = set_size(width=345, mode='tall')
panesize

figsize=(panesize[0],panesize[1]*2/6)
figsize=np.array(figsize)

In [None]:
fig, ax = plt.subplots(figsize=figsize)
x_ticks = np.linspace(0, ncells, nsteps)
labels = ['Full Model', 'Cluster1', 'Cluster2', 'Cluster3', 'Full Untrained Model']
ls = ['-']*len(labels)
ls[0] = '-.'
ls[-1] = '-.'

mean_error = np.mean(pruning_errors,axis=0)
std_error = np.std(pruning_errors,axis=0)
median_error = np.median(pruning_errors,axis=0)
mad_error = mad(pruning_errors,axis=0)
for i in range(len(labels)): # set order to change color according to color-cycler
    ax.plot(x_ticks, mean_error[i], label=labels[i])#np.mean(pruning_i,axis=0))
    ax.fill_between(x_ticks, mean_error[i] + std_error[i], mean_error[i] - std_error[i], alpha=0.1)
    #ax.plot(x_ticks, median_error[i], label=labels[i], ls=ls[i])#np.mean(pruning_i,axis=0))
    #ax.fill_between(x_ticks, median_error[i] + mad_error[i], median_error[i] - mad_error[i], alpha=0.1)

ax.scatter(x_ticks[None]*np.ones((30,1)), pruning_errors[:,1], s=5, alpha=0.5, c=ax.lines[-1].get_color())


ax.legend()
ax.set_xlabel('#Pruned Cells')
ax.set_ylabel('Decoding Error')

ax.axvline(20, ls=':')
ax.axvline(50, ls=':')
ax.axvline(200, ls=':')

#fig.savefig("/home/vemundss/Desktop/prune")
#fig.savefig(experiment.paths['experiment'] / f'plots/pruning')
fig.savefig(experiment.paths['experiment'] / f'plots/pruning_uninteresting_clusters')


### Investigate pruning distribution

In [None]:
prune_type_id = 1

for i in range(pruning_errors.shape[-1]):
    fig,ax = plt.subplots()
    ax.hist(pruning_errors[:,prune_type_id,i])
    m = np.mean(pruning_errors,axis=0)[prune_type_id,i]
    s = np.std(pruning_errors,axis=0)[prune_type_id,i]
    med = np.median(pruning_errors,axis=0)[prune_type_id,i]
    ma = mad(pruning_errors,axis=0)[prune_type_id,i]
    ax.axvline(m, color='green')
    ax.axvline(m + s, color='red')
    ax.axvline(m - s, color='red')
    
    ax.axvline(med, ls=':', color='green')
    ax.axvline(med + ma, ls=':', color='red')
    ax.axvline(med - ma, ls=':', color='red')
    #ax.set_title(np.mean())

In [None]:
np.min(pruning_errors,axis=0)[1], np.max(pruning_errors,axis=0)[1]

In [None]:
x_ticks = np.linspace(0, ncells, len(pe_true))
plt.plot(x_ticks, pe_true, label='Full Model')
plt.plot(x_ticks, pe_random, label='Random Pruning')
plt.plot(x_ticks, pe_gcs, label='High GCS Pruning')
plt.plot(x_ticks, pe_random_torus, label='Random Torus Pruning')
#plt.plot(x_ticks, pe_sorted_torus, label='Sorted Torus Pruning')
plt.plot(x_ticks, pe_random_inverse_torus, label='Random Inverse Torus Pruning')
plt.legend()
plt.xlabel('#Pruned Cells')
plt.ylabel('Decoding Error')

plt.savefig("/home/vemundss/Desktop/prune")
plt.savefig(experiment.paths['experiment'] / f'plots/pruning')

In [None]:
# NEXT
# -- DONE -- prune with high GCS
# -- DONE -- prune toroid cells sorted on phase - physics phase transtition?
# legg til error shadings på grafene.
# include adverserial attack?
# -- DONE -- prune inverse of toroid cells. path integration remain? ratemaps still grids? toroid still there?
# select phases based on e.g. right side of box.

# Include random initialised network without pruning to show baseline decoding error