## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_tasks", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
import pytorch_lightning as pl

datamodule: pl.LightningDataModule = hydra.utils.instantiate(core_cfg.nn.data, _recursive_=False)

test_dataloaders = []
train_dataloaders = []

for task_ind in range(datamodule.num_tasks + 1):
    datamodule.task_ind = task_ind
    datamodule.transform_func = hydra.utils.instantiate(core_cfg.dataset.transform_func, _recursive_=True)
    datamodule.setup()
    test_dataloaders.append(datamodule.test_dataloader()[0])
    train_dataloaders.append(datamodule.train_dataloader())

In [None]:
symbol_

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda task: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_T{task}_{cfg.seed_index}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(task): load_model_from_artifact(run, artifact_path(task))
    for task in range(cfg.num_tasks + 1)
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
symbols = symbols.difference({"o"})  # "o" is the model trained over all the dataset

sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
from ccmm.matching.sinkhorn_matching import get_perm_dict


ref_model = copy.deepcopy(models["a"])
dummy_input = torch.randn(1, 3, 32, 32)

perm_dict, map_param_index, map_prev_param_index = get_perm_dict(ref_model, dummy_input)

In [None]:
from pprint import pprint

# map each layer to the matrix permuting its rows
# pprint(map_prev_param_index)

In [None]:
# map each layer to the matrix permuting its columns
# pprint(map_param_index)

In [None]:
print(perm_dict)

In [None]:
from ccmm.utils.graph import graph_permutations_to_perm_spec

perm_spec = graph_permutations_to_perm_spec(ref_model, perm_dict, map_param_index, map_prev_param_index)

In [None]:
# ref_model = list(models.values())[0]
# assert set(perm_spec.layer_and_axes_to_perm.keys()) == set(ref_model.state_dict().keys())

In [None]:
# sorted(set(perm_spec.layer_and_axes_to_perm.keys()).difference(set(ref_model.state_dict().keys())))

# sorted(set(ref_model.state_dict().keys()).difference(set(perm_spec.layer_and_axes_to_perm.keys())))

In [None]:
from ccmm.matching.permutation_spec import PermutationSpecBuilder
from ccmm.matching.permutation_spec import PermutationSpec

layer_and_axes_to_perm = {k[6:]: v for k, v in perm_spec.layer_and_axes_to_perm.items()}
perm_spec_builder = PermutationSpecBuilder()
perm_spec = perm_spec_builder.permutation_spec_from_axes_to_perm(layer_and_axes_to_perm)

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=perm_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
def permute_batchnorm(model, perm, perm_dict, map_param_index):

    for name, module in model.named_modules():

        if "BatchNorm" in str(type(module)):

            if name + ".weight" in map_param_index:

                if module.running_mean is None and module.running_var is None:
                    continue

                i = perm_dict[map_param_index[name + ".weight"]]

                index = torch.argmax(perm[i], dim=1) if i is not None else torch.arange(module.running_mean.shape[0])

                module.running_mean.copy_(module.running_mean[index, ...])
                module.running_var.copy_(module.running_var[index, ...])

In [None]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

vector_to_state_dict
models_permuted_to_universe = {symbol: copy.deepcopy(models[symbol]) for symbol in symbols}

for symbol, model in models_permuted_to_universe.items():
    perms_to_universe = {}

    for perm_name, perm in permutations[symbol].items():
        perm = perm_indices_to_perm_matrix(perm)
        perm_to_universe = perm.T
        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)
        perms_to_universe[perm_name] = perm_to_universe

    permuted_params = apply_permutation_to_statedict(perm_spec, perms_to_universe, model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [None]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["a"])

    permuted_params = apply_permutation_to_statedict(
        perm_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )
    ref_model.model.load_state_dict(permuted_params)
    models_permuted_pairwise[fixed][permutee] = ref_model

### Check performance of models before and after permutation

In [None]:
# loader = train_dataloaders[0]
# for symbol, model in models_permuted_to_universe.items():
#     trainer.test(models_permuted_to_universe[symbol], loader)
#     trainer.test(models[symbol], loader)

