# Experiments

## Import dependencies

In [None]:
import copy
from pathlib import Path

import lightning
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import torch
import wandb

from Experiments import baselines_experiment
from Experiments import lmc_experiment
from Experiments import otfusion_experiment
from Experiments import pyhessian_experiment
from model_fusion import lmc_utils
from model_fusion.config import BASE_DATA_DIR, CHECKPOINT_DIR
from model_fusion.datasets import DataModuleType
from model_fusion.models import ModelType
from model_fusion.models.lightning import BaseModel
from model_fusion.train import setup_training, setup_testing, get_wandb_logger

# set seed for numpy based calculations
NUMPY_SEED = 100
np.random.seed(NUMPY_SEED)

## Experiment configurations

We already provide the finetuned models through WandB, however, the `enable_finetuning` flag can be set to True to perform finetuning.

In [None]:
enable_finetuning = False

### ResNet-18, Cifar-10

#### Same initialization

In [None]:
# Parent Batch Sizes 32 and 32
resnet_cifar10_same_init_32_32_config = {
    "model_a_run": "bbecqkxs",
    "model_b_run": "uw0w0e3e",
}

# Parent Batch Sizes 128 and 128
resnet_cifar10_same_init_128_128_config = {
    "model_a_run": "3bsofnmw",
    "model_b_run": "zp0c8n4p",
}

# Parent Batch Sizes 512 and 512
resnet_cifar10_same_init_512_512_config = {
    "model_a_run": "kvuejplb",
    "model_b_run": "kwdhgbfv",
}

# Parent Batch Sizes 512 and 32
resnet_cifar10_same_init_512_32_config = {
    "model_a_run": "kvuejplb",
    "model_b_run": "bbecqkxs",
}

# Parent Batch Sizes 32 and 512
resnet_cifar10_same_init_32_512_config = {
    "model_a_run": "bbecqkxs",
    "model_b_run": "kvuejplb",
}

#### Different initialization

In [None]:
# Parent Batch Sizes 32 and 32
resnet_cifar10_diff_init_32_32_config = {
    "model_a_run": "bbecqkxs",
    "model_b_run": "k9q16yq1",
    "vanilla_averaging_model_finetuning_run": "sm4quce8",
    "ot_fused_model_finetuning_run": "0a6mj32n",
    "finetuning_batch_size": 32,
    "vanilla_finetuning_lr": 0.01,
    "vanilla_finetuning_momentum": 0.95,
    "otf_finetuning_lr": 0.01,
    "otf_finetuning_momentum": 0.95,
}

# Parent Batch Sizes 128 and 128
resnet_cifar10_diff_init_128_128_config = {
    "model_a_run": "3bsofnmw",
    "model_b_run": "q2135wcz",
    "vanilla_averaging_model_finetuning_run": "30nfxxwe",
    "ot_fused_model_finetuning_run": "8z426q3y",
    "finetuning_batch_size": 128,
    "vanilla_finetuning_lr": 0.05,
    "vanilla_finetuning_momentum": 0.9,
    "otf_finetuning_lr": 0.05,
    "otf_finetuning_momentum": 0.9,
}

# Parent Batch Sizes 512 and 512
resnet_cifar10_diff_init_512_512_config = {
    "model_a_run": "kvuejplb",
    "model_b_run": "yqpgz3ya",
    "vanilla_averaging_model_finetuning_run": "39ttcid9",
    "ot_fused_model_finetuning_run": "xbppg39l",
    "finetuning_batch_size": 512,
    "vanilla_finetuning_lr": 0.1,
    "vanilla_finetuning_momentum": 0.9,
    "otf_finetuning_lr": 0.1,
    "otf_finetuning_momentum": 0.9,
}

# Parent Batch Sizes 512 and 32
resnet_cifar10_diff_init_512_32_config = {
    "model_a_run": "kvuejplb",
    "model_b_run": "k9q16yq1",
    # FT batch size 32
    "ot_fused_model_finetuning_run": "wc7z0q4u",
    "finetuning_batch_size": 32,
    "otf_finetuning_lr": 0.01,
    "otf_finetuning_momentum": 0.9,
    # FT batch size 256 - uncomment this part and comment the above part to run with batch size 256
    # "ot_fused_model_finetuning_run": "is4t96nh",
    # "finetuning_batch_size": 256,
    # "otf_finetuning_lr": 0.1,
    # "otf_finetuning_momentum": 0.9,
}

