In [16]:
import wandb
import numpy as np
import torch
import lightning
import copy
from pathlib import Path
from model_fusion.config import BASE_DATA_DIR, CHECKPOINT_DIR
from model_fusion.datasets import DataModuleType
from model_fusion.models import ModelType
from model_fusion.models.lightning import BaseModel 
from Experiments import lmc_experiment
from model_fusion import lmc_utils
from Experiments import baselines_experiment
from Experiments import otfusion_experiment
from Experiments import pyhessian_experiment
from model_fusion.train import setup_training, setup_testing


# set seed for numpy based calculations
NUMPY_SEED = 100
np.random.seed(NUMPY_SEED)

In [17]:
print("------- Loading models -------")

# select wandb run names
runA = 'bbecqkxs'
runB = 'k9q16yq1'#diff init

api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runA}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']
datamodule_hparams['data_augmentation'] = False

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointA = f'model-fusion/Model Fusion/model-{runA}:best_k'
checkpointB = f'model-fusion/Model Fusion/model-{runB}:best_k'

run = wandb.init()

artifact = run.use_artifact(checkpointA, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelA = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

artifact = run.use_artifact(checkpointB, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelB = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")


------- Loading models -------
{'lr': 0.025, 'momentum': 0.9, 'optimizer': 'sgd', 'max_epochs': 200, 'min_epochs': 50, 'model_seed': 42, 'model_type': 'ModelType.RESNET18', 'loss_module': 'CrossEntropyLoss', 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'model_hparams': {'bias': False, 'num_classes': 10, 'num_channels': 3}, 'early_stopping': True, 'datamodule_type': 'DataModuleType.CIFAR10', 'lr_decay_factor': 0.1, 'lightning_params': {'lr': 0.025, 'momentum': 0.9, 'optimizer': 'sgd', 'model_seed': 42, 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'lr_decay_factor': 0.1, 'lr_monitor_metric': 'val_loss'}, 'lr_monitor_metric': 'val_loss', 'datamodule_hparams': {'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': True}, 'model_hparams/bias': False, 'model_hparams/num_classes': 10, 'model_hparams/num_channels': 3}
{'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': False}
{'bias': False, 'num_classes': 10, 'num_channels': 3}


[34m[1mwandb[0m: Downloading large artifact model-bbecqkxs:best_k, 85.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact model-k9q16yq1:best_k, 85.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5


In [3]:
# LMC barrier
print("------- Computing LMC barrier before alignment-------")

lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA,
    modelB=modelB,
    granularity=21
)

------- Computing LMC barrier before alignment-------
Files already downloaded and verified
Files already downloaded and verified
Alpha: 0.00 (model 2), Train average loss: 0.02452 Train barrier:  0
Alpha: 1.00 (model 1), Train average loss: 0.05230 Train barrier:  0
Alpha: 0.05, Train average loss: 0.03960 Train barrier 0.01368531780148546
Alpha: 0.10, Train average loss: 0.09384 Train barrier 0.06654240260101027
Alpha: 0.15, Train average loss: 0.24460 Train barrier 0.21590681302264333
Alpha: 0.20, Train average loss: 0.57929 Train barrier 0.5492082939424779
Alpha: 0.25, Train average loss: 1.10763 Train barrier 1.0761669906641873
Alpha: 0.30, Train average loss: 1.64552 Train barrier 1.6126673706676893
Alpha: 0.35, Train average loss: 2.01071 Train barrier 1.9764659624682699
Alpha: 0.40, Train average loss: 2.19720 Train barrier 2.1615661021882295
Alpha: 0.45, Train average loss: 2.27190 Train barrier 2.234877763523112
Alpha: 0.50, Train average loss: 2.28801 Train barrier 2.2495970

In [12]:
# Baselines (prediction ensembling, vanilla averaging)
print("------- Computing baselines -------")

wandb_tag = f'baselines-{runA}-{runB}'

vanilla_averaging_model = baselines_experiment.run_baselines(
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag,
)

------- Computing baselines -------
------- Prediction based ensembling -------
------- Naive ensembling of weights -------
------- Evaluating baselines -------


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


------- Evaluating base models -------
Testing DataLoader 0:  10%|█         | 1/10 [00:00<00:03,  2.44it/s]

  return F.linear(input, self.weight, self.bias)


Testing DataLoader 0: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy           0.906499981880188
        val_loss            0.37029629945755005
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Testing DataLoader 0: 100%|██████████| 10/10 [00:02<00:00,  4.50it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy           0.918999969959259
        val_loss            0.3654583990573883
───────────────────────────────────────────────

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Test set: Average loss: 0.3027, Accuracy: 92.50%
------- Evaluating vanilla averaging -------
Testing DataLoader 0: 100%|██████████| 10/10 [00:02<00:00,  3.34it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.17899999022483826
        val_loss             2.287587881088257
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁▁▁
trainer/global_step,▁▁▁
val_accuracy,██▁
val_loss,▁▁█

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.179
val_loss,2.28759


In [3]:
# OT model fusion + eval aligned model 
print("------- Computing model fusion -------")

wandb_tag = f"ot_model_fusion-{runA}-{runB}"

ot_fused_model, modelA_aligned = otfusion_experiment.run_otfusion(
    batch_size=batch_size,
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag
)


------- Computing model fusion -------
------- Setting up parameters -------
{'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': False}
The parameters are: 
 {'eval_aligned': True, 'num_models': 2, 'width_ratio': 1, 'handle_skips': True, 'exact': True, 'activation_seed': 21, 'activation_histograms': True, 'ground_metric': 'euclidean', 'ground_metric_normalize': 'none', 'same_model': False, 'geom_ensemble_type': 'acts', 'act_num_samples': 200, 'skip_last_layer': False, 'skip_last_layer_type': 'average', 'softmax_temperature': 1, 'past_correction': True, 'correction': True, 'normalize_acts': False, 'normalize_wts': False, 'activation_normalize': False, 'center_acts': False, 'prelu_acts': False, 'pool_acts': False, 'pool_relu': False, 'importance': None, 'proper_marginals': False, 'not_squared': True, 'ground_metric_eff': False, 'dist_normalize': False, 'clip_gm': False, 'clip_min': 0, 'clip_max': 5, 'tmap_stats': False, 'ensemble_step': 0.5, 'reg': 0.01}
------- OT mo

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:395: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:03<00:00, 87.91it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.6895999908447266
        val_loss            0.9318426847457886
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.6896
val_loss,0.93184


# model parameters:  21
# new parameters:  21
fusing:  model.conv1.weight
fusing:  model.layer1.0.conv1.weight
fusing:  model.layer1.0.conv2.weight
fusing:  model.layer1.1.conv1.weight
fusing:  model.layer1.1.conv2.weight
fusing:  model.layer2.0.conv1.weight
fusing:  model.layer2.0.conv2.weight
fusing:  model.layer2.0.shortcut.0.weight
fusing:  model.layer2.1.conv1.weight
fusing:  model.layer2.1.conv2.weight
fusing:  model.layer3.0.conv1.weight
fusing:  model.layer3.0.conv2.weight
fusing:  model.layer3.0.shortcut.0.weight
fusing:  model.layer3.1.conv1.weight
fusing:  model.layer3.1.conv2.weight
fusing:  model.layer4.0.conv1.weight
fusing:  model.layer4.0.conv2.weight
fusing:  model.layer4.0.shortcut.0.weight
fusing:  model.layer4.1.conv1.weight
fusing:  model.layer4.1.conv2.weight
fusing:  model.fc.weight
------- Evaluating ot fusion model -------


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:03<00:00, 93.93it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.7246000170707703
        val_loss            0.9455543160438538
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.7246
val_loss,0.94555


In [9]:
# LMC barrier
print("------- Computing LMC barrier after alignment -------")

lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA_aligned,
    modelB=modelB,
    granularity=21
)

# Losses for ot fusion model and vanilla averaging model
datamodule_hparams_lmc = {'batch_size': 1024, 'data_dir': BASE_DATA_DIR}
datamodule_lmc = datamodule_type.get_data_module(**datamodule_hparams)
datamodule_lmc.prepare_data()
datamodule_lmc.setup('fit')

vanilla_loss = lmc_utils.compute_loss(vanilla_averaging_model, datamodule_lmc)
fused_loss = lmc_utils.compute_loss(ot_fused_model, datamodule_lmc)

print(f"Vanilla loss pre fine-tuning: {vanilla_loss}")
print(f"Fused loss pre fine-tuning: {fused_loss}")

------- Computing LMC barrier after alignment -------
Files already downloaded and verified
Files already downloaded and verified
Alpha: 0.00 (model 2), Train average loss: 0.02452 Train barrier:  0
Alpha: 1.00 (model 1), Train average loss: 0.84018 Train barrier:  0
Alpha: 0.05, Train average loss: 0.03192 Train barrier -0.03338956658409702
Alpha: 0.10, Train average loss: 0.05025 Train barrier -0.05583960228827264
Alpha: 0.15, Train average loss: 0.08538 Train barrier -0.06148995947043101
Alpha: 0.20, Train average loss: 0.14369 Train barrier -0.04396893484592437
Alpha: 0.25, Train average loss: 0.22960 Train barrier 0.001160084712505366
Alpha: 0.30, Train average loss: 0.34322 Train barrier 0.07399484267499712
Alpha: 0.35, Train average loss: 0.47871 Train barrier 0.16870902112987307
Alpha: 0.40, Train average loss: 0.62457 Train barrier 0.2737868930339813
Alpha: 0.45, Train average loss: 0.76628 Train barrier 0.3747137470708952
Alpha: 0.50, Train average loss: 0.89105 Train barrier

In [8]:
# finetuning otfused model    #DONE
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint

from model_fusion.train import get_wandb_logger

_, datamodule, trainer = None, None, None

min_epochs = 50
max_epochs = 100
datamodule_hparams['batch_size'] = 32
datamodule_hparams['data_augmentation']=True

datamodule = datamodule_type.get_data_module(**datamodule_hparams)
lightning_params = {'optimizer': 'sgd', 'lr': 0.01, 'momentum': 0.95, 'weight_decay': 0.0001, 'lr_scheduler': 'plateau', 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}
otfused_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, model=copy.deepcopy(ot_fused_model.model), **lightning_params)


logger_config = {'model_hparams': model_hparams} | {'datamodule_hparams': datamodule_hparams} | {'lightning_params': lightning_params} | {'min_epochs': min_epochs, 'max_epochs': max_epochs, 'model_type': model_type, 'datamodule_type': datamodule_type, 'early_stopping': True}
logger = get_wandb_logger("otfusion finetuning", logger_config, [])
callbacks = []
monitor = 'val_loss'
patience = 20
callbacks.append(EarlyStopping(monitor=monitor, patience=patience))

checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
callbacks.append(checkpoint_callback)
trainer = lightning.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger, callbacks=callbacks, deterministic='warn')


datamodule.prepare_data()

datamodule.setup('fit')

trainer.fit(otfused_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())


datamodule.setup('test')
trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())


wandb.finish()

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\utilities\parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\utilities\parsing.py:43: attribute 'model' removed from hparams because it cannot be pickled
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | loss_module | CrossEntropyLoss   | 0     
1 | accuracy    | MulticlassAccuracy | 0     
2 | model       | ResNet             | 11.2 M
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.657    Total estimated model params size (MB)


Using SGD optimizer with lr=0.01, momentum=0.95, weight_decay=0.0001, nesterov=False
Using ReduceLROnPlateau with lr_decay_factor=0.5 and lr_monitor_metric=val_loss
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 15.18it/s]

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
  return F.linear(input, self.weight, self.bias)


                                                                           

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 1/1407 [00:00<03:31,  6.65it/s, v_num=j32n, train_loss=0.763, train_accuracy=0.906]

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 21: 100%|██████████| 1407/1407 [00:46<00:00, 30.21it/s, v_num=j32n, train_loss=0.183, train_accuracy=0.875, val_loss=0.382, val_accuracy=0.892, avg_train_loss=0.144]  Epoch 00022: reducing learning rate of group 0 to 5.0000e-03.
Epoch 31: 100%|██████████| 1407/1407 [00:50<00:00, 27.64it/s, v_num=j32n, train_loss=0.00765, train_accuracy=1.000, val_loss=0.340, val_accuracy=0.915, avg_train_loss=0.0341] Epoch 00032: reducing learning rate of group 0 to 2.5000e-03.
Epoch 39: 100%|██████████| 1407/1407 [00:50<00:00, 27.61it/s, v_num=j32n, train_loss=2.33e-5, train_accuracy=1.000, val_loss=0.377, val_accuracy=0.927, avg_train_loss=0.00965] Epoch 00040: reducing learning rate of group 0 to 1.2500e-03.
Epoch 44:   0%|          | 0/1407 [00:00<?, ?it/s, v_num=j32n, train_loss=0.0249, train_accuracy=1.000, val_loss=0.368, val_accuracy=0.930, avg_train_loss=0.00455]             

Trainer was signaled to stop but the required `min_epochs=50` or `min_steps=None` has not been met. Training will continue...


Epoch 47: 100%|██████████| 1407/1407 [00:46<00:00, 29.99it/s, v_num=j32n, train_loss=7.06e-5, train_accuracy=1.000, val_loss=0.394, val_accuracy=0.928, avg_train_loss=0.00376] Epoch 00048: reducing learning rate of group 0 to 6.2500e-04.
Epoch 49: 100%|██████████| 1407/1407 [00:47<00:00, 29.81it/s, v_num=j32n, train_loss=0.000231, train_accuracy=1.000, val_loss=0.377, val_accuracy=0.931, avg_train_loss=0.00284]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:04<00:00, 65.43it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.9269000291824341
        val_loss            0.4218737781047821
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
avg_train_loss,█▇▆▆▅▅▅▅▅▅▄▄▄▄▄▄▄▄▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▂▃▄▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████████████
train_accuracy,▁▅▅▆████████████████████████████████████
train_loss,█▄▇▄▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁
trainer/global_step,▁▂▃▄▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████
val_accuracy,▁▂▂▂▃▂▂▂▂▃▂▃▂▄▃▂▃▄▇▆▅▇▇▇▆▇█▇█▇▇█████████
val_loss,▇▅▆▅▄▅▆▇▇▄▇▆▇▄▅▇▆▆▁▂▄▄▃▃▄▅▄▄▄▅▇▅▄▅▅▅▅▆▆█

0,1
avg_train_loss,0.00291
epoch,50.0
train_accuracy,1.0
train_loss,0.00023
trainer/global_step,70350.0
val_accuracy,0.9269
val_loss,0.42187


In [13]:
#finetuning vanilla averaged model
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint

from model_fusion.train import get_wandb_logger

_, datamodule, trainer = None, None, None

min_epochs = 50
max_epochs = 100
datamodule_hparams['batch_size'] = 32
datamodule_hparams['data_augmentation']=True

datamodule = datamodule_type.get_data_module(**datamodule_hparams)
lightning_params = {'optimizer': 'sgd', 'lr': 0.01, 'momentum': 0.95, 'weight_decay': 0.0001, 'lr_scheduler': 'plateau', 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}

vanilla_averaged_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, model=copy.deepcopy(vanilla_averaging_model.model), **lightning_params)

