In [1]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch import nn
import torchmetrics
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from pprint import pprint
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models import get_model_class
from ml_utilities.output_loader.repo import Repo
from ml_utilities.output_loader.job_output import JobResult, SweepResult
from omegaconf import OmegaConf

from erank.data.datasetgenerator import DatasetGenerator

import matplotlib.pyplot as plt
gpu_id = 0
REPO = Repo(dir=Path('../../erank'), hydra_defaults=OmegaConf.load('../configs/hydra/jobname_outputdir_format.yaml'))

# Linear interpolation debug notebook
This notebook is used to implement linear interplation of models. 

Do linear interpolation with on MNIST. Use data from Experiment 11.7.3. 

Start from pretrained model with 100 steps.

#### Setup

In [2]:
config_yaml = """
run_config:
  exec_type: parallel # sequential
  hostname: gorilla
  gpu_ids: [0,1,2,3,4,5,6,7]
  runs_per_gpu: 3

  wandb: # wandb config for run_handler, if "wandb: null" then logging to wandb is disabled for run_handler
    init:
      tags:
        - ${config.experiment_data.experiment_tag}_exps
        - run_handler
      notes: #
      group: ${config.experiment_data.experiment_tag}
      job_type: run_handler

seeds: [1,2,3]

sweep:
  type: grid
  axes:
    - parameter: trainer.init_model_step
      vals: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
    - parameter: data.dataset_kwargs.rotation_angle
      vals: linspace(0,180,50,endpoint=True)
    - parameter: data.dataset_split.restrict_n_samples_train_task
      vals: [300] #[5, 20, 50, 100, 500, 1000, 10000, 48000]

start_num: 3 # use this to count how often this config is run
###
config:
  experiment_data:
    entity: jkuiml-fsl
    project_name: sparsity
    experiment_tag: "11.7"
    experiment_type: startnum_${start_num}
    experiment_name: mnist-${config.experiment_data.experiment_tag}.${start_num}-lenet_rottasks_ft
    experiment_dir: null
    experiment_notes: Hyperparameter search.
    job_name: null
    seed: 0
    hostname: null # the server on which the run is run, will be filled by run_handler
    gpu_id: 0

  # wandb:
  #   init:
  #     tags: # list(), used to tag wandblogger
  #       - ${config.experiment_data.experiment_tag}_exps
  #     notes: ${config.experiment_data.experiment_notes} # str, used to make notes to wandblogger
  #     group: ${config.experiment_data.experiment_tag} # null
  #     job_type: ${config.experiment_data.experiment_type} # examples: hypsearch, pretrain, eval, etc.

  #   watch:
  #     log: null #parameters #null #all
  #     log_freq: 5000

  model:
    name: fc
    model_kwargs:
      input_size: 784
      hidden_sizes:
        - 300
        - 100
      output_size: 10
      flatten_input: True
      dropout: null
      act_fn: relu

  trainer:
    training_setup: supervised
    n_steps: 2000
    log_train_step_every: 1
    log_additional_train_step_every_multiplier: 1
    log_additional_logs: True
    val_every: 5
    save_every: 5 #500
    early_stopping_patience: 200 #500
    batch_size: 128
    optimizer_scheduler:
      optimizer_name: adamw #sgd #adamw
      optimizer_kwargs:
        lr: 0.001
        weight_decay: 0.0
    
    init_model_step: XXX
    init_model: /system/user/beck/pwbeck/projects/regularization/erank/outputs/mnist-11.5.0-lenet--221015_122552/model_step_${config.trainer.init_model_step}.p

    loss: crossentropy

    metrics:
      - Accuracy
    num_workers: 4
    verbose: False

  data:
    dataset: rotatedvision
    dataset_kwargs:
      data_root_path: /system/user/beck/pwbeck/data
      dataset: mnist
      rotation_angle: XXX
    dataset_split:
      train_val_split: 0.8
      restrict_n_samples_train_task: XXX

"""
cfg = OmegaConf.create(config_yaml)

### Load model

In [3]:
sweepr = REPO.get_output_loader(cfg)
print(sweepr)

Exp. Tag(start_num): 11.7(3)
Exp. Name: mnist-11.7.3-lenet_rottasks_ft
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  trainer.init_model_step: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475]
  data.dataset_kwargs.rotation_angle: linspace(0,180,50,endpoint=True)
  data.dataset_split.restrict_n_samples_train_task: [300]
Num. jobs: 5400
Config updated: 2022-11-25 12:34:14
Sweep started:  2022-11-25 12:36:51



In [4]:
# sw_summary = sweepr.get_summary()

In [5]:
# pv_df = sw_summary[(sw_summary['trainer.init_model_step'] == 100) & sw_summary['data.dataset_kwargs.rotation_angle'].between(0,30)]

In [6]:
# pv_df.sort_values(by=['data.dataset_kwargs.rotation_angle', 'seed'])

In [7]:
sweepr.find_jobs('init_model_step-100-rotation_angle-25')

