In [1]:
%load_ext autoreload
%autoreload 2
import sys
import torch
import itertools
import numpy as np
from torch import nn
import torchmetrics
import pandas as pd
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from pprint import pprint
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models import get_model_class
from ml_utilities.output_loader.repo import Repo
from ml_utilities.output_loader.job_output import JobResult, SweepResult
from ml_utilities.torch_utils.metrics import TAccuracy
from ml_utilities.utils import match_number_list_to_interval, flatten_hierarchical_dict, convert_listofdicts_to_dictoflists, hyp_param_cfg_to_str
from ml_utilities.run_utils.sweep import Sweeper
from omegaconf import OmegaConf

from erank.data.datasetgenerator import DatasetGenerator
from erank.mode_connectivity import interpolate_linear, interpolate_linear_runs, interpolation_result2series, InstabilityAnalyzer

import matplotlib.pyplot as plt
gpu_id = 0
REPO = Repo(dir=Path('../../erank'), hydra_defaults=OmegaConf.load('../configs/hydra/jobname_outputdir_format.yaml'))

# Instability analysis debug notebook
This notebook is used to implement linear interplation of models. 

Do linear interpolation with on MNIST. Use data from Experiment 11.7.4. 

Start from pretrained model with 100 steps.

In [2]:
# some constants
score_fn = TAccuracy()

## Instability Analysis on Experiment 11.7.3

In [3]:
config_yaml = """
run_config:
  exec_type: parallel # sequential
  hostname: gorilla
  gpu_ids: [0,1,2,3,4,5,6,7]
  runs_per_gpu: 3

  wandb: # wandb config for run_handler, if "wandb: null" then logging to wandb is disabled for run_handler
    init:
      tags:
        - ${config.experiment_data.experiment_tag}_exps
        - run_handler
      notes: #
      group: ${config.experiment_data.experiment_tag}
      job_type: run_handler

seeds: [1,2,3]

sweep:
  type: grid
  axes:
    - parameter: trainer.init_model_step
      vals: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
    - parameter: data.dataset_kwargs.rotation_angle
      vals: linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals: [300] #[5, 20, 50, 100, 500, 1000, 10000, 48000]

start_num: 3 # use this to count how often this config is run
###
config:
  experiment_data:
    entity: jkuiml-fsl
    project_name: sparsity
    experiment_tag: "11.7"
    experiment_type: startnum_${start_num}
    experiment_name: mnist-${config.experiment_data.experiment_tag}.${start_num}-lenet_rottasks_ft
    experiment_dir: null
    experiment_notes: Hyperparameter search.
    job_name: null
    seed: 0
    hostname: null # the server on which the run is run, will be filled by run_handler
    gpu_id: 0

  # wandb:
  #   init:
  #     tags: # list(), used to tag wandblogger
  #       - ${config.experiment_data.experiment_tag}_exps
  #     notes: ${config.experiment_data.experiment_notes} # str, used to make notes to wandblogger
  #     group: ${config.experiment_data.experiment_tag} # null
  #     job_type: ${config.experiment_data.experiment_type} # examples: hypsearch, pretrain, eval, etc.

  #   watch:
  #     log: null #parameters #null #all
  #     log_freq: 5000

  model:
    name: fc
    model_kwargs:
      input_size: 784
      hidden_sizes:
        - 300
        - 100
      output_size: 10
      flatten_input: True
      dropout: null
      act_fn: relu

  trainer:
    training_setup: supervised
    n_steps: 2000
    log_train_step_every: 1
    log_additional_train_step_every_multiplier: 1
    log_additional_logs: True
    val_every: 5
    save_every: 5 #500
    early_stopping_patience: 200 #500
    batch_size: 128
    optimizer_scheduler:
      optimizer_name: adamw #sgd #adamw
      optimizer_kwargs:
        lr: 0.001
        weight_decay: 0.0
    
    init_model_step: XXX
    init_model: /system/user/beck/pwbeck/projects/regularization/erank/outputs/mnist-11.5.0-lenet--221015_122552/model_step_${config.trainer.init_model_step}.p

    loss: crossentropy

    metrics:
      - Accuracy
    num_workers: 4
    verbose: False

  data:
    dataset: rotatedvision
    dataset_kwargs:
      data_root_path: /system/user/beck/pwbeck/data
      dataset: mnist
      rotation_angle: XXX
    dataset_split:
      train_val_split: 0.8
      restrict_n_samples_train_task: XXX

"""
cfg = OmegaConf.create(config_yaml)

In [4]:
sweepr = REPO.get_output_loader(cfg)
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Seeds: [1, 2, 3]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



### Instability analysis

In [5]:
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Seeds: [1, 2, 3]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



In [6]:
sweepr.directory

