In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
!wandb login 9676e3cc95066e4865586082971f2653245f09b4

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/guydavidson/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [4]:
import numpy as np
import pandas as pd
import scipy
from scipy import stats
from scipy.special import factorial

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import path as mpath

import pickle
import tabulate
import wandb
from collections import namedtuple
import sys

import meta_learning_data_analysis as analysis
import meta_learning_analysis_plots as plots

In [5]:
cache = analysis.refresh_cache()
print(cache.keys())

dict_keys(['six_replications_analyses', 'control_analyses', 'query_mod_replications', 'six_replications_updated_analyses', 'updated_control_analyses', 'query_mod_updated_analyses', 'forgetting_curves_raw_data', 'preliminary_maml_analyses', 'baseline_maml_comparison_analyses', 'maml_analyses', 'maml_alpha_0_analyses', 'maml_meta_test_analyses', 'balanced_batches_analyses', 'baseline_total_curve_analyses', 'control_total_curve_analyses', 'query_mod_total_curve_analyses', 'simultaneous_training_analyses', 'per_task_simultaneous_training_analyses', 'task_conditional_analyses', 'task_conditional_multiplicative_only_analyses', 'task_conditional_additive_only_analyses'])


-------

# Baseline analyses

In [None]:
if 'six_replications_analyses' in cache:
    six_replications_analyses = cache['six_replications_analyses']

else:
    six_replications_by_dimension_runs = analysis.load_runs(60)
    print('Loaded runs')

    six_reps_dict = {dimension_name:analysis.process_multiple_runs(run_set) 
                     for run_set, dimension_name 
                     in zip(six_replications_by_dimension_runs, analysis.CONDITION_ANALYSES_FIELDS)}
    six_replications_analyses = analysis.ConditionAnalysesSet(**six_reps_dict)

    cache = analysis.refresh_cache(dict(six_replications_analyses=six_replications_analyses))


In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
six_replications_by_dimension_runs = analysis.load_runs(60)
print('Loaded runs')

# note: the equal accuracy field will come in as accuracy_drops
updated_six_reps_dict = {}
start_index = 0
for run_set, dimension_name in zip(six_replications_by_dimension_runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):
    updated_six_reps_dict[dimension_name] = analysis.process_multiple_runs(
        run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size) 

# combined_analysis = analysis.process_multiple_runs(
#     six_replications_by_dimension_runs[3], 
#     parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size)

six_replications_updated_analyses = analysis.ConditionAnalysesSet(**updated_six_reps_dict)

cache = analysis.refresh_cache(dict(six_replications_updated_analyses=six_replications_updated_analyses))


In [None]:
# if 'baseline_total_curve_analyses' in cache:
#     baseline_total_curve_analyses = cache['baseline_total_curve_analyses']

# else:
six_replications_by_dimension_runs = analysis.load_runs(60)
print('Loaded runs')

analyses_per_dimension = {}

for run_set, dimension_name in zip(six_replications_by_dimension_runs, analysis.CONDITION_ANALYSES_FIELDS):
    print(f'Starting {dimension_name}')
    total_curve_raw, total_curve_mean, total_curve_std, total_curve_sem = \
        analysis.process_multiple_runs_total_task_training_curves(run_set)

    analyses_per_dimension[dimension_name] = analysis.TotalCurveResults(raw=total_curve_raw,
                                                                        mean=total_curve_mean, 
                                                                        std=total_curve_std, 
                                                                        sem=total_curve_sem)

total_curve_analyses = analysis.ConditionAnalysesSet(**analyses_per_dimension)
cache = analysis.refresh_cache(dict(baseline_total_curve_analyses=total_curve_analyses))


-------

# Control analyses

In [None]:
if 'control_analyses' in cache:
    control_analyses = cache['control_analyses']

else:
    control_runs = analysis.load_runs(150, 'meta-learning-scaling/sequential-benchmark-control', False)
    print(f'Loaded runs')
    control_analyses = analysis.ConditionAnalysesSet(combined=analysis.process_multiple_runs(control_runs.combined))

    cache = analysis.refresh_cache(dict(control_analyses=control_analyses))