## Analyze models as vectors

### Flatten models

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()
    }
    for symbol, models in models_permuted_pairwise.items()
}

## Interpolation curves

In [None]:
def linear_interpolation(model_a, model_b, lamb):
    return (1 - lamb) * model_a + lamb * model_b


def get_interp_curve(lambdas, model_a, model_b, ref_model, test_loader):
    interp_losses = []
    interp_accs = []

    for lamb in lambdas:
        interp_params = linear_interpolation(model_a=model_a, model_b=model_b, lamb=lamb)

        interp_params = vector_to_state_dict(interp_params, ref_model.model)

        ref_model.model.load_state_dict(interp_params)
        results = trainer.test(ref_model, test_loader, verbose=False)

        interp_losses.append(results[0][f"loss/test"])
        interp_accs.append(results[0][f"acc/test"])

    return interp_losses, interp_accs

## Average in the universe

In [None]:
merged_model = copy.deepcopy(models["a"])

vec = torch.nn.utils.parameters_to_vector(merged_model.parameters())

vec = torch.stack([model for model in flat_models_permuted_to_universe.values()]).mean(dim=0)

torch.nn.utils.vector_to_parameters(vec, merged_model.parameters())

## Pre-repair evaluation

In [None]:
symbols = sorted(list(symbols))

for symbol, loader in zip(symbols, test_dataloaders[1:]):
    pylogger.info(f"Symbol: {symbol}")

    task_spec_model = models[symbol]
    pylogger.info("Task specific")
    trainer.test(task_spec_model, loader)

    pylogger.info("Merged model")
    trainer.test(merged_model, loader)

## Repair

In [None]:
train_dataloaders_repeated = [train_dataloaders[0]] * len(symbols)

In [None]:
from ccmm.matching.repair import repair_model


repaired_model = repair_model(merged_model, models_permuted_to_universe, train_dataloaders_repeated)

### Evaluation: merged model vs task-specific models 

In [None]:
symbols = sorted(list(symbols))

for symbol, loader in zip(symbols, test_dataloaders[1:]):
    pylogger.info(f"Symbol: {symbol}")

    task_spec_model = models[symbol]
    pylogger.info("Task specific")
    trainer.test(task_spec_model, loader)

    pylogger.info("Merged model")
    trainer.test(repaired_model, loader)

### Evaluation: merged model vs task-specific models on the whole dataset

In [None]:
global_loader = test_dataloaders[0]

for symbol, model in models.items():
    pylogger.info(f"Symbol: {symbol}")
    trainer.test(model, global_loader)

pylogger.info("Merged model")
trainer.test(repaired_model, global_loader)

### Evaluation: merged model vs task-specific models on tasks different from the training tasks

In [None]:
# for symbol, model in models.items():
#     pylogger.info(f'Symbol: {symbol}')
#     trainer.test(model, global_loader)

# pylogger.info('Merged model')
# trainer.test(repaired_model, global_loader)

### Plot LMC

In [None]:
def plot_lmc(values, lambdas, labels, axis=None):

    num_curves = len(values)
    transparencies = np.linspace(0.5, 1, num_curves)
    linewidths = np.linspace(2.0, 4.0, num_curves)

    for i, (val, label) in enumerate(zip(values, labels)):
        if axis is None:
            axis = plt

        axis.plot(lambdas, val, label=label, alpha=transparencies[i], linewidth=linewidths[i])

    plt.legend()

In [None]:
p_perm_to_a = models_permuted_pairwise["a"]["b"]
lambdas = np.linspace(0, 1, 3)

from ccmm.utils.utils import get_interpolated_loss_acc_curves

loss, acc = get_interpolated_loss_acc_curves(
    model_a=models["a"],
    model_b=p_perm_to_a,
    lambdas=lambdas,
    ref_model=ref_model,
    trainer=trainer,
    loader=test_dataloaders[0],
)

In [None]:
values = [loss]
labels = ["Loss"]
plot_lmc(values, lambdas, labels)

In [None]:
values = [acc]
labels = ["Acc"]
plot_lmc(values, lambdas, labels)