# Parent Batch Sizes 32 and 512
resnet_cifar10_diff_init_32_512_config = {
    "model_a_run": "k9q16yq1",
    "model_b_run": "kvuejplb",
    # FT batch size 512
    "ot_fused_model_finetuning_run": "yzai8540",
    "finetuning_batch_size": 512,
    "otf_finetuning_lr": 0.1,
    "otf_finetuning_momentum": 0.9,
    # FT batch size 256 - uncomment this part and comment the above part to run with batch size 256
    # "ot_fused_model_finetuning_run": "3yaujh7p",
    # "finetuning_batch_size": 256,
    # "otf_finetuning_lr": 0.1,
    # "otf_finetuning_momentum": 0.9,
}

### ResNet-18, MNIST

#### Same initialization

In [None]:
# Parent Batch Sizes 32 and 32
resnet_mnist_same_init_32_32_config = {
    "model_a_run": "xrjk55ng",
    "model_b_run": "bmt8o992",
}

# Parent Batch Sizes 512 and 512
resnet_mnist_same_init_512_512_config = {
    "model_a_run": "djhsbo0l",
    "model_b_run": "d7g313hy",
}

#### Different initialization

In [None]:
# Parent Batch Sizes 32 and 32
resnet_mnist_diff_init_32_32_config = {
    "model_a_run": "xrjk55ng",
    "model_b_run": "u2ckjxvf",
    "ot_fused_model_finetuning_run": "k6rv9lh7",
    "finetuning_batch_size": 32,
    "otf_finetuning_lr": 0.005,
    "otf_finetuning_momentum": 0.98,
}

# Parent Batch Sizes 512 and 512
resnet_mnist_diff_init_512_512_config = {
    "model_a_run": "djhsbo0l",
    "model_b_run": "xzolnqeo",
    "ot_fused_model_finetuning_run": "insax4t9",
    "finetuning_batch_size": 512,
    "otf_finetuning_lr": 0.1,
    "otf_finetuning_momentum": 0.9,
}

### VGG-11, CIFAR-10

#### Same initialization

In [None]:
# Parent Batch Sizes 32 and 32
vgg_cifar10_same_init_32_32_config = {
    "model_a_run": "wil30lcb",
    "model_b_run": "6v6im8ni",
}

# Parent Batch Sizes 512 and 512
vgg_cifar10_same_init_512_512_config = {
    "model_a_run": "33kyx0p1",
    "model_b_run": "3ezo4au5",
}

#### Different initialization

In [None]:
# Parent Batch Sizes 32 and 32
vgg_cifar10_diff_init_32_32_config = {
    "model_a_run": "wil30lcb",
    "model_b_run": "0z3nowvr",
    "ot_fused_model_finetuning_run": "xhyln9am",
    "finetuning_batch_size": 32,
    "otf_finetuning_lr": 0.005,
    "otf_finetuning_momentum": 0.95,
}

# Parent Batch Sizes 512 and 512
vgg_cifar10_diff_init_512_512_config = {
    "model_a_run": "33kyx0p1",
    "model_b_run": "5tbujwji",
    "ot_fused_model_finetuning_run": "oq06dgmw",
    "finetuning_batch_size": 512,
    "otf_finetuning_lr": 0.1,
    "otf_finetuning_momentum": 0.95,
}

### Specify the desired experiment config here

In [None]:
experiment_config = resnet_cifar10_diff_init_32_32_config

In [None]:
# Parent model runs on WandB
runA = experiment_config["model_a_run"]
runB = experiment_config["model_b_run"]

# Finetuning runs on WandB
vanilla_averaging_model_finetuning_run = experiment_config.get("vanilla_averaging_model_finetuning_run", None)
ot_fused_model_finetuning_run = experiment_config.get("ot_fused_model_finetuning_run", None)

# Finetuning parameters
finetuning_batch_size = experiment_config.get("finetuning_batch_size", None)
vanilla_finetuning_lr = experiment_config.get("vanilla_finetuning_lr", None)
vanilla_finetuning_momentum = experiment_config.get("vanilla_finetuning_momentum", None)
otf_finetuning_lr = experiment_config.get("otf_finetuning_lr", None)
otf_finetuning_momentum = experiment_config.get("otf_finetuning_momentum", None)

## Load Experiment Configuration

In [None]:
api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runA}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']
datamodule_hparams['data_augmentation'] = False

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointA = f'model-fusion/Model Fusion/model-{runA}:best_k'
checkpointB = f'model-fusion/Model Fusion/model-{runB}:best_k'

run = wandb.init()