In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
control_runs = analysis.load_runs(150, 'meta-learning-scaling/sequential-benchmark-control', False)
print('Loaded runs')

updated_control_analyses = analysis.ConditionAnalysesSet(
    combined=analysis.process_multiple_runs(control_runs.combined, 
                                            parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size))


cache = analysis.refresh_cache(dict(updated_control_analyses=updated_control_analyses))


In [None]:
if 'control_total_curve_analyses' in cache:
    control_total_curve_analyses = cache['control_total_curve_analyses']

else:
    control_runs = analysis.load_runs(150, 'meta-learning-scaling/sequential-benchmark-control', False)
    print('Loaded runs')
    
    total_curve_raw, total_curve_mean, total_curve_std, total_curve_sem = \
            analysis.process_multiple_runs_total_task_training_curves(control_runs.combined)
    
    control_total_curve_analyses = analysis.ConditionAnalysesSet(
        combined=analysis.TotalCurveResults(raw=total_curve_raw,
                                            mean=total_curve_mean,
                                            std=total_curve_std,
                                            sem=total_curve_sem))

    cache = analysis.refresh_cache(dict(control_total_curve_analyses=control_total_curve_analyses))


In [None]:
cache = analysis.refresh_cache(dict(control_analyses=control_analyses))

# Plot the results

## Plot the number of examples by dimension

In [None]:
ylim = (1000, 520000)

plots.plot_processed_results(first_replication_analyses.color.examples, 'Color 10-run average', ylim)
plots.plot_processed_results(first_replication_analyses.shape.examples, 'Shape 10-run average', ylim)
plots.plot_processed_results(first_replication_analyses.texture.examples, 'Material 10-run average', ylim)

In [None]:
ylim = (1000, 700000)

plots.plot_processed_results(six_replications_analyses.color.examples, 'Color 60-run average', ylim)
plots.plot_processed_results(six_replications_analyses.shape.examples, 'Shape 60-run average', ylim)
plots.plot_processed_results(six_replications_analyses.texture.examples, 'Material 60-run average', ylim)

## Plot the log of the number of examples to criterion, in each dimension, with error bars

## Plot the combined results over all 180 runs

In [None]:
ylim = (7.75, 13.25)

plots.plot_processed_results(six_replications_analyses.combined.log_examples, 'Combined 180-run average', 
                       ylim, log_x=(True, True), log_y=True, sem_n=180, shade_error=True)

## Plot the absolute accuracy after introducing a new task

In [None]:
ylim = None

plots.plot_processed_results(six_replications_analyses.color.accuracies, 'Color 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.shape.accuracies, 'Shape 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.texture.accuracies, 'Material 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.combined.accuracies, 'Combined 180-run average', 
                       ylim, log_x=False, log_y=False, sem_n=180, shade_error=True)

## Plot the accuracy drop after introducing a new task

In [None]:
ylim = None

