In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
!wandb login 9676e3cc95066e4865586082971f2653245f09b4

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/guydavidson/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [3]:
import numpy as np
import pandas as pd
import scipy
from scipy import stats
from scipy.special import factorial

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import path as mpath

import pickle
import tabulate
import wandb
from collections import namedtuple
import sys
from ipypb import ipb

import meta_learning_data_analysis as analysis
import meta_learning_analysis_plots as plots

In [4]:
cache = analysis.refresh_cache()
print(cache.keys())

dict_keys(['six_replications_analyses', 'control_analyses', 'query_mod_replications', 'six_replications_updated_analyses', 'updated_control_analyses', 'query_mod_updated_analyses', 'forgetting_curves_raw_data', 'preliminary_maml_analyses', 'baseline_maml_comparison_analyses', 'maml_analyses', 'maml_alpha_0_analyses', 'maml_meta_test_analyses', 'balanced_batches_analyses', 'baseline_total_curve_analyses', 'control_total_curve_analyses', 'query_mod_total_curve_analyses', 'simultaneous_training_analyses', 'per_task_simultaneous_training_analyses', 'task_conditional_analyses', 'task_conditional_multiplicative_only_analyses', 'task_conditional_additive_only_analyses', 'task_conditional_weights', 'task_conditional_multiplicative_only_weights', 'task_conditional_additive_only_weights', 'forgetting_exp_decay_params', 'baseline_ratio_curriculum_analyses', 'baseline_power_curriculum_analyses', 'epochs_to_completion'])


-------

# Baseline analyses

In [None]:
if 'six_replications_analyses' in cache:
    six_replications_analyses = cache['six_replications_analyses']

else:
    six_replications_by_dimension_runs = analysis.load_runs(60)
    print('Loaded runs')

    six_reps_dict = {dimension_name:analysis.process_multiple_runs(run_set) 
                     for run_set, dimension_name 
                     in zip(six_replications_by_dimension_runs, analysis.CONDITION_ANALYSES_FIELDS)}
    six_replications_analyses = analysis.ConditionAnalysesSet(**six_reps_dict)

    cache = analysis.refresh_cache(dict(six_replications_analyses=six_replications_analyses))


In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
six_replications_by_dimension_runs = analysis.load_runs(60)
print('Loaded runs')

# note: the equal accuracy field will come in as accuracy_drops
updated_six_reps_dict = {}
start_index = 0
for run_set, dimension_name in zip(six_replications_by_dimension_runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):
    updated_six_reps_dict[dimension_name] = analysis.process_multiple_runs(
        run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size) 

# combined_analysis = analysis.process_multiple_runs(
#     six_replications_by_dimension_runs[3], 
#     parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size)

six_replications_updated_analyses = analysis.ConditionAnalysesSet(**updated_six_reps_dict)

cache = analysis.refresh_cache(dict(six_replications_updated_analyses=six_replications_updated_analyses))


In [None]:
# if 'baseline_total_curve_analyses' in cache:
#     baseline_total_curve_analyses = cache['baseline_total_curve_analyses']

# else:
six_replications_by_dimension_runs = analysis.load_runs(60)
print('Loaded runs')

analyses_per_dimension = {}

for run_set, dimension_name in zip(six_replications_by_dimension_runs, analysis.CONDITION_ANALYSES_FIELDS):
    print(f'Starting {dimension_name}')
    total_curve_raw, total_curve_mean, total_curve_std, total_curve_sem = \
        analysis.process_multiple_runs_total_task_training_curves(run_set)

    analyses_per_dimension[dimension_name] = analysis.TotalCurveResults(raw=total_curve_raw,
                                                                        mean=total_curve_mean, 
                                                                        std=total_curve_std, 
                                                                        sem=total_curve_sem)

total_curve_analyses = analysis.ConditionAnalysesSet(**analyses_per_dimension)
cache = analysis.refresh_cache(dict(baseline_total_curve_analyses=total_curve_analyses))


-------

# Control analyses

In [None]:
if 'control_analyses' in cache:
    control_analyses = cache['control_analyses']

else:
    control_runs = analysis.load_runs(150, 'meta-learning-scaling/sequential-benchmark-control', False)
    print(f'Loaded runs')
    control_analyses = analysis.ConditionAnalysesSet(combined=analysis.process_multiple_runs(control_runs.combined))

    cache = analysis.refresh_cache(dict(control_analyses=control_analyses))

In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
control_runs = analysis.load_runs(150, 'meta-learning-scaling/sequential-benchmark-control', False)
print('Loaded runs')

updated_control_analyses = analysis.ConditionAnalysesSet(
    combined=analysis.process_multiple_runs(control_runs.combined, 
                                            parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size))


cache = analysis.refresh_cache(dict(updated_control_analyses=updated_control_analyses))


