# Libraries, parameters

In [1]:
%load_ext rpy2.ipython 

import numpy as np
import pickle
import gzip
import pandas as pd

In [3]:
# Optimization params
batch_size = 256
epochs_per_task = 20 
learning_rate=1e-3
n_tasks = 10

opt_name = 'adam'

importance_metric = "fishers"

protocols = ["fisher[omega_decay=sum]", "path_int[omega_decay=sum,xi=0.1]", "unregularized[replay_prior=True]", "unregularized[replay_prior=False]"]

# Data loading and quick analyses of importance correlations

In [4]:
# for each layers weights, what is the correlation between the importance
# measures of a weight on one task and another? Positive correlations mean
# weights are generally being reused, negative means weights are being
# split up between tasks without much knowledge sharing.
results = {'protocol': [],
           'cval': [],
           'parameter_array': [],
           'age_1': [],
           'age_2': [],
           'importance_correlation': []
          }

for protocol in protocols:
    filename = "data_%s_opt%s_lr%.2e_bs%i_ep%i_tsks%i_%s.pkl.gz" % (protocol, opt_name, learning_rate, batch_size, epochs_per_task, n_tasks, importance_metric)
    with gzip.open(filename, 'rb') as f:
        loaded_data = pickle.load(f)
        for cval in loaded_data.keys():
            for layer in loaded_data[cval].keys():
                these_importances = [np.ndarray.flatten(x) for x in loaded_data[cval][layer]]
                age_1s = []
                age_2s = []
                correlations = []
                for age_1 in range(n_tasks - 1):
                    for age_2 in range(age_1 + 1, n_tasks):
                        correlation = np.corrcoef(these_importances[age_1], these_importances[age_2])[0, 1]
                        correlations.append(correlation)
                        age_1s.append(age_1)
                        age_2s.append(age_2)
                results['protocol'].extend([protocol] * len(correlations))
                results['cval'].extend([cval] * len(correlations))
                results['parameter_array'].extend([layer] * len(correlations))
                results['age_1'].extend(age_1s)
                results['age_2'].extend(age_2s)
                results['importance_correlation'].extend(correlations)
                
                
                        
                

In [5]:
#print(results)
results = pd.DataFrame(results)

In [41]:
%%R -i results -w 10 -h 10 --units in -r 200
library(tidyverse)
theme_set(theme_bw() +
          theme(panel.grid=element_blank()))

#print(results %>% 
#    filter(importance_correlation < 0))

ggplot(results %>% 
         filter(grepl('unregularized', protocol) | cval==0.1), 
       aes(x=importance_correlation, color=protocol)) +
    geom_density() +
    facet_grid(. ~ parameter_array) +
    scale_x_continuous(breaks=c(-0.5, 0., 0.5, 1.),
                       limits=c(-0.5, 1.)) +
    scale_color_brewer(palette="Dark2")

ggsave("importance_correlations_by_method.png", width=10, height=5)

ggplot(results %>% 
         filter(grepl('unregularized', protocol) | cval==0.1), 
       aes(x=importance_correlation, color=protocol)) +
    geom_density() +
    facet_grid(protocol + cval ~ parameter_array) +
    scale_x_continuous(breaks=c(-0.5, 0., 0.5, 1.),
                       limits=c(-0.5, 1.)) +
    scale_color_brewer(palette="Dark2")

ggsave("importance_correlations_by_method_spread.png", width=10, height=10)

ggplot(results %>% 
         filter(grepl('unregularized', protocol) | cval==0.1) %>%
         group_by(age_1, age_2, protocol, cval, parameter_array) %>%
         summarize(mean_importance_correlation=mean(importance_correlation)) %>% 
         ungroup(),
       aes(x=age_1, y=age_2, fill=mean_importance_correlation)) +
    geom_raster() +
    facet_grid(protocol + cval ~ parameter_array) +
    scale_fill_distiller(palette="RdBu", values=c(-1, 1))

ggsave("importance_correlations_by_method_age.png", width=15, height=10)


# Data loading and quick analyses of peformance

In [18]:
# for each layers weights, what is the correlation between the importance
# measures of a weight on one task and another? Positive correlations mean
# weights are generally being reused, negative means weights are being
# split up between tasks without much knowledge sharing.
perf_results = {'protocol': [],
           'cval': [],
           'parameter_array': [],
           'age': [],
           'task': [],
           'accuracy': []
          }

for protocol in protocols:
    filename = "data_%s_opt%s_lr%.2e_bs%i_ep%i_tsks%i.pkl.gz" % (protocol, opt_name, learning_rate, batch_size, epochs_per_task, n_tasks)
    with gzip.open(filename, 'rb') as f:
        loaded_data = pickle.load(f)
        for cval, raw_accuracies in loaded_data.items():
            accuracies = []
            ages = []
            tasks = []
            # could be more efficient
            for age in range(n_tasks):
                for task in range(n_tasks):
                    accuracies.append(raw_accuracies[age][task])
                    ages.append(age)
                    tasks.append(task)
            perf_results['protocol'].extend([protocol] * len(accuracies))
            perf_results['cval'].extend([cval] * len(accuracies))
            perf_results['parameter_array'].extend([layer] * len(accuracies))
            perf_results['age'].extend(ages)
            perf_results['task'].extend(tasks)
            perf_results['accuracy'].extend(accuracies)
                
                
                        
                

In [19]:
#print(results)
perf_results = pd.DataFrame(perf_results)

In [28]:
%%R -i perf_results -w 10 -h 10 --units in -r 200
library(tidyverse)
theme_set(theme_bw() +
          theme(panel.grid=element_blank()))

g = ggplot(perf_results %>% 
         filter(grepl('unregularized', protocol) | cval==0.1), 
       aes(x=age, y=accuracy, color=protocol)) +
    geom_line() +
    geom_point() +

    facet_wrap(. ~ task) +
    scale_color_brewer(palette="Dark2")

ggsave("accuracy_by_method.png", g, width=8, height=8)

g2 = g + 
  ylim(0.9, 1)

ggsave("accuracy_by_method_zoomed_in.png", g2, width=8, height=8)