logger_config = {'model_hparams': model_hparams} | {'datamodule_hparams': datamodule_hparams} | {'lightning_params': lightning_params} | {'min_epochs': min_epochs, 'max_epochs': max_epochs, 'model_type': model_type, 'datamodule_type': datamodule_type, 'early_stopping': True}
logger = get_wandb_logger("vanilla finetuning", logger_config, [])
callbacks = []
monitor = 'val_loss'
patience = 20
callbacks.append(EarlyStopping(monitor=monitor, patience=patience))

checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
callbacks.append(checkpoint_callback)
trainer = lightning.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger, callbacks=callbacks, deterministic='warn')


datamodule.prepare_data()

datamodule.setup('fit')


trainer.fit(vanilla_averaged_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())

datamodule.setup('test')

trainer.test(vanilla_averaged_lit_model, dataloaders=datamodule.test_dataloader())

wandb.finish()

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\utilities\parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\utilities\parsing.py:43: attribute 'model' removed from hparams because it cannot be pickled
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | loss_module | CrossEntropyLoss   | 0     
1 | accuracy    | MulticlassAccuracy | 0     
2 | model       | ResNet             | 11.2 M
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.657    Total estimated model params size (MB)


Using SGD optimizer with lr=0.01, momentum=0.95, weight_decay=0.0001, nesterov=False
Using ReduceLROnPlateau with lr_decay_factor=0.5 and lr_monitor_metric=val_loss
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 15.63it/s]

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
  return F.linear(input, self.weight, self.bias)


                                                                           