PosixPath('/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651')

In [7]:
cfg_instability_yaml = f"""
instability_sweep: /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651
score_fn: TAccuracy
interpolation_factors: [0.0000, 0.2500, 0.5000, 0.7500, 1.0000]
init_model_idxes_ks_or_every: 15
train_model_idxes: [-1]
device: {gpu_id}
interpolate_linear_kwargs: 
  interpolation_on_train_data: False
  dataloader_kwargs:
    batch_size: 1024
  compute_model_distances: False
hpparam_sweep:
  type: grid
  axes:
    - parameter: data.dataset_kwargs.rotation_angle
      vals: [180.] #linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals:
      - 300
"""
cfg_instability = OmegaConf.create(cfg_instability_yaml)

In [9]:
insta = InstabilityAnalyzer(**cfg_instability)

2022-12-20 17:27:55,148: Logging to /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651/instability_analysis/output--221220_172755.log initialized.
2022-12-20 17:27:55,163: Setup instability analysis with config: 
instability_sweep: /system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651
score_fn: TAccuracy
interpolation_factors:
- 0.0
- 0.25
- 0.5
- 0.75
- 1.0
interpolate_linear_kwargs:
  interpolation_on_train_data: false
  dataloader_kwargs:
    batch_size: 1024
  compute_model_distances: false
init_model_idx_k_param_name: trainer.init_model_step
device: 0
save_results_to_disc: true
num_seed_combinations: 1
init_model_idxes_ks_or_every: 15
train_model_idxes:
- -1
hpparam_sweep:
  type: grid
  axes:
  - parameter: data.dataset_kwargs.rotation_angle
    vals:
    - 180.0
  - parameter: data.dataset_split.restrict_n_samples_train_task
    vals:
    - 300

2022-12-20 17:

100%|██████████| 5400/5400 [00:44<00:00, 122.70it/s]


2022-12-20 17:28:39,564: Using init_model_idxes / k parameters: [0, 75, 350]
2022-12-20 17:28:39,565: Finding seed combinations..
2022-12-20 17:28:39,567: Using seed combinations: [(1, 2)]


In [10]:
res_ret = insta.instability_analysis()

2022-12-20 17:28:39,707: Starting instability analysis..
2022-12-20 17:28:39,708: Generating sweep type: grid
2022-12-20 17:28:39,710: Number of hyperparameter combinations for instability analysis: 1
HP combinations: 0it [00:00, ?it/s]2022-12-20 17:28:39,712: Params `rotation_angle-180-restrict_n_samples_train_task-300`: compute


100%|██████████| 5400/5400 [00:19<00:00, 271.83it/s]


2022-12-20 17:29:00,945: Generating dataset: rotatedvision
2022-12-20 17:29:00,947: Rotated vision dataset with mnist and rotation 180.0.
HP combinations: 0it [00:49, ?it/s]


ValueError: All objects passed were None

In [None]:
ret[0]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,datasets,val,val,val,val,val,val
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,score,interpolation_scores,interpolation_scores,interpolation_scores,interpolation_scores,interpolation_scores,instability
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,alpha,0.00,0.25,0.50,0.75,1.00,NaN
init_model_idx_k,job,seeds,model_idxes,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3
0,mnist-11.7.3-lenet_rottasks_ft--init_model_step-0-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(210, 160)",0.847348,0.846765,0.846418,0.848941,0.84876,-0.001636
50,mnist-11.7.3-lenet_rottasks_ft--init_model_step-50-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(310, 1275)",0.86918,0.86861,0.869537,0.869275,0.870379,-0.001169
100,mnist-11.7.3-lenet_rottasks_ft--init_model_step-100-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(1690, 470)",0.879214,0.878029,0.877378,0.877505,0.876529,-0.001343
350,mnist-11.7.3-lenet_rottasks_ft--init_model_step-350-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(115, 135)",0.906933,0.907078,0.907241,0.906997,0.90654,-0.000196


In [None]:
ret[1]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,distances,l2distance,cosinesimilarity
init_model_idx_k,job,seeds,model_idxes,Unnamed: 4_level_1,Unnamed: 5_level_1
0,mnist-11.7.3-lenet_rottasks_ft--init_model_step-0-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(210, 160)",5.768037,0.933539
50,mnist-11.7.3-lenet_rottasks_ft--init_model_step-50-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(310, 1275)",3.530757,0.975657
100,mnist-11.7.3-lenet_rottasks_ft--init_model_step-100-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(1690, 470)",2.413693,0.988928
350,mnist-11.7.3-lenet_rottasks_ft--init_model_step-350-rotation_angle-22.0408-restrict_n_samples_train_task-300,"(1, 2)","(115, 135)",2.155845,0.992363
