In [None]:
import wandb
import numpy as np
import torch
import lightning
from pathlib import Path
from model_fusion.config import BASE_DATA_DIR, CHECKPOINT_DIR
from model_fusion.datasets import DataModuleType
from model_fusion.models import ModelType
from model_fusion.models.lightning import BaseModel 
from Experiments import lmc_experiment
from Experiments import baselines_experiment
from Experiments import otfusion_experiment
from Experiments import pyhessian_experiment
from model_fusion.train import setup_training


# set seed for numpy based calculations
NUMPY_SEED = 100
np.random.seed(NUMPY_SEED)

In [None]:
print("------- Loading models -------")

# select wandb run names
runA = 'xzolnqeo'
runB = 'djhsbo0l'

api = wandb.Api()
run = api.run(f'model-fusion/Model Fusion/{runA}')

print(run.config)

batch_size = run.config['datamodule_hparams'].get('batch_size')

datamodule_type_str = run.config['datamodule_type'].split('.')[1].lower()
datamodule_type = DataModuleType(datamodule_type_str)
datamodule_hparams = run.config['datamodule_hparams']

model_type_str = run.config['model_type'].split('.')[1].lower()
model_type = ModelType(model_type_str)

# model_hparams = list(filter(lambda x: 'model_hparams/' in x[0], run.config.items()))
# model_hparams = {k.split('/')[1]: v for k, v in model_hparams}

model_hparams = run.config['model_hparams']

print(datamodule_hparams)
print(model_hparams)

checkpointA = f'model-fusion/Model Fusion/model-{runA}:best_k'
checkpointB = f'model-fusion/Model Fusion/model-{runB}:best_k'

run = wandb.init()

artifact = run.use_artifact(checkpointA, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelA = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

torch.save(modelA.model.state_dict(), 'model0.ckpt')

artifact = run.use_artifact(checkpointB, type='model')
artifact_dir = artifact.download(root=CHECKPOINT_DIR)
modelB = BaseModel.load_from_checkpoint(Path(artifact_dir)/"model.ckpt")

torch.save(modelB.model.state_dict(), 'model1.ckpt')


In [None]:
# LMC barrier
print("------- Computing LMC barrier -------")

lmc_experiment.run_lmc(
    datamodule_type=datamodule_type,
    modelA=modelA,
    modelB=modelB,
    granularity=3
)

In [None]:
# Baselines (prediction ensembling, vanilla averaging)
print("------- Computing baselines -------")

wandb_tag = f'baselines-{runA}-{runB}'

vanilla_averaging_model = baselines_experiment.run_baselines(
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag,
)

In [None]:
# OT model fusion + eval aligned model 
print("------- Computing model fusion -------")

wandb_tag = f"ot_model_fusion-{runA}-{runB}"

ot_fused_model, modelA_aligned = otfusion_experiment.run_otfusion(
    batch_size=batch_size,
    datamodule_type=datamodule_type,
    datamodule_hparams=datamodule_hparams,
    model_type=model_type, 
    model_hparams=model_hparams,
    modelA=modelA,
    modelB=modelB,
    wandb_tag=wandb_tag
)

In [None]:
# Pyhessian I (compute sharpness and eigenspectrum of base models, vanilla avg and ot fusion solutions)
print("------- Computing sharpness -------")

print("------- Model A -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelA, num_batches=10, compute_density=False, figure_name='modelA.pdf') 
print("------- Model B -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=modelB, num_batches=10, compute_density=False, figure_name='modelB.pdf')

# print("------- Model A aligned to B -------")
# hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=modelA_aligned)

print("------- OT fusion model -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=ot_fused_model, num_batches=10, compute_density=False, figure_name='otmodel.pdf')

# print("------- Vanilla avg model -------")
# hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=vanilla_averaging_model)


In [None]:
# finetuning
# TODO
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint

from model_fusion.train import get_wandb_logger

_, datamodule, trainer = None, None, None

min_epochs = 5
max_epochs = 10

datamodule = datamodule_type.get_data_module(**datamodule_hparams)
lightning_params = {'optimizer': 'sgd', 'lr': 0.05, 'momentum': 0.9, 'weight_decay': 0.0001, 'lr_scheduler': 'plateau', 'lr_decay_factor': 0.1, 'lr_monitor_metric': 'val_loss'}
otfused_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, lightning_params=lightning_params, model=ot_fused_model.model)
#vanilla_averaged_lit_model = BaseModel(model_type=model_type, model_hparams=model_hparams, lightning_params=lightning_params, model=vanilla_averaging_model.model)

logger_config = {'model_hparams': model_hparams} | {'datamodule_hparams': datamodule_hparams} | {'lightning_params': lightning_params} | {'min_epochs': min_epochs, 'max_epochs': max_epochs, 'model_type': model_type, 'datamodule_type': datamodule_type, 'early_stopping': True}
logger = get_wandb_logger("testy test test", logger_config, [])
callbacks = []
monitor = 'val_loss'
patience = 15
callbacks.append(EarlyStopping(monitor=monitor, patience=patience))

checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
callbacks.append(checkpoint_callback)
trainer = lightning.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger, callbacks=callbacks, deterministic='warn')


datamodule.prepare_data()

datamodule.setup('fit')

trainer.fit(otfused_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())
#trainer.fit(vanilla_averaged_lit_model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())

datamodule.setup('test')
trainer.test(otfused_lit_model, dataloaders=datamodule.test_dataloader())
#trainer.test(vanilla_averaged_lit_model, dataloaders=datamodule.test_dataloader())

wandb.finish()

In [None]:
# Comparison II (compute sharpness of finetuned solutions)
print("------- Computing sharpness of finetuned solutions -------")

print("------- OT fusion model (finetuned) -------")
hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type, model=otfused_lit_model, num_batches=10, compute_density=False, figure_name='otmodel.pdf')

#print("------- Vanilla avg model (finetuned) -------")
#hessian_comp = pyhessian_experiment.run_pyhessian(datamodule_type=datamodule_type,model=vanilla_averaging_model)
