In [10]:
%load_ext autoreload
%autoreload 2
import sys
import torch
import itertools
import copy
import numpy as np
from torch import nn
import torchmetrics
import pandas as pd
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from pprint import pprint
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models import get_model_class
from ml_utilities.output_loader.repo import Repo
from ml_utilities.output_loader.job_output import JobResult, SweepResult
from ml_utilities.torch_utils.metrics import TAccuracy
from ml_utilities.utils import match_number_list_to_interval, flatten_hierarchical_dict, convert_listofdicts_to_dictoflists, hyp_param_cfg_to_str
from ml_utilities.run_utils.sweep import Sweeper
from omegaconf import OmegaConf

from erank.data.datasetgenerator import DatasetGenerator
from erank.mode_connectivity import interpolate_linear, interpolate_linear_runs, interpolation_result2series, InstabilityAnalyzer

import matplotlib.pyplot as plt
gpu_id = 0
REPO = Repo(dir=Path('../../erank'), hydra_defaults=OmegaConf.load('../configs/hydra/jobname_outputdir_format.yaml'))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Instability analysis debug notebook
This notebook is used to implement linear interplation of models. 

Do linear interpolation with on MNIST. Use data from Experiment 11.7.4. 

Start from pretrained model with 100 steps.

In [2]:
# some constants
score_fn = TAccuracy()

## Instability Analysis on Experiment 11.7.3

In [26]:
config_yaml = """
run_config:
  exec_type: parallel # sequential
  hostname: gorilla
  gpu_ids: [0,1,2,3,4,5,6,7]
  runs_per_gpu: 3

  wandb: # wandb config for run_handler, if "wandb: null" then logging to wandb is disabled for run_handler
    init:
      tags:
        - ${config.experiment_data.experiment_tag}_exps
        - run_handler
      notes: #
      group: ${config.experiment_data.experiment_tag}
      job_type: run_handler

seeds: [1,2,3]

sweep:
  type: grid
  axes:
    - parameter: trainer.init_model_step
      vals: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
    - parameter: data.dataset_kwargs.rotation_angle
      vals: linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals: [300] #[5, 20, 50, 100, 500, 1000, 10000, 48000]

start_num: 3 # use this to count how often this config is run
###
config:
  experiment_data:
    entity: jkuiml-fsl
    project_name: sparsity
    experiment_tag: "11.7"
    experiment_type: startnum_${start_num}
    experiment_name: mnist-${config.experiment_data.experiment_tag}.${start_num}-lenet_rottasks_ft
    experiment_dir: null
    experiment_notes: Hyperparameter search.
    job_name: null
    seed: 0
    hostname: null # the server on which the run is run, will be filled by run_handler
    gpu_id: 0

  # wandb:
  #   init:
  #     tags: # list(), used to tag wandblogger
  #       - ${config.experiment_data.experiment_tag}_exps
  #     notes: ${config.experiment_data.experiment_notes} # str, used to make notes to wandblogger
  #     group: ${config.experiment_data.experiment_tag} # null
  #     job_type: ${config.experiment_data.experiment_type} # examples: hypsearch, pretrain, eval, etc.

  #   watch:
  #     log: null #parameters #null #all
  #     log_freq: 5000

  model:
    name: fc
    model_kwargs:
      input_size: 784
      hidden_sizes:
        - 300
        - 100
      output_size: 10
      flatten_input: True
      dropout: null
      act_fn: relu

  trainer:
    training_setup: supervised
    n_steps: 2000
    log_train_step_every: 1
    log_additional_train_step_every_multiplier: 1
    log_additional_logs: True
    val_every: 5
    save_every: 5 #500
    early_stopping_patience: 200 #500
    batch_size: 128
    optimizer_scheduler:
      optimizer_name: adamw #sgd #adamw
      optimizer_kwargs:
        lr: 0.001
        weight_decay: 0.0
    
    init_model_step: XXX
    init_model: /system/user/beck/pwbeck/projects/regularization/erank/outputs/mnist-11.5.0-lenet--221015_122552/model_step_${config.trainer.init_model_step}.p

    loss: crossentropy

    metrics:
      - Accuracy
    num_workers: 4
    verbose: False

  data:
    dataset: rotatedvision
    dataset_kwargs:
      data_root_path: /system/user/beck/pwbeck/data
      dataset: mnist
      rotation_angle: XXX
    dataset_split:
      train_val_split: 0.8
      restrict_n_samples_train_task: XXX

"""
cfg = OmegaConf.create(config_yaml)