In [None]:
if 'control_total_curve_analyses' in cache:
    control_total_curve_analyses = cache['control_total_curve_analyses']

else:
    control_runs = analysis.load_runs(150, 'meta-learning-scaling/sequential-benchmark-control', False)
    print('Loaded runs')
    
    total_curve_raw, total_curve_mean, total_curve_std, total_curve_sem = \
            analysis.process_multiple_runs_total_task_training_curves(control_runs.combined)
    
    control_total_curve_analyses = analysis.ConditionAnalysesSet(
        combined=analysis.TotalCurveResults(raw=total_curve_raw,
                                            mean=total_curve_mean,
                                            std=total_curve_std,
                                            sem=total_curve_sem))

    cache = analysis.refresh_cache(dict(control_total_curve_analyses=control_total_curve_analyses))


In [None]:
cache = analysis.refresh_cache(dict(control_analyses=control_analyses))

# Plot the results

## Plot the number of examples by dimension

In [None]:
ylim = (1000, 520000)

plots.plot_processed_results(first_replication_analyses.color.examples, 'Color 10-run average', ylim)
plots.plot_processed_results(first_replication_analyses.shape.examples, 'Shape 10-run average', ylim)
plots.plot_processed_results(first_replication_analyses.texture.examples, 'Material 10-run average', ylim)

In [None]:
ylim = (1000, 700000)

plots.plot_processed_results(six_replications_analyses.color.examples, 'Color 60-run average', ylim)
plots.plot_processed_results(six_replications_analyses.shape.examples, 'Shape 60-run average', ylim)
plots.plot_processed_results(six_replications_analyses.texture.examples, 'Material 60-run average', ylim)

## Plot the log of the number of examples to criterion, in each dimension, with error bars

## Plot the combined results over all 180 runs

In [None]:
ylim = (7.75, 13.25)

plots.plot_processed_results(six_replications_analyses.combined.log_examples, 'Combined 180-run average', 
                       ylim, log_x=(True, True), log_y=True, sem_n=180, shade_error=True)

## Plot the absolute accuracy after introducing a new task

In [None]:
ylim = None

plots.plot_processed_results(six_replications_analyses.color.accuracies, 'Color 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.shape.accuracies, 'Shape 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.texture.accuracies, 'Material 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.combined.accuracies, 'Combined 180-run average', 
                       ylim, log_x=False, log_y=False, sem_n=180, shade_error=True)

## Plot the accuracy drop after introducing a new task

In [None]:
ylim = None

plots.plot_processed_results(six_replications_analyses.color.accuracy_drops, 'Color 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.shape.accuracy_drops, 'Shape 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)
plots.plot_processed_results(six_replications_analyses.texture.accuracy_drops, 'Material 60-run average', 
                       ylim, log_x=False, log_y=False, sem_n=60, shade_error=True)

-------

# Query-modulated analyses

In [None]:
if 'query_mod_replications' in cache:
    query_mod_replications = cache['query_mod_replications']

else:
    query_mod_runs = analysis.query_modulated_runs_by_dimension(30)
    query_mod_replications = {}

    ignore_runs = [] # ('at6pkicv', )
    for mod_level in query_mod_runs:
        mod_level_runs = query_mod_runs[mod_level]

        mod_level_dict = {dimension_name: analysis.process_multiple_runs(mod_level_runs[i], ignore_runs=ignore_runs) 
                          for i, dimension_name 
                          in enumerate(analysis.CONDITION_ANALYSES_FIELDS)}

        query_mod_replications[mod_level] = analysis.ConditionAnalysesSet(**mod_level_dict)

    cache = analysis.refresh_cache(dict(query_mod_replications=query_mod_replications))

In [None]:
cache = analysis.refresh_cache(dict(query_mod_updated_analyses=query_mod_replications))

In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
query_mod_runs = analysis.query_modulated_runs_by_dimension(30)
print('Loaded runs')

# note: the equal accuracy field will come in as accuracy_drops
query_mod_replications = {}
mod_levels = list(query_mod_runs.keys())
start_index = 0

for mod_level in mod_levels[start_index:]:
    print(f'Starting mod level {mod_level}')
    mod_level_runs = query_mod_runs[mod_level]

    mod_level_dict = {dimension_name: analysis.process_multiple_runs(mod_level_runs[i], 
                                                                     parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size) 
                      for i, dimension_name 
                      in enumerate(analysis.CONDITION_ANALYSES_FIELDS)}

    query_mod_replications[mod_level] = analysis.ConditionAnalysesSet(**mod_level_dict)


cache = analysis.refresh_cache(dict(query_mod_updated_analyses=query_mod_replications))


In [None]:
cache = analysis.refresh_cache(dict(query_mod_replications=query_mod_replications,
                                    control_analyses=control_analyses,
                                    six_replications_analyses=six_replications_analyses))