['/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651/outputs/mnist-11.7.3-lenet_rottasks_ft--init_model_step-100-rotation_angle-25.7143-restrict_n_samples_train_task-300-seed-2--221127_232024',
 '/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651/outputs/mnist-11.7.3-lenet_rottasks_ft--init_model_step-100-rotation_angle-25.7143-restrict_n_samples_train_task-300-seed-3--221126_021736',
 '/system/user/publicwork/beck/projects/regularization/erank/outputs/mnist-11.7.3-lenet_rottasks_ft--221125_123651/outputs/mnist-11.7.3-lenet_rottasks_ft--init_model_step-100-rotation_angle-25.7143-restrict_n_samples_train_task-300-seed-1--221125_232309']

In [8]:
jobs_pretrainsteps100_rotangle25 = sweepr.get_jobs('init_model_step-100-rotation_angle-25')

In [9]:
# select best models trained by 3 different seeds 
# these should all be linear mode connected
model_a = jobs_pretrainsteps100_rotangle25[0].best_model
model_b = jobs_pretrainsteps100_rotangle25[1].best_model
model_c = jobs_pretrainsteps100_rotangle25[2].best_model

In [10]:
model_a

FC(
  (fc): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=300, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=300, out_features=100, bias=True)
    (4): ReLU(inplace=True)
    (5): Linear(in_features=100, out_features=10, bias=True)
  )
)

In [11]:
# sanity check: values should not be the same
# model_a.state_dict()['fc.1.weight'], model_b.state_dict()['fc.1.weight'] # check successful

### Load dataset

In [12]:
jobs_pretrainsteps100_rotangle25[0].override_hpparams

{'trainer.init_model_step': 100,
 'data.dataset_kwargs.rotation_angle': 25.714285714285715,
 'data.dataset_split.restrict_n_samples_train_task': 300}

In [13]:
data_cfg = cfg.config.data

In [14]:
OmegaConf.update(data_cfg, 'dataset_kwargs.rotation_angle', 25.714285714285715)
OmegaConf.update(data_cfg, 'dataset_split.restrict_n_samples_train_task', 300)

In [15]:
print(OmegaConf.to_yaml(data_cfg))

dataset: rotatedvision
dataset_kwargs:
  data_root_path: /system/user/beck/pwbeck/data
  dataset: mnist
  rotation_angle: 25.714285714285715
dataset_split:
  train_val_split: 0.8
  restrict_n_samples_train_task: 300



In [16]:
ds_generator = DatasetGenerator(**data_cfg)
ds_generator.generate_dataset()

In [17]:
len(ds_generator.train_split), len(ds_generator.val_split)

(300, 12000)

## Model interpolation finetuned from same checkpoint

In [18]:
%autoreload 2
from erank.mode_connectivity import interpolate_linear

In [19]:
# score_fn = nn.CrossEntropyLoss()
# score_fn = torchmetrics.Accuracy()
from ml_utilities.torch_utils.metrics import SimpleAccuracy
score_fn = SimpleAccuracy()

In [20]:
weights = torch.linspace(0, 1.0, 6)
weights

tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000])

In [21]:
res_dict = interpolate_linear(model_a, model_b, ds_generator.train_split, score_fn, weights, {'val': ds_generator.val_split})
pprint(res_dict)

Interp. factors: 100%|██████████| 6/6 [00:23<00:00,  3.92s/it]
{'__weights': [0.0,
               0.20000000298023224,
               0.4000000059604645,
               0.6000000238418579,
               0.800000011920929,
               1.0],
 '_cosinesimilarity': 0.989238977432251,
 '_l2distance': 2.3759570121765137,
 'val': [0.8735514879226685,
         0.8734565377235413,
         0.8736345767974854,
         0.8736464977264404,
         0.8737414479255676,
         0.8745725750923157]}


In [22]:
# sanity check: values should not be the same
# model_a.state_dict()['fc.1.weight'], model_b.state_dict()['fc.1.weight'] # check successful

In [23]:
res_dict = interpolate_linear(model_a, model_c, ds_generator.train_split, score_fn, weights, {'val': ds_generator.val_split})
pprint(res_dict)

Interp. factors: 100%|██████████| 6/6 [00:21<00:00,  3.66s/it]
{'__weights': [0.0,
               0.20000000298023224,
               0.4000000059604645,
               0.6000000238418579,
               0.800000011920929,
               1.0],
 '_cosinesimilarity': 0.9810684323310852,
 '_l2distance': 3.155956268310547,
 'val': [0.8735514879226685,
         0.8740382790565491,
         0.8744538426399231,
         0.87546306848526,
         0.8760923147201538,
         0.8754274249076843]}


In [24]:
res_dict = interpolate_linear(model_b, model_c, ds_generator.train_split, score_fn, weights, {'val': ds_generator.val_split})
pprint(res_dict)