c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 1/1407 [00:00<02:47,  8.39it/s]

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 26: 100%|██████████| 1407/1407 [00:44<00:00, 31.29it/s, v_num=uce8, train_loss=0.0232, train_accuracy=1.000, val_loss=0.398, val_accuracy=0.888, avg_train_loss=0.143] Epoch 00027: reducing learning rate of group 0 to 5.0000e-03.
Epoch 36: 100%|██████████| 1407/1407 [00:44<00:00, 31.56it/s, v_num=uce8, train_loss=0.00294, train_accuracy=1.000, val_loss=0.332, val_accuracy=0.914, avg_train_loss=0.0338] Epoch 00037: reducing learning rate of group 0 to 2.5000e-03.
Epoch 44: 100%|██████████| 1407/1407 [00:44<00:00, 31.68it/s, v_num=uce8, train_loss=7.67e-5, train_accuracy=1.000, val_loss=0.377, val_accuracy=0.923, avg_train_loss=0.00819] Epoch 00045: reducing learning rate of group 0 to 1.2500e-03.
Epoch 49:   0%|          | 0/1407 [00:00<?, ?it/s, v_num=uce8, train_loss=0.000882, train_accuracy=1.000, val_loss=0.399, val_accuracy=0.926, avg_train_loss=0.00468]           

