In [None]:
import wandb
import numpy as np
from model_fusion.config import BASE_DATA_DIR, CHECKPOINT_DIR
from pathlib import Path
from model_fusion.datasets import DataModuleType
from model_fusion.models import ModelType
from model_fusion.models.lightning import BaseModel 
from Experiments import lmc_experiment
from Experiments import baselines_experiment
from Experiments import otfusion_experiment


# set seed for numpy based calculations
NUMPY_SEED = 100
np.random.seed(NUMPY_SEED)

In [None]:
print("------- Loading models -------")

# select wandb run names
runA = '95shkeqb'
runB = '95shkeqb'

api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runA}')

batch_size = run.config['batch_size']

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = {'batch_size': batch_size, 'data_dir': BASE_DATA_DIR}

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

model_hparams = list(filter(lambda x: 'model_hparams/' in x[0], run.config.items()))
model_hparams = {k.split('/')[1]: v for k, v in model_hparams}

checkpointA = f'model-fusion/Model Fusion/model-{runA}:best'
checkpointB = f'model-fusion/Model Fusion/model-{runB}:best'

run = wandb.init()

artifact = run.use_artifact(checkpointA, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelA = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

artifact = run.use_artifact(checkpointB, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelB = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

In [None]:
# LMC barrier
print("------- Computing LMC barrier -------")

lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA,
    modelB=modelB,
)

In [None]:
# Baselines (prediction ensembling, vanilla averaging)
print("------- Computing baselines -------")

wandb_tag = f'baselines-{runA}-{runB}'

baselines_experiment.run_baselines(
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag,
)

In [None]:
# OT model fusion + eval aligned model 
print("------- Computing model fusion -------")

wandb_tag = "ot model fusion"

otfusion_experiment.run_otfusion(
    batch_size=batch_size,
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag
)

In [None]:
# Comparison I (compute sharpness and eigenspectrum of vanilla avg and ot fusion solutions)

In [None]:
# finetuning


In [None]:
# Comparison II (compute sharpness of finetuned solutions)