Interp. factors: 100%|██████████| 6/6 [00:22<00:00,  3.67s/it]
{'__weights': [0.0,
               0.20000000298023224,
               0.4000000059604645,
               0.6000000238418579,
               0.800000011920929,
               1.0],
 '_cosinesimilarity': 0.9805640578269958,
 '_l2distance': 3.192615270614624,
 'val': [0.8745725750923157,
         0.875558078289032,
         0.8773152828216553,
         0.8768997192382812,
         0.8761754035949707,
         0.8754274249076843]}


## Model interpolation independently trained

In [25]:
config_yaml2 = """
run_config:
  exec_type: parallel # sequential
  hostname: raptor
  gpu_ids: [0]
  runs_per_gpu: 2

  wandb: # wandb config for run_handler, if "wandb: null" then logging to wandb is disabled for run_handler
    init:
      tags:
        - ${config.experiment_data.experiment_tag}_exps
        - run_handler
      notes: #
      group: ${config.experiment_data.experiment_tag}
      job_type: run_handler

seeds: [1,2]

sweep:
  type: grid
  axes:
    - parameter: data.dataset_kwargs.rotation_angle
      vals: [0.0] #linspace(0,180,360,endpoint=True)

start_num: 1 # use this to count how often this config is run

###
config:
  experiment_data:
    entity: jkuiml-fsl
    project_name: sparsity
    experiment_tag: "11.5"
    experiment_type: startnum_${start_num}
    experiment_name: mnist-${config.experiment_data.experiment_tag}.${start_num}-lenet
    experiment_dir: null
    experiment_notes: Different random inits for mode connectivity analysis.
    job_name: null
    seed: 0
    hostname: null # the server on which the run is run, will be filled by run_handler
    gpu_id: 5

  wandb:
    init:
      tags: # list(), used to tag wandblogger
        - ${config.experiment_data.experiment_tag}_exps
      notes: ${config.experiment_data.experiment_notes} # str, used to make notes to wandblogger
      group: ${config.experiment_data.experiment_tag} # null
      job_type: ${config.experiment_data.experiment_type} # examples: hypsearch, pretrain, eval, etc.

    watch:
      log: null #parameters #null #all
      log_freq: 5000

  model:
    name: fc
    model_kwargs:
      input_size: 784
      hidden_sizes:
        - 300
        - 100
      output_size: 10
      flatten_input: True
      dropout: null
      act_fn: relu

  trainer:
    training_setup: supervised
    n_steps: 2000
    log_train_step_every: 1
    log_additional_train_step_every_multiplier: 1
    log_additional_logs: True
    val_every: 5
    save_every: 5
    early_stopping_patience: 200
    batch_size: 128
    optimizer_scheduler:
      optimizer_name: adamw #sgd #adamw
      optimizer_kwargs:
        lr: 0.001
        weight_decay: 0.0
    init_model: null

    loss: crossentropy

    metrics:
      - Accuracy
    num_workers: 4
    verbose: False

  data:
    dataset: rotatedvision
    dataset_kwargs:
      data_root_path: /system/user/beck/pwbeck/data
      dataset: mnist
      rotation_angle: 0.0
    dataset_split:
      train_val_split: 0.8

"""
cfg2 = OmegaConf.create(config_yaml2)

In [26]:
sweepr2 = REPO.get_output_loader(cfg2)
print(sweepr2)

Exp. Tag(start_num): 11.5(1)
Exp. Name: mnist-11.5.1-lenet
Training setup: supervised
Model name: fc
Dataset name: rotatedvision
Sweep type: grid
  data.dataset_kwargs.rotation_angle: [0.0]
Num. jobs: 2
Config updated: 2022-12-14 10:24:30
Sweep started:  2022-12-14 10:26:23



In [27]:
sweepr2.get_failed_jobs()

100%|██████████| 2/2 [00:00<00:00, 52.41it/s]


[]

In [28]:
jobs_mnist = sweepr2.get_jobs()
jobs_mnist

[JobResult(mnist-11.5.1-lenet--rotation_angle-0-seed-1--221214_102651),
 JobResult(mnist-11.5.1-lenet--rotation_angle-0-seed-2--221214_102651)]

In [29]:
model_A = jobs_mnist[0].best_model
model_B = jobs_mnist[1].best_model

In [30]:
data_cfg2 = cfg2.config.data
ds_generator2 = DatasetGenerator(**data_cfg2)
ds_generator2.generate_dataset()

In [31]:
res_dict = interpolate_linear(model_A, model_B, ds_generator2.train_split, score_fn, weights, {'val': ds_generator2.val_split}, dataloader_kwargs={'batch_size': 2048})
pprint(res_dict)

Interp. factors: 100%|██████████| 6/6 [00:22<00:00,  3.70s/it]
{'__weights': [0.0,
               0.20000000298023224,
               0.4000000059604645,
               0.6000000238418579,
               0.800000011920929,
               1.0],
 '_cosinesimilarity': 0.010202181525528431,
 '_l2distance': 30.355226516723633,
 'val': [0.974958598613739,
         0.9707711338996887,
         0.8915689587593079,
         0.9109537601470947,
         0.9726828932762146,
         0.976451575756073]}