artifact = run.use_artifact(checkpointA, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelA = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

artifact = run.use_artifact(checkpointB, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelB = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

## LMC Barrier before alignment

In [None]:
lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA,
    modelB=modelB,
    granularity=21
)

## Model Test Accuracies

Here we compute the test accuracies for the parent models. Additionally, we create the vanilla averaging and OT fused models, and evaluate their test accuracies as well. We ignore the test loss as we only consider the training loss, which we compute later in this notebook.

In [None]:
wandb_tag = f'baselines-{runA}-{runB}'

vanilla_averaging_model = baselines_experiment.run_baselines(
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type,
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag,
)

In [None]:
wandb_tag = f"ot_model_fusion-{runA}-{runB}"

ot_fused_model, modelA_aligned = otfusion_experiment.run_otfusion(
    batch_size=batch_size,
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type,
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag
)

## LMC Barrier after alignment

In [None]:
lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA_aligned,
    modelB=modelB,
    granularity=21
)

## Training Loss of the averaged models

In [None]:
datamodule_hparams_lmc = {'batch_size': 1024, 'data_dir': BASE_DATA_DIR}
datamodule_lmc = datamodule_type.get_data_module(**datamodule_hparams)
datamodule_lmc.prepare_data()
datamodule_lmc.setup('fit')

vanilla_loss = lmc_utils.compute_loss(vanilla_averaging_model, datamodule_lmc)
fused_loss = lmc_utils.compute_loss(ot_fused_model, datamodule_lmc)

print(f"Vanilla loss pre fine-tuning: {vanilla_loss}")
print(f"Fused loss pre fine-tuning: {fused_loss}")

## Finetuning

### Finetuning the OT fused model

In [None]:
if enable_finetuning:
    min_epochs = 50
    max_epochs = 100
    datamodule_hparams['batch_size'] = finetuning_batch_size
    datamodule_hparams['data_augmentation'] = True

    datamodule = datamodule_type.get_data_module(**datamodule_hparams)
    lightning_params = {'optimizer': 'sgd', 'lr': otf_finetuning_lr, 'momentum': otf_finetuning_momentum, 'weight_decay': 0.0001, 'lr_scheduler': 'plateau', 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}
    otfused_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, model=copy.deepcopy(ot_fused_model.model), **lightning_params)


    logger_config = {'model_hparams': model_hparams} | {'datamodule_hparams': datamodule_hparams} | {'lightning_params': lightning_params} | {'min_epochs': min_epochs, 'max_epochs': max_epochs, 'model_type': model_type, 'datamodule_type': datamodule_type, 'early_stopping': True}
    logger = get_wandb_logger("otfusion finetuning", logger_config, [])
    callbacks = []
    monitor = 'val_loss'
    patience = 20
    callbacks.append(EarlyStopping(monitor=monitor, patience=patience))

    checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
    callbacks.append(checkpoint_callback)
    trainer = lightning.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger, callbacks=callbacks, deterministic='warn')


    datamodule.prepare_data()

    datamodule.setup('fit')

    trainer.fit(otfused_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())


    datamodule.setup('test')
    trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())


    wandb.finish()

### Finetuning the vanilla averaging model

In [None]:
if enable_finetuning:
    min_epochs = 50
    max_epochs = 100
    datamodule_hparams['batch_size'] = finetuning_batch_size
    datamodule_hparams['data_augmentation']=True

    datamodule = datamodule_type.get_data_module(**datamodule_hparams)
    lightning_params = {'optimizer': 'sgd', 'lr': vanilla_finetuning_lr, 'momentum': vanilla_finetuning_momentum, 'weight_decay': 0.0001, 'lr_scheduler': 'plateau', 'lr_decay_factor': 0.5, 'lr_monitor_metric': 'val_loss'}

    vanilla_averaged_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, model=copy.deepcopy(vanilla_averaging_model.model), **lightning_params)

    logger_config = {'model_hparams': model_hparams} | {'datamodule_hparams': datamodule_hparams} | {'lightning_params': lightning_params} | {'min_epochs': min_epochs, 'max_epochs': max_epochs, 'model_type': model_type, 'datamodule_type': datamodule_type, 'early_stopping': True}
    logger = get_wandb_logger("vanilla finetuning", logger_config, [])
    callbacks = []
    monitor = 'val_loss'
    patience = 20
    callbacks.append(EarlyStopping(monitor=monitor, patience=patience))

    checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
    callbacks.append(checkpoint_callback)
    trainer = lightning.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger, callbacks=callbacks, deterministic='warn')


    datamodule.prepare_data()

    datamodule.setup('fit')


    trainer.fit(vanilla_averaged_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())

    datamodule.setup('test')

    trainer.test(vanilla_averaged_lit_model, dataloaders=datamodule.test_dataloader())

    wandb.finish()