plots.plot_processed_results(six_replications_analyses.color.accuracy_drops, 'Color 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.shape.accuracy_drops, 'Shape 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.texture.accuracy_drops, 'Material 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)

-------

# Query-modulated analyses

In [None]:
if 'query_mod_replications' in cache:
    query_mod_replications = cache['query_mod_replications']

else:
    query_mod_runs = analysis.query_modulated_runs_by_dimension(30)
    query_mod_replications = {}

    ignore_runs = [] # ('at6pkicv', )
    for mod_level in query_mod_runs:
        mod_level_runs = query_mod_runs[mod_level]

        mod_level_dict = {dimension_name: analysis.process_multiple_runs(mod_level_runs[i], ignore_runs=ignore_runs) 
                          for i, dimension_name 
                          in enumerate(analysis.CONDITION_ANALYSES_FIELDS)}

        query_mod_replications[mod_level] = analysis.ConditionAnalysesSet(**mod_level_dict)

    cache = analysis.refresh_cache(dict(query_mod_replications=query_mod_replications))

In [None]:
cache = analysis.refresh_cache(dict(query_mod_updated_analyses=query_mod_replications))

In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
query_mod_runs = analysis.query_modulated_runs_by_dimension(30)
print('Loaded runs')

# note: the equal accuracy field will come in as accuracy_drops
query_mod_replications = {}
mod_levels = list(query_mod_runs.keys())
start_index = 0

for mod_level in mod_levels[start_index:]:
    print(f'Starting mod level {mod_level}')
    mod_level_runs = query_mod_runs[mod_level]

    mod_level_dict = {dimension_name: analysis.process_multiple_runs(mod_level_runs[i], 
                                                                     parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size) 
                      for i, dimension_name 
                      in enumerate(analysis.CONDITION_ANALYSES_FIELDS)}

    query_mod_replications[mod_level] = analysis.ConditionAnalysesSet(**mod_level_dict)


cache = analysis.refresh_cache(dict(query_mod_updated_analyses=query_mod_replications))


In [None]:
cache = analysis.refresh_cache(dict(query_mod_replications=query_mod_replications,
                                    control_analyses=control_analyses,
                                    six_replications_analyses=six_replications_analyses))

In [None]:
if 'query_mod_total_curve_analyses' in cache:
    query_mod_total_curve_analyses = cache['query_mod_total_curve_analyses']

else:
    query_mod_runs = analysis.query_modulated_runs_by_dimension(30)
    print('Loaded runs')
    
    query_mod_total_curve_analyses = {}
    mod_levels = list(query_mod_runs.keys())
    start_index = 0
    
    for mod_level in mod_levels[start_index:]:
        print(f'Starting mod level {mod_level}')
        mod_level_runs = query_mod_runs[mod_level]

        analyses_per_dimension = {}
        
        for run_set, dimension_name in zip(mod_level_runs, analysis.CONDITION_ANALYSES_FIELDS):
            print(f'Starting {dimension_name}')
            total_curve_raw, total_curve_mean, total_curve_std, total_curve_sem = \
                analysis.process_multiple_runs_total_task_training_curves(run_set)

            analyses_per_dimension[dimension_name] = analysis.TotalCurveResults(raw=total_curve_raw,
                                                                                mean=total_curve_mean, 
                                                                                std=total_curve_std, 
                                                                                sem=total_curve_sem)

        query_mod_total_curve_analyses[mod_level] = analysis.ConditionAnalysesSet(**analyses_per_dimension)
            
    cache = analysis.refresh_cache(dict(query_mod_total_curve_analyses=query_mod_total_curve_analyses))


In [None]:
query_mod_total_curve_analyses.keys()

-------

# MAML

In [None]:
if 'maml_analyses' in cache:
    maml_analyses = cache['maml_analyses']

else:
    maml_runs = analysis.load_runs(30, 'meta-learning-scaling/maml-sequential-benchmark')
    print('Loaded runs')

    # ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
    #                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
    ignore_runs = set()

    maml_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(maml_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        maml_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    maml_analyses = analysis.ConditionAnalysesSet(**maml_analyses_dict)
    cache = analysis.refresh_cache(dict(maml_analyses=maml_analyses))


In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
maml_alpha_0_runs = analysis.load_runs(20, 'meta-learning-scaling/maml-alpha-0')
print('Loaded runs')

raise ValueError('This will not work yeet')

# ignore_runs = set(['ac82mceh', '7kau3ypy', 'g9ujw7gg', 'avmcbnot'])
ignore_runs = set()

maml_alpha_0_analyses_dict = {}
start_index = 0
for run_set, dimension_name in zip(maml_alpha_0_runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):
    
    maml_alpha_0_analyses_dict[dimension_name] = analysis.process_multiple_runs(
        run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
        ignore_runs=ignore_runs) 

maml_alpha_0_analyses = analysis.ConditionAnalysesSet(**maml_alpha_0_analyses_dict)
cache = analysis.refresh_cache(dict(maml_alpha_0_analyses=maml_alpha_0_analyses))


In [None]:
if 'maml_meta_test_analyses' in cache:
    maml_meta_test_analyses = cache['maml_meta_test_analyses']

else:
    maml_meta_test_runs = analysis.load_runs(30, 'meta-learning-scaling/maml-meta-test')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    maml_meta_test_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(maml_meta_test_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        maml_meta_test_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    maml_meta_test_analyses = analysis.ConditionAnalysesSet(**maml_meta_test_analyses_dict)
    cache = analysis.refresh_cache(dict(maml_meta_test_analyses=maml_meta_test_analyses))


In [None]:
if 'balanced_batches_analyses' in cache:
    balanced_batches_analyses = cache['balanced_batches_analyses']

else:
    balanced_batches_runs = analysis.load_runs(30, 'meta-learning-scaling/balanced-batches-sequential-benchmark')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    balanced_batches_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(balanced_batches_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        balanced_batches_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    balanced_batches_analyses = analysis.ConditionAnalysesSet(**balanced_batches_analyses_dict)
    cache = analysis.refresh_cache(dict(balanced_batches_analyses=balanced_batches_analyses))


In [None]:
maml_comparison_run_ids = [1000, 1001, 2000, 2001, 2002, 
                           2003, 2004, 2005, 2006, 2007, 
                           2008, 2009, 3000, 3001, 3002, 
                           3003, 3004, 3005, 3006, 3007, 3008]

maml_comparison_runs = analysis.load_runs(10, split_runs_by_dimension=False, valid_run_ids=set(maml_comparison_run_ids))
print('Loaded runs')

baseline_maml_comparison_analyses = analysis.ConditionAnalysesSet(
    combined=analysis.process_multiple_runs(maml_comparison_runs.combined, 
                                            parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,))


cache = analysis.refresh_cache(dict(baseline_maml_comparison_analyses=baseline_maml_comparison_analyses))



-----

# Simultaneous training in a dimension

In [None]:
# if 'simultaneous_training_analyses' in cache:
#     simultaneous_training_analyses = cache['simultaneous_training_analyses']

# else:
simultaneous_training_runs = analysis.load_runs(20, 'meta-learning-scaling/simultaneous-training')
print('Loaded runs')

# ignore_runs = set(['u3gk9oio'])
ignore_runs = set()

simultaneous_training_analyses_dict = {}
per_task_simultaneous_training_analyses_dict = {}
start_index = 0
for run_set, dimension_name in zip(simultaneous_training_runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

    stacked_results, all_results_mean, all_results_std, all_results_sem, per_task_results_mean, per_task_results_std, per_task_results_sem = analysis.process_multiple_runs_simultaneous_training(run_set, ignore_runs=ignore_runs)

    simultaneous_training_analyses_dict[dimension_name] = analysis.TotalCurveResults(raw=stacked_results,
                                                                                     mean=all_results_mean, 
                                                                                     std=all_results_std, 
                                                                                     sem=all_results_sem)

    per_task_simultaneous_training_analyses_dict[dimension_name] = analysis.TotalCurveResults(raw=stacked_results,
                                                                                              mean=per_task_results_mean, 
                                                                                              std=per_task_results_std, 
                                                                                              sem=per_task_results_sem)

simultaneous_training_analyses = analysis.ConditionAnalysesSet(**simultaneous_training_analyses_dict)
per_task_simultaneous_training_analyses = analysis.ConditionAnalysesSet(**per_task_simultaneous_training_analyses_dict)
cache = analysis.refresh_cache(dict(simultaneous_training_analyses=simultaneous_training_analyses,
                                    per_task_simultaneous_training_analyses=per_task_simultaneous_training_analyses))


-----

# New Task Modulation

In [None]:
if 'task_conditional_analyses' in cache:
    task_conditional_analyses = cache['task_conditional_analyses']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    # ignore_runs = set(['Task-conditional-[0, 1, 2, 3]-1017'])
    ignore_runs = set()

    task_modulated_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(task_modulated_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        task_modulated_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    task_conditional_analyses = analysis.ConditionAnalysesSet(**task_modulated_analyses_dict)
    cache = analysis.refresh_cache(dict(task_conditional_analyses=task_conditional_analyses))


In [None]:
if 'task_conditional_multiplicative_only_analyses' in cache:
    task_conditional_multiplicative_only_analyses = cache['task_conditional_multiplicative_only_analyses']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-multiplicative-only')
    print('Loaded runs')

    # ignore_runs = set(['Task-conditional-multiplicative-[0, 1, 2, 3]-1006',
    #                   'Task-conditional-multiplicative-[0, 1, 2, 3]-1005'])
    ignore_runs = set()

    task_modulated_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(task_modulated_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        task_modulated_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    task_conditional_multiplicative_only_analyses = analysis.ConditionAnalysesSet(**task_modulated_analyses_dict)
    cache = analysis.refresh_cache(dict(task_conditional_multiplicative_only_analyses=task_conditional_multiplicative_only_analyses))


In [None]:
if 'task_conditional_additive_only_analyses' in cache:
    task_conditional_additive_only_analyses = cache['task_conditional_additive_only_analyses']

else:
    task_modulated_runs = analyses_caches/alysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-additive-only')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    task_modulated_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(task_modulated_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        task_modulated_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    task_conditional_additive_only_analyses = analysis.ConditionAnalysesSet(**task_modulated_analyses_dict)
    cache = analysis.refresh_cache(dict(task_conditional_additive_only_analyses=task_conditional_additive_only_analyses))


# Saving the task-conditional weights en masse

In [52]:
if 'task_conditional_weights' in cache:
    task_conditional_weights = cache['task_conditional_weights']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set(['Task-conditional-[0, 1, 2, 3]-1013',
                      'Task-conditional-[0, 1, 2, 3]-1011',
                      'Task-conditional-[0, 1, 2, 3]-1001'])
#     ignore_runs = set()

    task_conditional_weights = analysis.parse_task_conditional_weights(task_modulated_runs.combined,
                                                                      additive=True,
                                                                      multiplicative=True,
                                                                      ignore_runs=ignore_runs)

    cache = analysis.refresh_cache(dict(task_conditional_weights=task_conditional_weights))



Loaded runs
Task-conditional-[0, 1, 2, 3]-1017
Task-conditional-[0, 1, 2, 3]-1019
Task-conditional-[0, 1, 2, 3]-1018
Task-conditional-[0, 1, 2, 3]-1009
Task-conditional-[0, 1, 2, 3]-1008
Task-conditional-[0, 1, 2, 3]-1007
Task-conditional-[0, 1, 2, 3]-1006
Task-conditional-[0, 1, 2, 3]-1016
Task-conditional-[0, 1, 2, 3]-1015
Task-conditional-[0, 1, 2, 3]-3009 10
Task-conditional-[0, 1, 2, 3]-3019
Task-conditional-[0, 1, 2, 3]-1014
Task-conditional-[0, 1, 2, 3]-3018
Task-conditional-[0, 1, 2, 3]-3008
Task-conditional-[0, 1, 2, 3]-3017
Task-conditional-[0, 1, 2, 3]-1005
Task-conditional-[0, 1, 2, 3]-3007
Task-conditional-[0, 1, 2, 3]-3016
Task-conditional-[0, 1, 2, 3]-3006
Task-conditional-[0, 1, 2, 3]-3015 20
Task-conditional-[0, 1, 2, 3]-3005
Task-conditional-[0, 1, 2, 3]-1004
Task-conditional-[0, 1, 2, 3]-3014
Task-conditional-[0, 1, 2, 3]-3004
Task-conditional-[0, 1, 2, 3]-3013
Task-conditional-[0, 1, 2, 3]-3003
Task-conditional-[0, 1, 2, 3]-3012
Task-conditional-[0, 1, 2, 3]-1003
Ta

In [53]:
if 'task_conditional_multiplicative_only_weights' in cache:
    task_conditional_multiplicative_only_weights = cache['task_conditional_multiplicative_only_weights']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-multiplicative-only')
    print('Loaded runs')

    # ignore_runs = set(['Task-conditional-multiplicative-[0, 1, 2, 3]-1006',
    #                   'Task-conditional-multiplicative-[0, 1, 2, 3]-1005'])
    ignore_runs = set()

    task_conditional_multiplicative_only_weights = analysis.parse_task_conditional_weights(task_modulated_runs.combined,
                                                                      additive=False,
                                                                      multiplicative=True,
                                                                      ignore_runs=ignore_runs)

    cache = analysis.refresh_cache(dict(task_conditional_multiplicative_only_weights=task_conditional_multiplicative_only_weights))



Loaded runs
Task-conditional-multiplicative-[0, 1, 2, 3]-1005
Task-conditional-multiplicative-[0, 1, 2, 3]-1006
Task-conditional-multiplicative-[0, 1, 2, 3]-1009
Task-conditional-multiplicative-[0, 1, 2, 3]-1008
Task-conditional-multiplicative-[0, 1, 2, 3]-1007
Task-conditional-multiplicative-[0, 1, 2, 3]-1019
Task-conditional-multiplicative-[0, 1, 2, 3]-1018
Task-conditional-multiplicative-[0, 1, 2, 3]-1017
Task-conditional-multiplicative-[0, 1, 2, 3]-1016
Task-conditional-multiplicative-[0, 1, 2, 3]-1015 10
Task-conditional-multiplicative-[0, 1, 2, 3]-1014
Task-conditional-multiplicative-[0, 1, 2, 3]-1004
Task-conditional-multiplicative-[0, 1, 2, 3]-1003
Task-conditional-multiplicative-[0, 1, 2, 3]-1013
Task-conditional-multiplicative-[0, 1, 2, 3]-3009
Task-conditional-multiplicative-[0, 1, 2, 3]-3019
Task-conditional-multiplicative-[0, 1, 2, 3]-3008
Task-conditional-multiplicative-[0, 1, 2, 3]-1012
Task-conditional-multiplicative-[0, 1, 2, 3]-3018
Task-conditional-multiplicative-[0,

In [55]:
if 'task_conditional_additive_only_weights' in cache:
    task_conditional_additive_only_weights = cache['task_conditional_additive_only_weights']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-additive-only')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    task_conditional_additive_only_weights = analysis.parse_task_conditional_weights(task_modulated_runs.combined,
                                                                      additive=True,
                                                                      multiplicative=False,
                                                                      ignore_runs=ignore_runs)

    cache = analysis.refresh_cache(dict(task_conditional_additive_only_weights=task_conditional_additive_only_weights))

Loaded runs
Task-conditional-additive-[0, 1, 2, 3]-3014
Task-conditional-additive-[0, 1, 2, 3]-3019
Task-conditional-additive-[0, 1, 2, 3]-3004
Task-conditional-additive-[0, 1, 2, 3]-3009
Task-conditional-additive-[0, 1, 2, 3]-3013
Task-conditional-additive-[0, 1, 2, 3]-3018
Task-conditional-additive-[0, 1, 2, 3]-3003
Task-conditional-additive-[0, 1, 2, 3]-3008
Task-conditional-additive-[0, 1, 2, 3]-3017
Task-conditional-additive-[0, 1, 2, 3]-3002 10
Task-conditional-additive-[0, 1, 2, 3]-3012
Task-conditional-additive-[0, 1, 2, 3]-3007
Task-conditional-additive-[0, 1, 2, 3]-1009
Task-conditional-additive-[0, 1, 2, 3]-3016
Task-conditional-additive-[0, 1, 2, 3]-3006
Task-conditional-additive-[0, 1, 2, 3]-3001
Task-conditional-additive-[0, 1, 2, 3]-3011
Task-conditional-additive-[0, 1, 2, 3]-3015
Task-conditional-additive-[0, 1, 2, 3]-3010
Task-conditional-additive-[0, 1, 2, 3]-3005 20
Task-conditional-additive-[0, 1, 2, 3]-3000
Task-conditional-additive-[0, 1, 2, 3]-1008
Task-condition

# Playing around with reading from the weights

In [8]:
import sys
sys.path.append('../projects/')

from metalearning import cnnmlp

In [9]:
DEFAULT_LEARNING_RATE = 5e-4
DEFAULT_WEIGHT_DECAY = 1e-4

def create_task_conditional_model(multiplicative=True, additive=True, checkpoint_path=None, name=None):
    mod_level = list(range(4))

    model = cnnmlp.TaskConditionalCNNMLP(
        mod_level=mod_level,
        multiplicative_mod=multiplicative,
        additive_mod=additive,
        query_length=30,
        conv_filter_sizes=(16, 32, 48, 64),
        conv_output_size=4480,
        mlp_layer_sizes=(512, 512, 512, 512),
        lr=DEFAULT_LEARNING_RATE,
        weight_decay=DEFAULT_WEIGHT_DECAY,
        use_lr_scheduler=False,
        conv_dropout=False,
        mlp_dropout=False,
        name=name)

    if checkpoint_path is not None:
        model.load_state(checkpoint_path)
        
    return model

In [10]:
api = wandb.Api()

'Task-conditional-additive-3014'

In [15]:
run = api.run('meta-learning-scaling/task-conditional-all-layers-additive-only/runs/49dq79wf')
last_checkpoint_file = run.file(f'{run.name.replace("[0, 1, 2, 3]-", "")}-query-9.pth')


In [16]:
last_checkpoint = last_checkpoint_file.download(replace=True, root='/tmp')

In [22]:
model = create_task_conditional_model(multiplicative=False, 
                                      checkpoint_path='/tmp/' + last_checkpoint_file.name)

In [36]:
layers = range(4)
additive_weights = [model.conv.additive_mod_layers[f'additive-{i}'].weight.detach().cpu().numpy() 
                                for i in layers]

In [43]:
dimension = run.config['benchmark_dimension']
dimension = 0
[w[:, dimension * 10:(dimension + 1) * 10].sum() for w in additive_weights]

[-2.8157453e-23, 3.222276e-22, -3.1739396e-22, -5.7911973e-22]

AttributeError: 'dict' object has no attribute 'benchmark_dimension'

In [39]:
run.config

{'lr': 0.0005,
 'loss': 'CE',
 'decay': 0.0001,
 'epochs': 1000,
 'mod_level': [0, 1, 2, 3],
 'batch_size': 1500,
 'query_order': [28, 24, 21, 26, 22, 29, 20, 23, 25, 27],
 'test_coreset_size': 5000,
 'accuracy_threshold': 0.95,
 'latin_square_index': 3014,
 'train_coreset_size': 22500,
 'benchmark_dimension': 2,
 'dataset_random_seed': 3014,
 'latin_square_random_seed': 3010}

In [26]:
model.conv.additive_mod_layers

ModuleDict(
  (additive-0): Linear(in_features=30, out_features=3, bias=True)
  (additive-1): Linear(in_features=30, out_features=16, bias=True)
  (additive-2): Linear(in_features=30, out_features=32, bias=True)
  (additive-3): Linear(in_features=30, out_features=48, bias=True)
)

In [None]:
model.conv.additive_mod_layers['additive-0'].weight.sum(0)

In [None]:
[model.conv.additive_mod_layers['additive-0'].weight.detach().cpu().numpy()[:, i * 10:(i + 1) * 10].sum() for i in range(3)]

In [None]:
np.ravel(model.conv.additive_mod_layers['additive-0'].weight.detach().cpu().numpy()[:, 20:30])

In [None]:
plt.hist(Out[44], bins=10)

# Scratch work