Trainer was signaled to stop but the required `min_epochs=50` or `min_steps=None` has not been met. Training will continue...


Epoch 49: 100%|██████████| 1407/1407 [00:44<00:00, 31.37it/s, v_num=uce8, train_loss=0.0012, train_accuracy=1.000, val_loss=0.407, val_accuracy=0.923, avg_train_loss=0.00356]  


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:04<00:00, 73.30it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.9246000051498413
        val_loss             0.422087162733078
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
avg_train_loss,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▇▇▇█████████████████████
train_accuracy,▆▅█▁▃█▆▇▆▇█▇▆▇██████████████████████████
train_loss,▄▄▂█▅▂▄▃▄▂▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇█████████████████████
val_accuracy,▁▃▄▄▅▅▅▄▅▆▄▆▅▆▅▅▅▅▆▆▆▅▇▆▇▇▇▇▇▇██████████
val_loss,█▆▄▅▃▄▃▅▃▃▆▃▂▁▃▃▄▃▂▂▂▄▁▃▃▁▃▃▃▂▃▄▄▃▄▃▄▄▄▅

0,1
avg_train_loss,0.00337
epoch,50.0
train_accuracy,1.0
train_loss,0.0012
trainer/global_step,70350.0
val_accuracy,0.9246
val_loss,0.42209


