# Libraries, parameters

In [1]:
%load_ext rpy2.ipython 

import numpy as np
import pickle
import gzip
import pandas as pd

In [2]:
# Optimization params
batch_size = 256
epochs_per_task = 20 
learning_rate=1e-3
n_tasks = 10

opt_name = 'adam'

importance_metric = "fishers"

protocols = ["fisher[omega_decay=sum]", "path_int[omega_decay=sum,xi=0.1]", "unregularized[replay_prior=True]", "unregularized[replay_prior=False]"]

# Data loading and quick analyses

In [3]:
# for each layers weights, what is the correlation between the importance
# measures of a weight on one task and another? Positive correlations mean
# weights are generally being reused, negative means weights are being
# split up between tasks without much knowledge sharing.
results = {'protocol': [],
           'cval': [],
           'parameter_array': [],
           'age_1': [],
           'age_2': [],
           'importance_correlation': []
          }

for protocol in protocols:
    filename = "data_%s_opt%s_lr%.2e_bs%i_ep%i_tsks%i_%s.pkl.gz" % (protocol, opt_name, learning_rate, batch_size, epochs_per_task, n_tasks, importance_metric)
    with gzip.open(filename, 'rb') as f:
        loaded_data = pickle.load(f)
        for cval in loaded_data.keys():
            for layer in loaded_data[cval].keys():
                these_importances = [np.ndarray.flatten(x) for x in loaded_data[cval][layer]]
                age_1s = []
                age_2s = []
                correlations = []
                for age_1 in range(n_tasks - 1):
                    for age_2 in range(age_1 + 1, n_tasks):
                        correlation = np.corrcoef(these_importances[age_1], these_importances[age_2])[0, 1]
                        correlations.append(correlation)
                        age_1s.append(age_1)
                        age_2s.append(age_2)
                results['protocol'].extend([protocol] * len(correlations))
                results['cval'].extend([cval] * len(correlations))
                results['parameter_array'].extend([layer] * len(correlations))
                results['age_1'].extend(age_1s)
                results['age_2'].extend(age_2s)
                results['importance_correlation'].extend(correlations)
                
                
                        
                

In [4]:
#print(results)
results = pd.DataFrame(results)

In [13]:
%%R -i results -w 10 -h 10 --units in -r 200
library(tidyverse)
theme_set(theme_bw() +
          theme(panel.grid=element_blank()))

ggplot(results %>% 
         filter(grepl('unregularized', protocol) | cval==0.1), 
       aes(x=importance_correlation, color=protocol)) +
    geom_density() +
    facet_grid(. ~ parameter_array) +
    scale_x_continuous(breaks=c(-0.5, 0., 0.5, 1.),
                       limits=c(-0.5, 1.)) +
    scale_color_brewer(palette="Dark2")

ggsave("importance_correlations_by_method.png", width=10, height=5)

ggplot(results %>% 
         filter(grepl('unregularized', protocol) | cval==0.1), 
       aes(x=importance_correlation, color=protocol)) +
    geom_density() +
    facet_grid(protocol + cval ~ parameter_array) +
    scale_x_continuous(breaks=c(-0.5, 0., 0.5, 1.),
                       limits=c(-0.5, 1.)) +
    scale_color_brewer(palette="Dark2")

ggsave("importance_correlations_by_method_spread.png", width=10, height=10)