In [None]:
if 'query_mod_total_curve_analyses' in cache:
    query_mod_total_curve_analyses = cache['query_mod_total_curve_analyses']

else:
    query_mod_runs = analysis.query_modulated_runs_by_dimension(30)
    print('Loaded runs')
    
    query_mod_total_curve_analyses = {}
    mod_levels = list(query_mod_runs.keys())
    start_index = 0
    
    for mod_level in mod_levels[start_index:]:
        print(f'Starting mod level {mod_level}')
        mod_level_runs = query_mod_runs[mod_level]

        analyses_per_dimension = {}
        
        for run_set, dimension_name in zip(mod_level_runs, analysis.CONDITION_ANALYSES_FIELDS):
            print(f'Starting {dimension_name}')
            total_curve_raw, total_curve_mean, total_curve_std, total_curve_sem = \
                analysis.process_multiple_runs_total_task_training_curves(run_set)

            analyses_per_dimension[dimension_name] = analysis.TotalCurveResults(raw=total_curve_raw,
                                                                                mean=total_curve_mean, 
                                                                                std=total_curve_std, 
                                                                                sem=total_curve_sem)

        query_mod_total_curve_analyses[mod_level] = analysis.ConditionAnalysesSet(**analyses_per_dimension)
            
    cache = analysis.refresh_cache(dict(query_mod_total_curve_analyses=query_mod_total_curve_analyses))


In [None]:
query_mod_total_curve_analyses.keys()

-------

# MAML

In [None]:
if 'maml_analyses' in cache:
    maml_analyses = cache['maml_analyses']