In [27]:
sweepr = REPO.get_output_loader(cfg)
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Seeds: [1, 2, 3]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



In [28]:
np.linspace(0,180,50, endpoint=True)[1]

3.673469387755102

In [29]:
sweepr.sweep_params

['trainer.init_model_step',
 'data.dataset_kwargs.rotation_angle',
 'data.dataset_split.restrict_n_samples_train_task']

In [30]:
qdict = {'trainer.init_model_step': 0,
    'data.dataset_kwargs.rotation_angle':np.linspace(0,180,50, endpoint=True)[1]}
qdict = {'trainer.init_model_step': 0,
    'data.dataset_kwargs.rotation_angle':3.673}
ret = sweepr.query_jobs(qdict, float_eps=1e-3)

Collecting summaries: 100%|██████████| 5400/5400 [00:46<00:00, 117.38it/s]


In [31]:
isinstance(qdict['data.dataset_kwargs.rotation_angle'], (float))

True

In [32]:
ret[0]

Unnamed: 0,best_step,best_val_score,trainer.init_model_step,data.dataset_kwargs.rotation_angle,data.dataset_split.restrict_n_samples_train_task,seed
mnist-11.7.3-lenet_rottasks_ft--init_model_step-0-rotation_angle-3.67347-restrict_n_samples_train_task-300-seed-1--221127_112533,110,0.850417,0,3.673469,300,1
mnist-11.7.3-lenet_rottasks_ft--init_model_step-0-rotation_angle-3.67347-restrict_n_samples_train_task-300-seed-2--221128_003316,160,0.8475,0,3.673469,300,2
mnist-11.7.3-lenet_rottasks_ft--init_model_step-0-rotation_angle-3.67347-restrict_n_samples_train_task-300-seed-3--221127_122758,355,0.85125,0,3.673469,300,3


In [33]:
print(OmegaConf.to_yaml(sweepr.sweep_cfg))

type: grid
axes:
- parameter: trainer.init_model_step
  vals:
  - 0
  - 5
  - 10
  - 15
  - 20
  - 25
  - 30
  - 35
  - 40
  - 45
  - 50
  - 55
  - 60
  - 65
  - 70
  - 75
  - 80
  - 85
  - 90
  - 95
  - 100
  - 125
  - 150
  - 175
  - 200
  - 225
  - 250
  - 275
  - 300
  - 325
  - 350
  - 375
  - 400
  - 425
  - 450
  - 475
- parameter: data.dataset_kwargs.rotation_angle
  vals: linspace(0,180,50,endpoint=True)
- parameter: data.dataset_split.restrict_n_samples_train_task
  vals:
  - 300



In [35]:
sw = Sweeper.create(sweep_config=copy.deepcopy(sweepr.sweep_cfg))
sw.sweep_params

['trainer.init_model_step',
 'data.dataset_kwargs.rotation_angle',
 'data.dataset_split.restrict_n_samples_train_task']

In [36]:
sw.drop_axes(['trainer.init_model_step', 'experiment_data.seed'])

[{'parameter': 'trainer.init_model_step', 'vals': [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]}]

In [37]:
sw.sweep_params

['data.dataset_kwargs.rotation_angle',
 'data.dataset_split.restrict_n_samples_train_task']

In [39]:
sw.drop_axes(['trainer.init_model_step', 'experiment_data.seed'])
sw.sweep_params

['data.dataset_kwargs.rotation_angle',
 'data.dataset_split.restrict_n_samples_train_task']

### Instability analysis

In [5]:
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Seeds: [1, 2, 3]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



In [6]:
sweepr.directory

PosixPath('/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651')

In [7]:
cfg_instability_yaml = f"""
instability_sweep: /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651
score_fn: TAccuracy
interpolation_factors: [0.0000, 0.2500, 0.5000, 0.7500, 1.0000]
init_model_idxes_ks_or_every: 15
train_model_idxes: [-1]
device: {gpu_id}
interpolate_linear_kwargs: 
  interpolation_on_train_data: False
  dataloader_kwargs:
    batch_size: 1024
  compute_model_distances: True
hpparam_sweep:
  type: grid
  axes:
    - parameter: data.dataset_kwargs.rotation_angle
      vals: [180.] #linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals:
      - 300
"""
cfg_instability = OmegaConf.create(cfg_instability_yaml)

In [8]:
insta = InstabilityAnalyzer(**cfg_instability)

