In [1]:
import wandb
import numpy as np
import torch
import lightning
import copy
from pathlib import Path
from model_fusion.config import BASE_DATA_DIR, CHECKPOINT_DIR
from model_fusion.datasets import DataModuleType
from model_fusion.models import ModelType
from model_fusion.models.lightning import BaseModel 
from Experiments import lmc_experiment
from model_fusion import lmc_utils
from Experiments import baselines_experiment
from Experiments import otfusion_experiment
from Experiments import pyhessian_experiment
from model_fusion.train import setup_training, setup_testing


# set seed for numpy based calculations
NUMPY_SEED = 100
np.random.seed(NUMPY_SEED)

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


In [2]:
print("------- Loading models -------")

# select wandb run names
runA = 'bbecqkxs'#32
runB = 'kvuejplb'#same init, 512

api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runA}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']
datamodule_hparams['data_augmentation'] = False

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointA = f'model-fusion/Model Fusion/model-{runA}:best_k'
checkpointB = f'model-fusion/Model Fusion/model-{runB}:best_k'

run = wandb.init()

artifact = run.use_artifact(checkpointA, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelA = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

artifact = run.use_artifact(checkpointB, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelB = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

------- Loading models -------


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


{'lr': 0.025, 'momentum': 0.9, 'optimizer': 'sgd', 'max_epochs': 200, 'min_epochs': 50, 'model_seed': 42, 'model_type': 'ModelType.RESNET18', 'loss_module': 'CrossEntropyLoss', 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'model_hparams': {'bias': False, 'num_classes': 10, 'num_channels': 3}, 'early_stopping': True, 'datamodule_type': 'DataModuleType.CIFAR10', 'lr_decay_factor': 0.1, 'lightning_params': {'lr': 0.025, 'momentum': 0.9, 'optimizer': 'sgd', 'model_seed': 42, 'lr_scheduler': 'plateau', 'weight_decay': 0.0001, 'lr_decay_factor': 0.1, 'lr_monitor_metric': 'val_loss'}, 'lr_monitor_metric': 'val_loss', 'datamodule_hparams': {'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': True}, 'model_hparams/bias': False, 'model_hparams/num_classes': 10, 'model_hparams/num_channels': 3}
{'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': False}
{'bias': False, 'num_classes': 10, 'num_channels': 3}


[34m[1mwandb[0m: Currently logged in as: [33mframbelli[0m ([33mmodel-fusion[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model-bbecqkxs:best_k, 85.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact model-kvuejplb:best_k, 85.20MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4


In [3]:
# LMC barrier
print("------- Computing LMC barrier before alignment-------")

lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA,
    modelB=modelB,
    granularity=21
)

------- Computing LMC barrier before alignment-------
Files already downloaded and verified
Files already downloaded and verified
Alpha: 0.00 (model 2), Train average loss: 0.03120 Train barrier:  0
Alpha: 1.00 (model 1), Train average loss: 0.05230 Train barrier:  0
Alpha: 0.05, Train average loss: 0.06109 Train barrier 0.028833368214103913
Alpha: 0.10, Train average loss: 0.16498 Train barrier 0.1316721314605077
Alpha: 0.15, Train average loss: 0.38893 Train barrier 0.3545673969027731
Alpha: 0.20, Train average loss: 0.73206 Train barrier 0.6966395740011003
Alpha: 0.25, Train average loss: 1.11352 Train barrier 1.0770417473514875
Alpha: 0.30, Train average loss: 1.42316 Train barrier 1.3856339737621943
Alpha: 0.35, Train average loss: 1.61437 Train barrier 1.5757897419667244
Alpha: 0.40, Train average loss: 1.69707 Train barrier 1.6574350990147062
Alpha: 0.45, Train average loss: 1.68769 Train barrier 1.6469913048709763
Alpha: 0.50, Train average loss: 1.59088 Train barrier 1.5491266

In [4]:
# Baselines (prediction ensembling, vanilla averaging)
print("------- Computing baselines -------")

wandb_tag = f'baselines-{runA}-{runB}'

vanilla_averaging_model = baselines_experiment.run_baselines(
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag,
)

------- Computing baselines -------
------- Prediction based ensembling -------
------- Naive ensembling of weights -------
------- Evaluating baselines -------


c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:395: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


------- Evaluating base models -------
Testing DataLoader 0: 100%|██████████| 10/10 [00:02<00:00,  4.03it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy           0.906499981880188
        val_loss            0.37029629945755005
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Testing DataLoader 0: 100%|██████████| 10/10 [00:02<00:00,  4.62it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.9037999510765076
        val_loss            0.42905566096305847
──────────────────────────────────────────────

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Test set: Average loss: 0.3251, Accuracy: 91.96%
------- Evaluating vanilla averaging -------
Testing DataLoader 0: 100%|██████████| 10/10 [00:03<00:00,  3.22it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.4903999865055084
        val_loss            1.6003756523132324
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁▁▁
trainer/global_step,▁▁▁
val_accuracy,██▁
val_loss,▁▁█

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.4904
val_loss,1.60038


In [5]:
# OT model fusion + eval aligned model 
print("------- Computing model fusion -------")

wandb_tag = f"ot_model_fusion-{runA}-{runB}"

ot_fused_model, modelA_aligned = otfusion_experiment.run_otfusion(
    batch_size=batch_size,
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag
)

------- Computing model fusion -------
------- Setting up parameters -------
{'seed': 42, 'data_dir': 'data', 'batch_size': 32, 'data_augmentation': False}
The parameters are: 
 {'eval_aligned': True, 'num_models': 2, 'width_ratio': 1, 'handle_skips': True, 'exact': True, 'activation_seed': 21, 'activation_histograms': True, 'ground_metric': 'euclidean', 'ground_metric_normalize': 'none', 'same_model': False, 'geom_ensemble_type': 'acts', 'act_num_samples': 200, 'skip_last_layer': False, 'skip_last_layer_type': 'average', 'softmax_temperature': 1, 'past_correction': True, 'correction': True, 'normalize_acts': False, 'normalize_wts': False, 'activation_normalize': False, 'center_acts': False, 'prelu_acts': False, 'pool_acts': False, 'pool_relu': False, 'importance': None, 'proper_marginals': False, 'not_squared': True, 'ground_metric_eff': False, 'dist_normalize': False, 'clip_gm': False, 'clip_min': 0, 'clip_max': 5, 'tmap_stats': False, 'ensemble_step': 0.5, 'reg': 0.01}
------- OT mo

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:03<00:00, 85.40it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.6305000185966492
        val_loss            1.0602638721466064
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.6305
val_loss,1.06026


# model parameters:  21
# new parameters:  21
fusing:  model.conv1.weight
fusing:  model.layer1.0.conv1.weight
fusing:  model.layer1.0.conv2.weight
fusing:  model.layer1.1.conv1.weight
fusing:  model.layer1.1.conv2.weight
fusing:  model.layer2.0.conv1.weight
fusing:  model.layer2.0.conv2.weight
fusing:  model.layer2.0.shortcut.0.weight
fusing:  model.layer2.1.conv1.weight
fusing:  model.layer2.1.conv2.weight
fusing:  model.layer3.0.conv1.weight
fusing:  model.layer3.0.conv2.weight
fusing:  model.layer3.0.shortcut.0.weight
fusing:  model.layer3.1.conv1.weight
fusing:  model.layer3.1.conv2.weight
fusing:  model.layer4.0.conv1.weight
fusing:  model.layer4.0.conv2.weight
fusing:  model.layer4.0.shortcut.0.weight
fusing:  model.layer4.1.conv1.weight
fusing:  model.layer4.1.conv2.weight
fusing:  model.fc.weight
------- Evaluating ot fusion model -------


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\filos\OneDrive\Desktop\ETH\model-fusion\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 313/313 [00:03<00:00, 88.30it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.5566999912261963
        val_loss             1.379953145980835
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.5567
val_loss,1.37995


In [6]:
# LMC barrier
print("------- Computing LMC barrier after alignment -------")

lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA_aligned,
    modelB=modelB,
    granularity=21
)

# Losses for ot fusion model and vanilla averaging model
datamodule_hparams_lmc = {'batch_size': 1024, 'data_dir': BASE_DATA_DIR}
datamodule_lmc = datamodule_type.get_data_module(**datamodule_hparams)
datamodule_lmc.prepare_data()
datamodule_lmc.setup('fit')

vanilla_loss = lmc_utils.compute_loss(vanilla_averaging_model, datamodule_lmc)
fused_loss = lmc_utils.compute_loss(ot_fused_model, datamodule_lmc)

print(f"Vanilla loss: {vanilla_loss}")
print(f"Fused loss: {fused_loss}")

------- Computing LMC barrier after alignment -------
Files already downloaded and verified
Files already downloaded and verified
Alpha: 0.00 (model 2), Train average loss: 0.03120 Train barrier:  0
Alpha: 1.00 (model 1), Train average loss: 0.98812 Train barrier:  0
Alpha: 0.05, Train average loss: 0.05570 Train barrier -0.023344237804578415
Alpha: 0.10, Train average loss: 0.12890 Train barrier 0.0020119033154513966
Alpha: 0.15, Train average loss: 0.26597 Train barrier 0.09123515401101773
Alpha: 0.20, Train average loss: 0.45719 Train barrier 0.23460416985352833
Alpha: 0.25, Train average loss: 0.67603 Train barrier 0.4055960914060473
Alpha: 0.30, Train average loss: 0.88917 Train barrier 0.5708950679497586
Alpha: 0.35, Train average loss: 1.07046 Train barrier 0.7043376100316643
Alpha: 0.40, Train average loss: 1.20871 Train barrier 0.794746105760336
Alpha: 0.45, Train average loss: 1.30429 Train barrier 0.842474383788473
Alpha: 0.50, Train average loss: 1.36265 Train barrier 0.852

In [None]:
# finetuning
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint

from model_fusion.train import get_wandb_logger

_, datamodule, trainer = None, None, None

min_epochs = 50
max_epochs = 100
datamodule_hparams['batch_size'] = 256 #and 32
datamodule_hparams['data_augmentation']=True

datamodule = datamodule_type.get_data_module(**datamodule_hparams)
lightning_params = {'optimizer': 'sgd', 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0001, 'lr_scheduler': 'plateau', 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}
otfused_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, model=copy.deepcopy(ot_fused_model.model), **lightning_params)
#vanilla_averaged_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, lightning_params=lightning_params, model=vanilla_averaging_model.model)

logger_config = {'model_hparams': model_hparams} | {'datamodule_hparams': datamodule_hparams} | {'lightning_params': lightning_params} | {'min_epochs': min_epochs, 'max_epochs': max_epochs, 'model_type': model_type, 'datamodule_type': datamodule_type, 'early_stopping': True}
logger = get_wandb_logger("otfusion finetuning", logger_config, [])
callbacks = []
monitor = 'val_loss'
patience = 20
callbacks.append(EarlyStopping(monitor=monitor, patience=patience))

checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
callbacks.append(checkpoint_callback)
trainer = lightning.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger, callbacks=callbacks, deterministic='warn')


datamodule.prepare_data()

datamodule.setup('fit')

trainer.fit(otfused_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())
#trainer.fit(vanilla_averaged_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())

datamodule.setup('test')
trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())
#trainer.test(vanilla_averaged_lit_model, dataloaders=datamodule.test_dataloader())

wandb.finish()

finetuned_loss = lmc_utils.compute_loss(otfused_lit_model, datamodule_lmc)

print(f"Finetuned otfusion loss: {finetuned_loss}")

In [None]:
#Testing finetuning
#runFT = 'wc7z0q4u' #32
runFT = 'is4t96nh' #256

api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runFT}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']
datamodule_hparams['data_augmentation'] = False

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointFT = f'model-fusion/Model Fusion/model-{runFT}:best_k'


run = wandb.init()

artifact = run.use_artifact(checkpointFT, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
otfused_lit_model = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")
wandb_tags = [f"{model_type.value}", f"{datamodule_type.value}"]

datamodule, trainer = setup_testing(f'eval finetuning {runFT}', model_type, model_hparams, datamodule_type, datamodule_hparams, wandb_tags)

datamodule.prepare_data()
datamodule.setup('test')

trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())

wandb.finish()


finetuned_loss = lmc_utils.compute_loss(otfused_lit_model, datamodule_lmc)

print(f"Finetuned otfusion loss: {finetuned_loss}")

In [7]:
# Pyhessian (compute sharpness and eigenspectrum of base models, vanilla avg, ot fusion and finetuned solutions)
print("------- Computing sharpness -------")

print("------- Model A -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelA,compute_density=False, figure_name='modelA.pdf') 

print("------- Model B -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelB, compute_density=False, figure_name='modelB.pdf')

#print("------- Model A aligned to B -------")
#hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=modelA_aligned compute_density=False, figure_name='modelA_aligned.pdf')

print("------- OT fusion model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=ot_fused_model,  compute_density=False, figure_name='otmodel512_32.pdf')

print("------- Vanilla avg model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=vanilla_averaging_model, compute_density=False, figure_name='vanilla_avg.pdf')

#print("------- OT fusion model (finetuned) -------")
#hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=otfused_lit_model,  compute_density=False, figure_name='otmodel.pdf')

Seed set to 42


------- Computing sharpness -------
------- Model A -------
Files already downloaded and verified
Files already downloaded and verified


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


The top Hessian eigenvalue of this model is 2.4751


Seed set to 42



***Trace:  126.4693588256836
------- Model B -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 8.0340


Seed set to 42



***Trace:  231.17194213867188
------- OT fusion model -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 6.9863


Seed set to 42



***Trace:  119.44730949401855
------- Vanilla avg model -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 4.2220

***Trace:  62.63650944768166


In [8]:
print("------- Aligned model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=modelA_aligned, compute_density=False, figure_name='aligned.pdf')

Seed set to 42


------- Aligned model -------
Files already downloaded and verified
Files already downloaded and verified
The top Hessian eigenvalue of this model is 13.8469

***Trace:  490.30308884840747
