# Libraries, parameters

In [1]:
%load_ext rpy2.ipython 

import numpy as np
import pickle
import gzip
import pandas as pd

In [8]:
# Optimization params
batch_size = 256
epochs_per_task = 20 
learning_rate=1e-3
n_tasks = 10

opt_name = 'adam'

importance_metric = "fishers"

protocols = ["fisher[omega_decay=sum]", "path_int[omega_decay=sum,xi=0.1]", "unregularized[replay_prior=True]"]

# Data loading and quick analyses

In [14]:
# for each layers weights, what is the correlation between the importance
# measures of a weight on one task and another? Positive correlations mean
# weights are generally being reused, negative means weights are being
# split up between tasks without much knowledge sharing.
results = {'protocol': [],
           'cval': [],
           'parameter_array': [],
           'age_1': [],
           'age_2': [],
           'importance_correlation': []
          }

for protocol in protocols:
    filename = "data_%s_opt%s_lr%.2e_bs%i_ep%i_tsks%i_%s.pkl.gz" % (protocol, opt_name, learning_rate, batch_size, epochs_per_task, n_tasks, importance_metric)
    with gzip.open(filename, 'rb') as f:
        loaded_data = pickle.load(f)
        for cval in loaded_data.keys():
            for layer in loaded_data[cval].keys():
                these_importances = [np.ndarray.flatten(x) for x in loaded_data[cval][layer]]
                age_1s = []
                age_2s = []
                correlations = []
                for age_1 in range(n_tasks - 1):
                    for age_2 in range(age_1 + 1, n_tasks):
                        correlation = np.corrcoef(these_importances[age_1], these_importances[age_2])[0, 1]
                        correlations.append(correlation)
                        age_1s.append(age_1)
                        age_2s.append(age_2)
                results['protocol'].extend([protocol] * len(correlations))
                results['cval'].extend([cval] * len(correlations))
                results['parameter_array'].extend([layer] * len(correlations))
                results['age_1'].extend(age_1s)
                results['age_2'].extend(age_2s)
                results['importance_correlation'].extend(correlations)
                
                
                        
                

0.35967755841233884
0.26689410383122414
0.17361847087171406
0.11417570525526817
0.0850348621387023
0.09726620985178687
0.059262991326426755
0.11333664464724755
0.044279390736038114
0.5802292781548332
0.4637174462979327
0.36091261132105246
0.3007417181996549
0.22399117270839386
0.1877759508767373
0.22634556744523202
0.1404086261387641
0.7252311497560073
0.5371289812852302
0.40466290307251207
0.35453789452264245
0.25938856275482997
0.254885387309397
0.26026380796350573
0.6450481930923802
0.5046223201441234
0.45614863320428506
0.3775065843369706
0.3240585285489944
0.3507928566728867
0.6788596316986153
0.6023362712613096
0.5344183266445138
0.41420212421693076
0.3719762309640678
0.7013353284012446
0.5616500026004827
0.525435294371921
0.441350737056976
0.7290602468559432
0.5298020639031785
0.5479583966410926
0.5936757776935255
0.6057787797497971
0.6292068444899195
0.43642777559673124
0.2826958396595265
0.20927971202623108
0.18370288850200553
0.16071397309323093
0.1408931772686264
0.096949915

0.01559837401927073
0.022800300018678783
0.0003413062299527916
0.0032388068657818177
-0.011934272624504002
-0.013402945250444501
-0.006592946063673171
0.03221974784632404
0.0062066195333032375
0.005203184869834186
-0.002601773048181985
0.006779286980100714
-0.011455227859342972
0.027687669546723092
0.0007347414562756312
0.010752333872302922
0.005651134934060423
-0.005216743368904186
0.011165653530688794
0.0011977726319804187
0.019042892545673307
-0.005737965920244608
0.022101665354694097
0.0009500763041940175
0.005371401964044165
0.03490073715242773
0.014208134039625037
0.04427636425696857
0.038106322178196934
0.03205842832267599
-0.013085270317976006
-0.025916304507199082
-0.025197277045325362
-0.003540312467714309
-0.017052064545180136
-0.03715620336953363
0.018133253092895728
0.12439303231213393
0.030290949191192153
0.018000986502473178
0.019621885869390975
-0.020245242212111454
-0.008638313682631493
0.010811823068397399
-0.03188847521517478
0.05640745615773351
0.07044170072976266
-

0.044567979787149894
0.03996098379995552
0.01455016116388051
-0.00749803184442316
-0.00847186643799987
-0.00940555668798826
-0.007109871338639527
-0.005908693843772749
-0.004567860188916915
0.1818351337820508
0.031767232055864424
0.02056832882739361
0.01583069889387904
0.016213287975258842
0.001098489620811984
0.0003114385497650741
-0.0018261454360289863
0.08380939411478204
0.03290965315881829
0.021159947737731524
0.022079353776050803
0.0016538500194475238
0.00385432790877418
0.0007947164260868624
0.08913876322833877
0.060136635508305385
0.027793546746425583
0.008055834165841281
0.010957092109810796
0.002110240657332501
0.08934879766135573
0.05920537596113779
0.026452029337677235
0.01774614875878861
0.007877799150774793
0.09167651846645636
0.04477506619162689
0.02727824991107178
0.012217105894571268
0.052696452871137216
0.0535952621726638
0.02223872694298921
0.11937469261702643
0.027337347402876536
0.07511365506801669
0.007476054816052869
0.005186062612790776
0.010186468923257828
-0.01

0.02362041103069926
0.021227235655796726
0.03867719262532816
0.02035371723683202
0.008340569713167977
0.49903655728911467
0.3683433647170554
0.23909781103131514
0.27195402498950955
0.26301307289252296
0.2812723947830926
0.19769979987853706
0.17703026222210935
0.39178886399212925
0.2897208288974529
0.3429383381772517
0.35900618658226546
0.3583559646768601
0.25501370305390114
0.27667942059814626
0.439792696931212
0.3844264294719928
0.3066660586115615
0.3334012262589054
0.32372397506445905
0.190292706365638
0.41425263649538663
0.2945027856369312
0.2936601776858325
0.3192844272963517
0.19920101151909358
0.4426334094366919
0.40618877758033983
0.3715676016833355
0.2924483100288557
0.5417857173458975
0.3644214938274415
0.43777885964988733
0.46265774756725897
0.48893010629119943
0.4340207663808264
0.03661043184875744
0.0049622671651067905
0.01956192395225186
0.015521591655825734
0.00291129336698141
-3.732392962777488e-05
-0.008434496209277536
-0.006192228679047056
-0.011919541194820459
0.03426

In [12]:
print(results)

{'age_1': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,