In [2]:
%load_ext autoreload
%autoreload 2
import sys
import torch
import itertools
import copy
import numpy as np
from torch import nn
import torchmetrics
import pandas as pd
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from pprint import pprint
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models import get_model_class
from ml_utilities.output_loader.repo import Repo
from ml_utilities.output_loader.job_output import JobResult, SweepResult
from ml_utilities.torch_utils.metrics import TAccuracy
from ml_utilities.utils import match_number_list_to_interval, flatten_hierarchical_dict, convert_listofdicts_to_dictoflists, hyp_param_cfg_to_str
from ml_utilities.run_utils.sweep import Sweeper
from omegaconf import OmegaConf

from erank.data.datasetgenerator import DatasetGenerator
from erank.mode_connectivity import interpolate_linear, interpolate_linear_runs, interpolation_result2series, InstabilityAnalyzer

import matplotlib.pyplot as plt
gpu_id = 0
REPO = Repo(dir=Path('../../erank'), hydra_defaults=OmegaConf.load('../configs/hydra/jobname_outputdir_format.yaml'))

# Instability analysis debug notebook
This notebook is used to implement instability analysis a la Frankle et al., 2020. 
It is used to debug the large scale experiment. 

Do linear interpolation with on MNIST. Use data from Experiment 11.7.3.

In [3]:
# some constants
score_fn = TAccuracy()

## Instability Analysis on Experiment 11.7.3

In [4]:
config_yaml = """
run_config:
  exec_type: parallel # sequential
  hostname: gorilla
  gpu_ids: [0,1,2,3,4,5,6,7]
  runs_per_gpu: 3

  wandb: # wandb config for run_handler, if "wandb: null" then logging to wandb is disabled for run_handler
    init:
      tags:
        - ${config.experiment_data.experiment_tag}_exps
        - run_handler
      notes: #
      group: ${config.experiment_data.experiment_tag}
      job_type: run_handler

seeds: [1,2,3]

sweep:
  type: grid
  axes:
    - parameter: trainer.init_model_step
      vals: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
    - parameter: data.dataset_kwargs.rotation_angle
      vals: linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals: [300] #[5, 20, 50, 100, 500, 1000, 10000, 48000]

start_num: 3 # use this to count how often this config is run
###
config:
  experiment_data:
    entity: jkuiml-fsl
    project_name: sparsity
    experiment_tag: "11.7"
    experiment_type: startnum_${start_num}
    experiment_name: mnist-${config.experiment_data.experiment_tag}.${start_num}-lenet_rottasks_ft
    experiment_dir: null
    experiment_notes: Hyperparameter search.
    job_name: null
    seed: 0
    hostname: null # the server on which the run is run, will be filled by run_handler
    gpu_id: 0

  # wandb:
  #   init:
  #     tags: # list(), used to tag wandblogger
  #       - ${config.experiment_data.experiment_tag}_exps
  #     notes: ${config.experiment_data.experiment_notes} # str, used to make notes to wandblogger
  #     group: ${config.experiment_data.experiment_tag} # null
  #     job_type: ${config.experiment_data.experiment_type} # examples: hypsearch, pretrain, eval, etc.

  #   watch:
  #     log: null #parameters #null #all
  #     log_freq: 5000

  model:
    name: fc
    model_kwargs:
      input_size: 784
      hidden_sizes:
        - 300
        - 100
      output_size: 10
      flatten_input: True
      dropout: null
      act_fn: relu

  trainer:
    training_setup: supervised
    n_steps: 2000
    log_train_step_every: 1
    log_additional_train_step_every_multiplier: 1
    log_additional_logs: True
    val_every: 5
    save_every: 5 #500
    early_stopping_patience: 200 #500
    batch_size: 128
    optimizer_scheduler:
      optimizer_name: adamw #sgd #adamw
      optimizer_kwargs:
        lr: 0.001
        weight_decay: 0.0
    
    init_model_step: XXX
    init_model: /system/user/beck/pwbeck/projects/regularization/erank/outputs/mnist-11.5.0-lenet--221015_122552/model_step_${config.trainer.init_model_step}.p

    loss: crossentropy

    metrics:
      - Accuracy
    num_workers: 4
    verbose: False

  data:
    dataset: rotatedvision
    dataset_kwargs:
      data_root_path: /system/user/beck/pwbeck/data
      dataset: mnist
      rotation_angle: XXX
    dataset_split:
      train_val_split: 0.8
      restrict_n_samples_train_task: XXX

"""
cfg = OmegaConf.create(config_yaml)

In [5]:
sweepr = REPO.get_output_loader(cfg)
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Seeds: [1, 2, 3]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



In [6]:
np.linspace(0,180,50, endpoint=True)[1]

3.673469387755102

### Instability analysis

In [7]:
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Seeds: [1, 2, 3]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



In [8]:
sweepr.directory

PosixPath('/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651')

In [9]:
cfg_instability_yaml = f"""
instability_sweep: /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651
score_fn: TAccuracy
interpolation_factors: [0.0000, 0.2500, 0.5000, 0.7500, 1.0000]
init_model_idxes_ks_or_every: 15 #3 # 0
train_model_idxes: [-1] #[100, -1]
device: 0
interpolate_linear_kwargs: 
  interpolation_on_train_data: True
  dataloader_kwargs:
    batch_size: 1024
  compute_model_distances: True
override_files: False
num_seed_combinations: 1
save_folder_suffix: 1
float_eps_query_job: 1e-3
save_results_to_disc: False
hpparam_sweep:
  type: grid
  axes:
    - parameter: data.dataset_kwargs.rotation_angle
      vals: [3.67346] #[180.] #linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals:
      - 300
"""
cfg_instability = OmegaConf.create(cfg_instability_yaml)

In [10]:
insta = InstabilityAnalyzer(**cfg_instability)

Collecting failed jobs: 100%|██████████| 5400/5400 [00:49<00:00, 109.35it/s]


In [11]:
# res_ret = insta.instability_analysis(override_files=True)

In [22]:
sorted(insta.combined_results, reverse=False)

['combined_result--221221_134957', 'combined_result--221222_144957']

In [34]:
insta.combined_results

['combined_result--221221_134957']

In [27]:
insta.combined_results_dfs.keys()

dict_keys(['datasets', 'distances'])

In [31]:
# insta.combined_results_dfs['datasets']

In [30]:
sweepr.directory

PosixPath('/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651')

In [39]:
insta2 = InstabilityAnalyzer.reload(sweepr, instability_folder_suffix=1)

Collecting failed jobs:  90%|████████▉ | 4841/5400 [01:17<00:08, 62.28it/s]
Collecting failed jobs: 100%|██████████| 5400/5400 [00:51<00:00, 105.20it/s]


In [40]:
insta2.combined_results

['combined_result--221221_134957']