## Evaluating Finetuned models

### OT Fused model

In [None]:
if ot_fused_model_finetuning_run is not None:
    api = wandb.Api()
    run = api.run(f'model-fusion/Model Fusion/{ot_fused_model_finetuning_run}')

    print(run.config)

    batch_size = run.config['datamodule_hparams'].get('batch_size')

    datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
    datamodule_type = DataModuleType(datamodule_type_str)
    datamodule_hparams = run.config['datamodule_hparams']
    datamodule_hparams['data_augmentation'] = False

    model_type_str = run.config['model_type'].split('.')[1].lower()
    model_type = ModelType(model_type_str)

    model_hparams = run.config['model_hparams']

    print(datamodule_hparams)
    print(model_hparams)

    checkpointFT = f'model-fusion/Model Fusion/model-{ot_fused_model_finetuning_run}:best_k'


    run = wandb.init()

    artifact = run.use_artifact(checkpointFT, type='model')
    artifact_dir = artifact.download(root=CHECKPOINT_DIR)
    otfused_lit_model = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")
    wandb_tags = [f"{model_type.value}", f"{datamodule_type.value}"]

    datamodule, trainer = setup_testing(f'eval finetuning {ot_fused_model_finetuning_run}', model_type, model_hparams, datamodule_type, datamodule_hparams, wandb_tags)

    datamodule.prepare_data()
    datamodule.setup('test')

    trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())

    wandb.finish()

    finetuned_loss = lmc_utils.compute_loss(otfused_lit_model, datamodule_lmc)

    print(f"Finetuned otfused loss: {finetuned_loss}")

### Vanilla Averaging model

In [None]:
if vanilla_averaging_model_finetuning_run is not None:
    api = wandb.Api()
    run = api.run(f'model-fusion/Model Fusion/{vanilla_averaging_model_finetuning_run}')

    print(run.config)

    batch_size = run.config['datamodule_hparams'].get('batch_size')

    datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
    datamodule_type = DataModuleType(datamodule_type_str)
    datamodule_hparams = run.config['datamodule_hparams']
    datamodule_hparams['data_augmentation'] = False

    model_type_str = run.config['model_type'].split('.')[1].lower()
    model_type = ModelType(model_type_str)

    model_hparams = run.config['model_hparams']

    print(datamodule_hparams)
    print(model_hparams)

    checkpointFT = f'model-fusion/Model Fusion/model-{vanilla_averaging_model_finetuning_run}:best_k'


    run = wandb.init()

    artifact = run.use_artifact(checkpointFT, type='model')
    artifact_dir = artifact.download(root=CHECKPOINT_DIR)
    vanilla_averaged_lit_model = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")
    wandb_tags = [f"{model_type.value}", f"{datamodule_type.value}"]

    datamodule, trainer = setup_testing(f'eval finetuning {vanilla_averaging_model_finetuning_run}', model_type, model_hparams, datamodule_type, datamodule_hparams, wandb_tags)

    datamodule.prepare_data()
    datamodule.setup('test')

    trainer.test(vanilla_averaged_lit_model, dataloaders=datamodule.test_dataloader())

    wandb.finish()

    finetuned_loss = lmc_utils.compute_loss(vanilla_averaged_lit_model, datamodule_lmc)

    print(f"Finetuned vanilla loss: {finetuned_loss}")

## Curvature Analysis

We use Pyhessian to compute sharpness and eigenspectrum of the base models, vanilla avg., OT fusion and the finetuned solutions

In [None]:
print("------- Computing sharpness -------")

print("------- Model A -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelA, compute_density=False, figure_name='modelA.pdf')

print("------- Model B -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelB, compute_density=False, figure_name='modelB.pdf')

print("------- Model A aligned to B -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelA_aligned, compute_density=False, figure_name='modelA_aligned.pdf')

print("------- Vanilla avg model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=vanilla_averaging_model, compute_density=False, figure_name='vanilla_avg.pdf')

print("------- OT fusion model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=ot_fused_model, compute_density=True, figure_name='otmodel.pdf')

if vanilla_averaging_model_finetuning_run is not None:
    print("------- Vanilla avg model (finetuned) -------")
    hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=vanilla_averaged_lit_model, compute_density=False, figure_name='vanilla_avg_ft.pdf')

if ot_fused_model_finetuning_run is not None:
    print("------- OT fusion model (finetuned) -------")
    hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=otfused_lit_model, compute_density=False, figure_name='otmodel_ft.pdf')