In [9]:
#Testing ot fused model finetuned
#runFT = 'j1go7qte'
runFT = '0a6mj32n'


api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runFT}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']
datamodule_hparams['data_augmentation'] = False

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointFT = f'model-fusion/Model Fusion/model-{runFT}:best_k'


run = wandb.init()

artifact = run.use_artifact(checkpointFT, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
otfused_lit_model = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")
wandb_tags = [f"{model_type.value}", f"{datamodule_type.value}"]

datamodule, trainer = setup_testing(f'eval finetuning {runFT}', model_type, model_hparams, datamodule_type, datamodule_hparams, wandb_tags)

datamodule.prepare_data()
datamodule.setup('test')

trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())

wandb.finish()

finetuned_loss = lmc_utils.compute_loss(otfused_lit_model, datamodule_lmc)

print(f"Finetuned otfused loss: {finetuned_loss}")

{'lr': 0.01, 'model': None, 'momentum': 0.95, 'optimizer': 'sgd', 'max_epochs': 100, 'min_epochs': 50, 'model_type': 'ModelType.RESNET18', 'loss_module': 'CrossEntropyLoss', 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'model_hparams': {'bias': False, 'num_classes': 10, 'num_channels': 3}, 'early_stopping': True, 'datamodule_type': 'DataModuleType.CIFAR10', 'lr_decay_factor': 0.5, 'lightning_params': {'lr': 0.01, 'momentum': 0.95, 'optimizer': 'sgd', 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}, 'lr_monitor_metric': 'val_loss', 'datamodule_hparams': {'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': True}, 'model_hparams/bias': False, 'model_hparams/num_classes': 10, 'model_hparams/num_channels': 3}
{'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': False}
{'bias': False, 'num_classes': 10, 'num_channels': 3}