else:
    maml_runs = analysis.load_runs(30, 'meta-learning-scaling/maml-sequential-benchmark')
    print('Loaded runs')

    # ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
    #                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
    ignore_runs = set()

    maml_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(maml_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        maml_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    maml_analyses = analysis.ConditionAnalysesSet(**maml_analyses_dict)
    cache = analysis.refresh_cache(dict(maml_analyses=maml_analyses))


In [None]:
# if 'six_replications_updated_analyses' in cache:
#     six_replications_updated_analyses = cache['six_replications_updated_analyses']

# else:
maml_alpha_0_runs = analysis.load_runs(20, 'meta-learning-scaling/maml-alpha-0')
print('Loaded runs')

raise ValueError('This will not work yeet')

# ignore_runs = set(['ac82mceh', '7kau3ypy', 'g9ujw7gg', 'avmcbnot'])
ignore_runs = set()

maml_alpha_0_analyses_dict = {}
start_index = 0
for run_set, dimension_name in zip(maml_alpha_0_runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):
    
    maml_alpha_0_analyses_dict[dimension_name] = analysis.process_multiple_runs(
        run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
        ignore_runs=ignore_runs) 

maml_alpha_0_analyses = analysis.ConditionAnalysesSet(**maml_alpha_0_analyses_dict)
cache = analysis.refresh_cache(dict(maml_alpha_0_analyses=maml_alpha_0_analyses))


In [None]:
if 'maml_meta_test_analyses' in cache:
    maml_meta_test_analyses = cache['maml_meta_test_analyses']

else:
    maml_meta_test_runs = analysis.load_runs(30, 'meta-learning-scaling/maml-meta-test')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    maml_meta_test_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(maml_meta_test_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        maml_meta_test_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    maml_meta_test_analyses = analysis.ConditionAnalysesSet(**maml_meta_test_analyses_dict)
    cache = analysis.refresh_cache(dict(maml_meta_test_analyses=maml_meta_test_analyses))


In [None]:
if 'balanced_batches_analyses' in cache:
    balanced_batches_analyses = cache['balanced_batches_analyses']

else:
    balanced_batches_runs = analysis.load_runs(30, 'meta-learning-scaling/balanced-batches-sequential-benchmark')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    balanced_batches_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(balanced_batches_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        balanced_batches_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    balanced_batches_analyses = analysis.ConditionAnalysesSet(**balanced_batches_analyses_dict)
    cache = analysis.refresh_cache(dict(balanced_batches_analyses=balanced_batches_analyses))


In [None]:
maml_comparison_run_ids = [1000, 1001, 2000, 2001, 2002, 
                           2003, 2004, 2005, 2006, 2007, 
                           2008, 2009, 3000, 3001, 3002, 
                           3003, 3004, 3005, 3006, 3007, 3008]

maml_comparison_runs = analysis.load_runs(10, split_runs_by_dimension=False, valid_run_ids=set(maml_comparison_run_ids))
print('Loaded runs')

baseline_maml_comparison_analyses = analysis.ConditionAnalysesSet(
    combined=analysis.process_multiple_runs(maml_comparison_runs.combined, 
                                            parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,))


cache = analysis.refresh_cache(dict(baseline_maml_comparison_analyses=baseline_maml_comparison_analyses))



-----

# Simultaneous training in a dimension

In [None]:
# if 'simultaneous_training_analyses' in cache:
#     simultaneous_training_analyses = cache['simultaneous_training_analyses']

# else:
simultaneous_training_runs = analysis.load_runs(20, 'meta-learning-scaling/simultaneous-training')
print('Loaded runs')

# ignore_runs = set(['u3gk9oio'])
ignore_runs = set()

simultaneous_training_analyses_dict = {}
per_task_simultaneous_training_analyses_dict = {}
start_index = 0
for run_set, dimension_name in zip(simultaneous_training_runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

    stacked_results, all_results_mean, all_results_std, all_results_sem, per_task_results_mean, per_task_results_std, per_task_results_sem = analysis.process_multiple_runs_simultaneous_training(run_set, ignore_runs=ignore_runs)

    simultaneous_training_analyses_dict[dimension_name] = analysis.TotalCurveResults(raw=stacked_results,
                                                                                     mean=all_results_mean, 
                                                                                     std=all_results_std, 
                                                                                     sem=all_results_sem)

    per_task_simultaneous_training_analyses_dict[dimension_name] = analysis.TotalCurveResults(raw=stacked_results,
                                                                                              mean=per_task_results_mean, 
                                                                                              std=per_task_results_std, 
                                                                                              sem=per_task_results_sem)

simultaneous_training_analyses = analysis.ConditionAnalysesSet(**simultaneous_training_analyses_dict)
per_task_simultaneous_training_analyses = analysis.ConditionAnalysesSet(**per_task_simultaneous_training_analyses_dict)
cache = analysis.refresh_cache(dict(simultaneous_training_analyses=simultaneous_training_analyses,
                                    per_task_simultaneous_training_analyses=per_task_simultaneous_training_analyses))


-----

# New Task Modulation

In [None]:
if 'task_conditional_analyses' in cache:
    task_conditional_analyses = cache['task_conditional_analyses']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    # ignore_runs = set(['Task-conditional-[0, 1, 2, 3]-1017'])
    ignore_runs = set()

    task_modulated_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(task_modulated_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        task_modulated_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    task_conditional_analyses = analysis.ConditionAnalysesSet(**task_modulated_analyses_dict)
    cache = analysis.refresh_cache(dict(task_conditional_analyses=task_conditional_analyses))


In [None]:
if 'task_conditional_multiplicative_only_analyses' in cache:
    task_conditional_multiplicative_only_analyses = cache['task_conditional_multiplicative_only_analyses']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-multiplicative-only')
    print('Loaded runs')

    # ignore_runs = set(['Task-conditional-multiplicative-[0, 1, 2, 3]-1006',
    #                   'Task-conditional-multiplicative-[0, 1, 2, 3]-1005'])
    ignore_runs = set()

    task_modulated_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(task_modulated_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        task_modulated_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    task_conditional_multiplicative_only_analyses = analysis.ConditionAnalysesSet(**task_modulated_analyses_dict)
    cache = analysis.refresh_cache(dict(task_conditional_multiplicative_only_analyses=task_conditional_multiplicative_only_analyses))


In [None]:
if 'task_conditional_additive_only_analyses' in cache:
    task_conditional_additive_only_analyses = cache['task_conditional_additive_only_analyses']

else:
    task_modulated_runs = analyses_caches/alysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-additive-only')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    task_modulated_analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(task_modulated_runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        task_modulated_analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    task_conditional_additive_only_analyses = analysis.ConditionAnalysesSet(**task_modulated_analyses_dict)
    cache = analysis.refresh_cache(dict(task_conditional_additive_only_analyses=task_conditional_additive_only_analyses))


# Saving the task-conditional weights en masse

In [None]:
if 'task_conditional_weights' in cache:
    task_conditional_weights = cache['task_conditional_weights']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set(['Task-conditional-[0, 1, 2, 3]-1013',
                      'Task-conditional-[0, 1, 2, 3]-1011',
                      'Task-conditional-[0, 1, 2, 3]-1001'])
#     ignore_runs = set()

    task_conditional_weights = analysis.parse_task_conditional_weights(task_modulated_runs.combined,
                                                                      additive=True,
                                                                      multiplicative=True,
                                                                      ignore_runs=ignore_runs)

    cache = analysis.refresh_cache(dict(task_conditional_weights=task_conditional_weights))



In [None]:
if 'task_conditional_multiplicative_only_weights' in cache:
    task_conditional_multiplicative_only_weights = cache['task_conditional_multiplicative_only_weights']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-multiplicative-only')
    print('Loaded runs')

    # ignore_runs = set(['Task-conditional-multiplicative-[0, 1, 2, 3]-1006',
    #                   'Task-conditional-multiplicative-[0, 1, 2, 3]-1005'])
    ignore_runs = set()

    task_conditional_multiplicative_only_weights = analysis.parse_task_conditional_weights(task_modulated_runs.combined,
                                                                      additive=False,
                                                                      multiplicative=True,
                                                                      ignore_runs=ignore_runs)

    cache = analysis.refresh_cache(dict(task_conditional_multiplicative_only_weights=task_conditional_multiplicative_only_weights))



In [None]:
if 'task_conditional_additive_only_weights' in cache:
    task_conditional_additive_only_weights = cache['task_conditional_additive_only_weights']

else:
    task_modulated_runs = analysis.load_runs(20, 'meta-learning-scaling/task-conditional-all-layers-additive-only')
    print('Loaded runs')

    # ignore_runs = set(['u3gk9oio'])
    ignore_runs = set()

    task_conditional_additive_only_weights = analysis.parse_task_conditional_weights(task_modulated_runs.combined,
                                                                      additive=True,
                                                                      multiplicative=False,
                                                                      ignore_runs=ignore_runs)

    cache = analysis.refresh_cache(dict(task_conditional_additive_only_weights=task_conditional_additive_only_weights))

--------

# The curriculum experiments

In [7]:
if 'baseline_ratio_curriculum_analyses' in cache:
    baseline_ratio_curriculum_analyses = cache['baseline_ratio_curriculum_analyses']

else:
    runs = analysis.load_runs(20, 'meta-learning-scaling/baseline-curriculum-balanced-batches')
    print('Loaded runs')

    # ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
    #                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
    ignore_runs = set()

    analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    baseline_ratio_curriculum_analyses = analysis.ConditionAnalysesSet(**analyses_dict)
    cache = analysis.refresh_cache(dict(baseline_ratio_curriculum_analyses=baseline_ratio_curriculum_analyses))


Loaded runs
baseline-curriculum-balanced-batches-1009
baseline-curriculum-balanced-batches-1008
baseline-curriculum-balanced-batches-1007
baseline-curriculum-balanced-batches-1006
baseline-curriculum-balanced-batches-1005
baseline-curriculum-balanced-batches-1004
baseline-curriculum-balanced-batches-1003
baseline-curriculum-balanced-batches-1002
baseline-curriculum-balanced-batches-1001
baseline-curriculum-balanced-batches-1000
Removing extraneous nans
Max first nan index: 241
Examples to criterion examples
Log examples to criterion log_examples
New task accuracy accuracies
New task accuracy delta accuracy_drops
First task accuracy by epoch first_task_accuracies
New task accuracy by epoch new_task_accuracies
baseline-curriculum-balanced-batches-2009


  mean=np.nanmean(result_set, axis=0),
  keepdims=keepdims)


baseline-curriculum-balanced-batches-2008
baseline-curriculum-balanced-batches-2007
baseline-curriculum-balanced-batches-2006
baseline-curriculum-balanced-batches-2005
baseline-curriculum-balanced-batches-2004
baseline-curriculum-balanced-batches-2003
baseline-curriculum-balanced-batches-2002
baseline-curriculum-balanced-batches-2001
baseline-curriculum-balanced-batches-2000
Removing extraneous nans
Max first nan index: 35
Examples to criterion examples
Log examples to criterion log_examples
New task accuracy accuracies
New task accuracy delta accuracy_drops
First task accuracy by epoch first_task_accuracies
New task accuracy by epoch new_task_accuracies
baseline-curriculum-balanced-batches-3009
baseline-curriculum-balanced-batches-3008
baseline-curriculum-balanced-batches-3007
baseline-curriculum-balanced-batches-3006
baseline-curriculum-balanced-batches-3005
baseline-curriculum-balanced-batches-3004
baseline-curriculum-balanced-batches-3003
baseline-curriculum-balanced-batches-3002
b

In [8]:
if 'ratio_curriculum_1_5_analyses' in cache:
    ratio_curriculum_1_5_analyses = cache['ratio_curriculum_1_5_analyses']

else:
    runs = analysis.load_runs(20, 'meta-learning-scaling/curriculum-ratio-1-5-balanced-batches')
    print('Loaded runs')

    # ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
    #                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
    ignore_runs = set()

    analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    ratio_curriculum_1_5_analyses = analysis.ConditionAnalysesSet(**analyses_dict)
    cache = analysis.refresh_cache(dict(ratio_curriculum_1_5_analyses=ratio_curriculum_1_5_analyses))


Loaded runs
curriculum-ratio-1-5-balanced-batches-1002
curriculum-ratio-1-5-balanced-batches-1004
curriculum-ratio-1-5-balanced-batches-1003
curriculum-ratio-1-5-balanced-batches-1009
curriculum-ratio-1-5-balanced-batches-1001
curriculum-ratio-1-5-balanced-batches-1008
curriculum-ratio-1-5-balanced-batches-1007
curriculum-ratio-1-5-balanced-batches-1006
curriculum-ratio-1-5-balanced-batches-1005
curriculum-ratio-1-5-balanced-batches-1000
Removing extraneous nans
Max first nan index: 576
Examples to criterion examples
Log examples to criterion log_examples
New task accuracy accuracies
New task accuracy delta accuracy_drops
First task accuracy by epoch first_task_accuracies
New task accuracy by epoch new_task_accuracies
curriculum-ratio-1-5-balanced-batches-2009
curriculum-ratio-1-5-balanced-batches-2008
curriculum-ratio-1-5-balanced-batches-2007
curriculum-ratio-1-5-balanced-batches-2006
curriculum-ratio-1-5-balanced-batches-2005
curriculum-ratio-1-5-balanced-batches-2004
curriculum-rat

In [10]:
if 'baseline_power_curriculum_analyses' in cache:
    baseline_power_curriculum_analyses = cache['baseline_power_curriculum_analyses']

else:
    runs = analysis.load_runs(20, 'meta-learning-scaling/power-curriculum-default-alpha-balanced-batches')
    print('Loaded runs')

    # ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
    #                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
    ignore_runs = set()

    analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    baseline_power_curriculum_analyses = analysis.ConditionAnalysesSet(**analyses_dict)
    cache = analysis.refresh_cache(dict(baseline_power_curriculum_analyses=baseline_power_curriculum_analyses))


Loaded runs
power-curriculum-default-alpha-balanced-batches-1009
power-curriculum-default-alpha-balanced-batches-1008
power-curriculum-default-alpha-balanced-batches-1007
power-curriculum-default-alpha-balanced-batches-1006
power-curriculum-default-alpha-balanced-batches-1005
power-curriculum-default-alpha-balanced-batches-1004
power-curriculum-default-alpha-balanced-batches-1003
power-curriculum-default-alpha-balanced-batches-1002
power-curriculum-default-alpha-balanced-batches-1001
power-curriculum-default-alpha-balanced-batches-1000
Removing extraneous nans
Max first nan index: 84
Examples to criterion examples
Log examples to criterion log_examples
New task accuracy accuracies
New task accuracy delta accuracy_drops
First task accuracy by epoch first_task_accuracies
New task accuracy by epoch new_task_accuracies
power-curriculum-default-alpha-balanced-batches-2009
power-curriculum-default-alpha-balanced-batches-2008
power-curriculum-default-alpha-balanced-batches-2007
power-curricul

In [19]:
if 'power_curriculum_2_analyses' in cache:
    power_curriculum_2_analyses = cache['power_curriculum_2_analyses']

else:
    runs = analysis.load_runs(20, 'meta-learning-scaling/power-curriculum-alpha-2-balanced-batches')
    print('Loaded runs')

    # ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
    #                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
    ignore_runs = set()

    analyses_dict = {}
    start_index = 0
    for run_set, dimension_name in zip(runs[start_index:], 
                                       analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

        analyses_dict[dimension_name] = analysis.process_multiple_runs(
            run_set, parse_func=analysis.parse_run_results_with_new_task_accuracy_and_equal_size,
            ignore_runs=ignore_runs) 

    power_curriculum_2_analyses = analysis.ConditionAnalysesSet(**analyses_dict)
    cache = analysis.refresh_cache(dict(power_curriculum_2_analyses=power_curriculum_2_analyses))


Loaded runs
a4hj9y7k power-curriculum-alpha-2-balanced-batches-1002
0a1iifo9 power-curriculum-alpha-2-balanced-batches-1004
t071il8w power-curriculum-alpha-2-balanced-batches-1009
8dtqbr1j power-curriculum-alpha-2-balanced-batches-1008
mdjzx1jh power-curriculum-alpha-2-balanced-batches-1003
p88x7rzh power-curriculum-alpha-2-balanced-batches-1001
81so866q power-curriculum-alpha-2-balanced-batches-1007
itt7kotx power-curriculum-alpha-2-balanced-batches-1006
j94dgyhy power-curriculum-alpha-2-balanced-batches-1005
4s2sznro power-curriculum-alpha-2-balanced-batches-1000
Removing extraneous nans
Max first nan index: 914
Examples to criterion examples
Log examples to criterion log_examples
New task accuracy accuracies
New task accuracy delta accuracy_drops
First task accuracy by epoch first_task_accuracies
New task accuracy by epoch new_task_accuracies
1j7x7cx0 power-curriculum-alpha-2-balanced-batches-2009


  mean=np.nanmean(result_set, axis=0),
  keepdims=keepdims)


tkbqabcy power-curriculum-alpha-2-balanced-batches-2008
e2yilpsf power-curriculum-alpha-2-balanced-batches-2007
oze9neb3 power-curriculum-alpha-2-balanced-batches-2006
iozr2j1y power-curriculum-alpha-2-balanced-batches-2005
hgbn92hj power-curriculum-alpha-2-balanced-batches-2004
1uyhapjj power-curriculum-alpha-2-balanced-batches-2003
90qs1ggr power-curriculum-alpha-2-balanced-batches-2002
hzviecjv power-curriculum-alpha-2-balanced-batches-2001
auhn18le power-curriculum-alpha-2-balanced-batches-2000
Removing extraneous nans
Max first nan index: 91
Examples to criterion examples
Log examples to criterion log_examples
New task accuracy accuracies
New task accuracy delta accuracy_drops
First task accuracy by epoch first_task_accuracies
New task accuracy by epoch new_task_accuracies
s801akir power-curriculum-alpha-2-balanced-batches-3009
mnma5wha power-curriculum-alpha-2-balanced-batches-3008
9hcoqk36 power-curriculum-alpha-2-balanced-batches-3007
wq7l6495 power-curriculum-alpha-2-balanced-

# Computing the epochs to completion in each condition

In [20]:
EPOCHS_TO_COMPLETION_SETS = (
    # name, URI, num_runs, ignore_runs, split_by_condition
    ('baseline', 'meta-learning-scaling/sequential-benchmark-baseline', 60, None, True),
    ('heterogeneous', 'meta-learning-scaling/sequential-benchmark-control', 150, None, False),
    # (),  # TODO: query-modulated, too?
    ('ratio_curriculum', 'meta-learning-scaling/baseline-curriculum-balanced-batches', 20, None, True),
    ('ratio_curriculum_1_5', 'meta-learning-scaling/curriculum-ratio-1-5-balanced-batches', 20, None, True),
    ('power_curriculum', 'meta-learning-scaling/power-curriculum-default-alpha-balanced-batches', 20, None, True),
    ('power_curriculum_2', 'meta-learning-scaling/power-curriculum-alpha-2-balanced-batches', 20, None, True),
)

In [21]:
if 'epochs_to_completion' in cache:
    epochs_to_completion = cache['epochs_to_completion']
    
else:
    epochs_to_completion = dict()
    
    for name, path, num_runs, ignore_runs, split_runs_by_dimension in EPOCHS_TO_COMPLETION_SETS:
        runs = analysis.load_runs(num_runs, path, split_runs_by_dimension=split_runs_by_dimension)
        raw_analyses_dict = {}
        log_analyses_dict = {}
        
        if split_runs_by_dimension:
            for run_set, dimension_name in zip(runs, analysis.CONDITION_ANALYSES_FIELDS):
                results, log_results = analysis.epochs_to_taks_completions(run_set, ignore_runs=ignore_runs, ipb_desc=f'{name}/{dimension_name}') 
                raw_analyses_dict[dimension_name] = results
                log_analyses_dict[dimension_name] = log_results
                
        else:
            results, log_results = analysis.epochs_to_taks_completions(runs.combined, ignore_runs=ignore_runs, ipb_desc=f'{name}') 
            raw_analyses_dict[analysis.COMBINED] = results
            log_analyses_dict[analysis.COMBINED] = log_results
            
        raw_analyses_set = analysis.ConditionAnalysesSet(**raw_analyses_dict)
        log_analyses_set = analysis.ConditionAnalysesSet(**log_analyses_dict)
        
        epochs_to_completion[name] = dict(raw=raw_analyses_set, log=log_analyses_set)
        
    cache = analysis.refresh_cache(dict(epochs_to_completion=epochs_to_completion)) 
        

wandb: Network error resolved after 0:00:26.981506, resuming normal operation.


wandb: Network error resolved after 0:00:14.102628, resuming normal operation.


wandb: Network error resolved after 0:00:12.721399, resuming normal operation.


  ResultSet(name=name, mean=np.nanmean(np.log(results_per_run), axis=0), std=np.nanstd(np.log(results_per_run), axis=0))
  np.subtract(arr, avg, out=arr, casting='unsafe')


'ai99wvyu'

In [None]:
runs = analysis.load_runs(20, 'meta-learning-scaling/baseline-curriculum-balanced-batches')
print('Loaded runs')

# ignore_runs = set(['oz996ztv', 'oin8pqu2', 'h09i9xyg', 'umo8x16f', 'pm76ui1s', 
#                    'qnyo08x8', 'hmjtjheq', 'wki7q8cs', 'q8ku840y'])
ignore_runs = set()

raw_analyses_dict = {}
log_analyses_dict = {}
start_index = 0
for run_set, dimension_name in zip(runs[start_index:], 
                                   analysis.CONDITION_ANALYSES_FIELDS[start_index:]):

    results, log_results = analysis.epochs_to_taks_completions(run_set, ignore_runs=ignore_runs) 
    raw_analyses_dict[dimension_name] = results
    log_analyses_dict[dimension_name] = log_results
    

In [5]:
def run_finished_succesfully(run, samples=5000):
    df = run.history(pandas=True, samples=samples)
    test_acc_column_names = [f'Test Accuracy, Query #{i + 1}' for i in range(10)]
    if not all([col in df.columns for col in test_acc_column_names]):
        return False
    
    last_epoch_accuracies = [df[col].iloc[-1] for col in test_acc_column_names]
    return np.all(np.array(last_epoch_accuracies) >= 0.95)


def condition_finished_succesfully(runs, num_runs=20, split_runs_by_dimension=True):
    if isinstance(runs, analysis.ConditionAnalysesSet):
        runs = runs.combined
        
    if isinstance(runs, str):
        runs = analysis.load_runs(num_runs, runs, split_runs_by_dimension=split_runs_by_dimension).combined
        
    failed_runs = []
    running = []
    for run in ipb(runs, desc='Runs'):
        run_id = run.config['dataset_random_seed']
        if run.state == 'running':
            print(f'Run {run.name} is still running')
            running.append(run_id)
            continue
            
        if not run_finished_succesfully(run):
            print(f'Run {run.name} failed')
            failed_runs.append(run_id)
            
    if len(running) > 0:
        print(f'{len(running)} runs are still running: {running}')
            
    if len(failed_runs) == 0:
        print('All finished runs passed')
    else:
        print(f'{len(failed_runs)} runs failed: {failed_runs}')
            

In [6]:
curriculum_urls = (
    'meta-learning-scaling/baseline-curriculum-balanced-batches', 
    'meta-learning-scaling/curriculum-ratio-1-5-balanced-batches',
    'meta-learning-scaling/power-curriculum-default-alpha-balanced-batches',
    'meta-learning-scaling/power-curriculum-alpha-2-balanced-batches',
)

for url in curriculum_urls:
    print(url)
    condition_finished_succesfully(url)

meta-learning-scaling/baseline-curriculum-balanced-batches


All finished runs passed
meta-learning-scaling/curriculum-ratio-1-5-balanced-batches


All finished runs passed
meta-learning-scaling/power-curriculum-default-alpha-balanced-batches


All finished runs passed
meta-learning-scaling/power-curriculum-alpha-2-balanced-batches


All finished runs passed


In [10]:
isinstance(curriculum_urls[0], str)

True

In [None]:
r.state

In [None]:
test_df = runs.combined[0].history(pandas=True, samples=10000)

In [None]:
test_df['Test Accuracy, Query #7'].iloc[-1]

In [None]:
'Test Accuracy, Query #17' in test_df.columns

# Playing around with reading from the weights

In [None]:
import sys
sys.path.append('../projects/')

from metalearning import cnnmlp

In [None]:
DEFAULT_LEARNING_RATE = 5e-4
DEFAULT_WEIGHT_DECAY = 1e-4

def create_task_conditional_model(multiplicative=True, additive=True, checkpoint_path=None, name=None):
    mod_level = list(range(4))

    model = cnnmlp.TaskConditionalCNNMLP(
        mod_level=mod_level,
        multiplicative_mod=multiplicative,
        additive_mod=additive,
        query_length=30,
        conv_filter_sizes=(16, 32, 48, 64),
        conv_output_size=4480,
        mlp_layer_sizes=(512, 512, 512, 512),
        lr=DEFAULT_LEARNING_RATE,
        weight_decay=DEFAULT_WEIGHT_DECAY,
        use_lr_scheduler=False,
        conv_dropout=False,
        mlp_dropout=False,
        name=name)

    if checkpoint_path is not None:
        model.load_state(checkpoint_path)
        
    return model

In [None]:
api = wandb.Api()

In [None]:
run = api.run('meta-learning-scaling/task-conditional-all-layers-additive-only/runs/49dq79wf')
last_checkpoint_file = run.file(f'{run.name.replace("[0, 1, 2, 3]-", "")}-query-9.pth')


In [None]:
last_checkpoint = last_checkpoint_file.download(replace=True, root='/tmp')

In [None]:
model = create_task_conditional_model(multiplicative=False, 
                                      checkpoint_path='/tmp/' + last_checkpoint_file.name)

In [None]:
layers = range(4)
additive_weights = [model.conv.additive_mod_layers[f'additive-{i}'].weight.detach().cpu().numpy() 
                                for i in layers]

In [None]:
dimension = run.config['benchmark_dimension']
dimension = 0
[w[:, dimension * 10:(dimension + 1) * 10].sum() for w in additive_weights]

In [None]:
run.config

In [None]:
model.conv.additive_mod_layers

In [None]:
model.conv.additive_mod_layers['additive-0'].weight.sum(0)

In [None]:
[model.conv.additive_mod_layers['additive-0'].weight.detach().cpu().numpy()[:, i * 10:(i + 1) * 10].sum() for i in range(3)]

In [None]:
np.ravel(model.conv.additive_mod_layers['additive-0'].weight.detach().cpu().numpy()[:, 20:30])

In [None]:
plt.hist(Out[44], bins=10)

# Scratch work