## Imports

In [None]:
import copy
import itertools
import logging
import math
from functools import partial
from pathlib import Path
from typing import Dict
import json

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    get_interpolated_loss_acc_curves,
    l2_norm_models,
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)
sns.set_context("talk")

cmap_name = "coolwarm_r"

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching")

In [None]:
cfg = compose(config_name="matching", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}
model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

### Git Re-Basin

In [None]:
# matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
# pylogger.info(f"Matcher: {matcher.name}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
from ccmm.matching.utils import get_inverse_permutations

symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)
fixed, permutee = "a", "b"
fixed_model, permutee_model = models[fixed], models[permutee]

# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a
permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}

# matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
# permutations[fixed_symbol][permutee_symbol], perm_history = matcher(
#     fixed=fixed_model.model, permutee=permutee_model.model
# )

# permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(permutations[fixed_symbol][permutee_symbol])

### Frank-Wolfe

In [None]:
from ccmm.matching.frank_wolfe_matching import collect_perm_sizes, frank_wolfe_weight_matching_trial
from ccmm.matching.matcher import FrankWolfeMatcher

params_a = fixed_model.model.state_dict()
params_b = permutee_model.model.state_dict()
perm_sizes = collect_perm_sizes(permutation_spec, params_a)

initialization_method = "identity"

In [None]:
perm_matrices, perm_matrices_history, new_obj, all_step_sizes = frank_wolfe_weight_matching_trial(
    params_a,
    params_b,
    perm_sizes,
    initialization_method,
    permutation_spec,
    200,
    device="cuda",
    return_step_sizes=True,
    global_step_size=True,
)

restore_original_weights(models, model_orig_weights)

In [None]:
plt.plot([step_size for ind, step_size in enumerate(all_step_sizes)], color=palette["light red"])

plt.xlabel("Iteration")
plt.ylabel("Step size")
plt.title("Step size")

plt.savefig("figures/convergence_step_sizes.pdf", bbox_inches="tight")

In [None]:
from ccmm.matching.weight_matching import solve_linear_assignment_problem

permutations[fixed][permutee] = {p: solve_linear_assignment_problem(perm) for p, perm in perm_matrices.items()}

In [None]:
updated_params = {fixed: {permutee: None}}

In [None]:
from scripts.evaluate_matched_models import evaluate_pair_of_models

restore_original_weights(models, model_orig_weights)

pylogger.info(f"Permuting model {permutee} into {fixed}.")

# perms[a, b] maps b -> a
updated_params[fixed][permutee] = apply_permutation_to_statedict(
    permutation_spec, permutations[fixed][permutee], models[permutee].model.state_dict()
)
restore_original_weights(models, model_orig_weights)

lambdas = np.linspace(0, 1, num=4)

results = evaluate_pair_of_models(
    models,
    fixed,
    permutee,
    updated_params,
    train_loader,
    test_loader,
    lambdas,
    core_cfg,
)

In [None]:
results["test_loss_barrier"]