[2022-12-21 11:13:21,512][ml_utilities.utils][INFO] - Logging to /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651/instability_analysis/output--221221_111321.log initialized.
[2022-12-21 11:13:21,528][erank.mode_connectivity][INFO] - Setup instability analysis with config: 
instability_sweep: /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651
score_fn: TAccuracy
interpolation_factors:
- 0.0
- 0.25
- 0.5
- 0.75
- 1.0
interpolate_linear_kwargs:
  interpolation_on_train_data: false
  dataloader_kwargs:
    batch_size: 1024
  compute_model_distances: true
init_model_idx_k_param_name: trainer.init_model_step
device: 0
save_results_to_disc: true
override_files: false
num_seed_combinations: 1
init_model_idxes_ks_or_every: 15
train_model_idxes:
- -1
hpparam_sweep:
  type: grid
  axes:
  - parameter: data.dataset_kwargs.rotation_angle
    vals:
    - 180.0
  - parameter: 

Collecting failed jobs: 100%|██████████| 5400/5400 [00:48<00:00, 112.29it/s]


[2022-12-21 11:14:10,130][erank.mode_connectivity][INFO] - Using init_model_idxes / k parameters: [0, 75, 350]
[2022-12-21 11:14:10,131][erank.mode_connectivity][INFO] - Finding seed combinations..
[2022-12-21 11:14:10,132][erank.mode_connectivity][INFO] - Using seed combinations: [(1, 2)]


In [9]:
res_ret = insta.instability_analysis(override_files=True)

[2022-12-21 11:15:10,190][erank.mode_connectivity][INFO] - Starting instability analysis..
[2022-12-21 11:15:10,195][ml_utilities.run_utils.sweep][INFO] - Generating sweep type: grid
[2022-12-21 11:15:10,202][erank.mode_connectivity][INFO] - Number of hyperparameter combinations for instability analysis: 1
HP combinations:   0%|          | 0/1 [00:00<?, ?it/s][2022-12-21 11:15:18,748][erank.mode_connectivity][INFO] - Params `rotation_angle-180-restrict_n_samples_train_task-300`: compute


Collecting summaries: 100%|██████████| 5400/5400 [00:21<00:00, 247.10it/s]


[2022-12-21 11:15:44,437][erank.data.datasetgenerator][INFO] - Generating dataset: rotatedvision
[2022-12-21 11:15:44,440][erank.data.rotatedvisiondataset][INFO] - Rotated vision dataset with mnist and rotation 180.0.
[2022-12-21 11:16:21,567][erank.data.datasetgenerator][INFO] - Generating dataset: rotatedvision
[2022-12-21 11:16:21,569][erank.data.rotatedvisiondataset][INFO] - Rotated vision dataset with mnist and rotation 180.0.
[2022-12-21 11:16:56,195][erank.data.datasetgenerator][INFO] - Generating dataset: rotatedvision
[2022-12-21 11:16:56,197][erank.data.rotatedvisiondataset][INFO] - Rotated vision dataset with mnist and rotation 180.0.
HP combinations: 100%|██████████| 1/1 [03:28<00:00, 208.23s/it]
[2022-12-21 11:18:47,035][erank.mode_connectivity][INFO] - Done. 
Combined results in file `/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651/instability_analysis/combined_results/combined_result--221221_111321.p`.


In [14]:
np.linspace(0,180,50,endpoint=True)[::5].tolist()

[0.0,
 18.367346938775512,
 36.734693877551024,
 55.10204081632653,
 73.46938775510205,
 91.83673469387756,
 110.20408163265306,
 128.57142857142858,
 146.9387755102041,
 165.3061224489796]

In [15]:
np.linspace(0,180,50,endpoint=True)

array([  0.        ,   3.67346939,   7.34693878,  11.02040816,
        14.69387755,  18.36734694,  22.04081633,  25.71428571,
        29.3877551 ,  33.06122449,  36.73469388,  40.40816327,
        44.08163265,  47.75510204,  51.42857143,  55.10204082,
        58.7755102 ,  62.44897959,  66.12244898,  69.79591837,
        73.46938776,  77.14285714,  80.81632653,  84.48979592,
        88.16326531,  91.83673469,  95.51020408,  99.18367347,
       102.85714286, 106.53061224, 110.20408163, 113.87755102,
       117.55102041, 121.2244898 , 124.89795918, 128.57142857,
       132.24489796, 135.91836735, 139.59183673, 143.26530612,
       146.93877551, 150.6122449 , 154.28571429, 157.95918367,
       161.63265306, 165.30612245, 168.97959184, 172.65306122,
       176.32653061, 180.        ])