[34m[1mwandb[0m: Downloading large artifact model-0a6mj32n:best_k, 85.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:395: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0:   1%|          | 3/313 [00:00<00:17, 18.06it/s]

  return F.linear(input, self.weight, self.bias)


Testing DataLoader 0: 100%|██████████| 313/313 [00:04<00:00, 76.00it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.9229999780654907
        val_loss            0.41299566626548767
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.923
val_loss,0.413


Finetuned otfused loss: 0.004341453356884854


In [14]:
#Testing vanilla model finetuned
#runFT = 'vyu3hupf'
runFT= 'sm4quce8'

api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runFT}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']
datamodule_hparams['data_augmentation'] = False

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointFT = f'model-fusion/Model Fusion/model-{runFT}:best_k'


run = wandb.init()

artifact = run.use_artifact(checkpointFT, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
vanilla_averaged_lit_model = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")
wandb_tags = [f"{model_type.value}", f"{datamodule_type.value}"]

datamodule, trainer = setup_testing(f'eval finetuning {runFT}', model_type, model_hparams, datamodule_type, datamodule_hparams, wandb_tags)

datamodule.prepare_data()
datamodule.setup('test')

trainer.test(vanilla_averaged_lit_model, dataloaders=datamodule.test_dataloader())

wandb.finish()

finetuned_loss = lmc_utils.compute_loss(vanilla_averaged_lit_model, datamodule_lmc)

print(f"Finetuned vanilla loss: {finetuned_loss}")

{'lr': 0.01, 'model': None, 'momentum': 0.95, 'optimizer': 'sgd', 'max_epochs': 100, 'min_epochs': 50, 'model_type': 'ModelType.RESNET18', 'loss_module': 'CrossEntropyLoss', 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'model_hparams': {'bias': False, 'num_classes': 10, 'num_channels': 3}, 'early_stopping': True, 'datamodule_type': 'DataModuleType.CIFAR10', 'lr_decay_factor': 0.5, 'lightning_params': {'lr': 0.01, 'momentum': 0.95, 'optimizer': 'sgd', 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}, 'lr_monitor_metric': 'val_loss', 'datamodule_hparams': {'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': True}, 'model_hparams/bias': False, 'model_hparams/num_classes': 10, 'model_hparams/num_channels': 3}
{'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': False}
{'bias': False, 'num_classes': 10, 'num_channels': 3}


[34m[1mwandb[0m: Downloading large artifact model-sm4quce8:best_k, 85.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:395: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0:   1%|          | 3/313 [00:00<00:17, 17.86it/s]

  return F.linear(input, self.weight, self.bias)


Testing DataLoader 0: 100%|██████████| 313/313 [00:03<00:00, 89.79it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.9200000166893005
        val_loss            0.39501965045928955
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.92
val_loss,0.39502


Finetuned vanilla loss: 0.007130932860800223


In [12]:
# Pyhessian (compute sharpness and eigenspectrum of base models, vanilla avg, ot fusion and finetuned solutions)
print("------- Computing sharpness -------")

print("------- Model A -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelA, compute_density=False, figure_name='modelA.pdf') 
print("------- Model B -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelB, compute_density=False, figure_name='modelB.pdf')

#print("------- Model A aligned to B -------")
#hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=modelA_aligned,  compute_density=False, figure_name='modelA_aligned.pdf')

print("------- OT fusion model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=ot_fused_model,  compute_density=False, figure_name='otmodel32.pdf')

print("------- Vanilla avg model (finetuned) -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=vanilla_averaged_lit_model,  compute_density=False, figure_name='vanilla_avg.pdf')

print("------- OT fusion model (finetuned) -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=otfused_lit_model, num_batches=30, compute_density=False, figure_name='otmodel.pdf')

Seed set to 42


------- Computing sharpness -------
------- Model A -------
Files already downloaded and verified
Files already downloaded and verified


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


The top Hessian eigenvalue of this model is 2.4751


Seed set to 42



***Trace:  126.4693115234375
------- Model B -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 2.4914


Seed set to 42



***Trace:  104.19172821044921
------- OT fusion model -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 3.5451


Seed set to 42



***Trace:  126.33977978046124
------- Vanilla avg model -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 3.4924


Seed set to 42



***Trace:  51.21888825201219
------- OT fusion model (finetuned) -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 4.1499

***Trace:  61.17197711651142


In [22]:
print("------- Vanilla avg model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=vanilla_averaging_model,  compute_density=False, figure_name='vanilla_avg.pdf')

Seed set to 42


------- Vanilla avg model -------
Files already downloaded and verified
Files already downloaded and verified


  return F.linear(input, self.weight, self.bias)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


The top Hessian eigenvalue of this model is -0.1860

***Trace:  1.0636819319427013


In [10]:
#testing second finetuning
print("------- OT fusion model (finetuned) -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=otfused_lit_model, num_batches=30, compute_density=False, figure_name='otmodel.pdf')

Seed set to 42


------- OT fusion model (finetuned) -------
Files already downloaded and verified
Files already downloaded and verified


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


The top Hessian eigenvalue of this model is 3.0270

***Trace:  43.03564792209201


In [15]:
#testing second finetuning vanilla
print("------- Vanilla avg model (finetuned) -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=vanilla_averaged_lit_model,  compute_density=False, figure_name='vanilla_avg.pdf')

Seed set to 42


------- Vanilla avg model (finetuned) -------
Files already downloaded and verified
Files already downloaded and verified


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


The top Hessian eigenvalue of this model is 1.6497

***Trace:  54.086493900844026
