## Imports

In [None]:
import copy
import itertools
import logging
import math
from functools import partial
from pathlib import Path
from typing import Dict
import json
from ccmm.matching.frank_wolfe_matching import frank_wolfe_weight_matching
from ccmm.matching.weight_matching import solve_linear_assignment_problem


import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    get_interpolated_loss_acc_curves,
    l2_norm_models,
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)

from ccmm.matching.matcher import FrankWolfeMatcher
from scripts.evaluate_matched_models import evaluate_pair_of_models

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)
sns.set_context("talk")

cmap_name = "coolwarm_r"

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=["model.widen_factor=2"])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

# TODO: remove ln stuff
artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_ln_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching 

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation_spec()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

### Identity initialization

In [None]:
lambdas = [0.0, 0.5, 1.0]
num_trials = 10
all_results = {}
model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}

for fixed_symb, permutee_symb in [("a", "b"), ("b", "c"), ("a", "c")]:
    updated_params = {fixed_symb: {permutee_symb: None}}

    evaluate = partial(
        evaluate_pair_of_models, train_loader=train_loader, test_loader=test_loader, lambdas=lambdas, cfg=core_cfg
    )

    fixed_model, permutee_model = models[fixed_symb], models[permutee_symb]

    # Identity
    permutations, best_perm_matrices_history = frank_wolfe_weight_matching(
        ps=permutation_spec,
        fixed=fixed_model.model.state_dict(),
        permutee=permutee_model.model.state_dict(),
        num_trials=1,
        initialization_method="identity",
    )

    updated_params[fixed_symb][permutee_symb] = apply_permutation_to_statedict(
        permutation_spec, permutations, models[permutee_symb].model.state_dict()
    )

    identity_results = evaluate(
        models=models,
        fixed_id=fixed_symb,
        permutee_id=permutee_symb,
        updated_params=updated_params,
    )

    restore_original_weights(models, model_orig_weights)

    # bistochastic barycenter
    permutations, best_perm_matrices_history = frank_wolfe_weight_matching(
        ps=permutation_spec,
        fixed=fixed_model.model.state_dict(),
        permutee=permutee_model.model.state_dict(),
        num_trials=1,
        initialization_method="bistochastic_barycenter",
    )

    updated_params[fixed_symb][permutee_symb] = apply_permutation_to_statedict(
        permutation_spec, permutations, models[permutee_symb].model.state_dict()
    )

    barycenter_results = evaluate(
        models=models,
        fixed_id=fixed_symb,
        permutee_id=permutee_symb,
        updated_params=updated_params,
    )

    restore_original_weights(models, model_orig_weights)

    _, _, all_trial_perm_matrices = frank_wolfe_weight_matching(
        ps=permutation_spec,
        fixed=fixed_model.model.state_dict(),
        permutee=permutee_model.model.state_dict(),
        num_trials=num_trials,
        initialization_method="sinkhorn",
        return_all_trial_perm_matrices=True,
    )

    trial_loss_barriers = []

    for trial in range(num_trials):

        restore_original_weights(models, model_orig_weights)

        perms = {p: solve_linear_assignment_problem(perm) for p, perm in all_trial_perm_matrices[trial].items()}

        updated_params = {fixed_symb: {permutee_symb: None}}

        updated_params[fixed_symb][permutee_symb] = apply_permutation_to_statedict(
            permutation_spec, perms, models[permutee_symb].model.state_dict()
        )

        trial_results = evaluate_pair_of_models(
            models,
            fixed_symb,
            permutee_symb,
            updated_params,
            train_loader,
            test_loader,
            lambdas,
            core_cfg,
        )

        trial_loss_barriers.append(trial_results["test_loss_barrier"])

    all_results[(fixed_symb, permutee_symb)] = {
        "sinkhorn": trial_loss_barriers,
        "identity": identity_results["test_loss_barrier"],
        "barycenter": barycenter_results["test_loss_barrier"],
    }

In [None]:
extracted_results = {}
for pair in all_results:
    extracted_results[pair] = {
        "identity": all_results[pair]["identity"],
        "sinkhorn_mean": np.array(all_results[pair]["sinkhorn"]).mean(),
        "sinkhorn_std": np.array(all_results[pair]["sinkhorn"]).std(),
        "barycenter": all_results[pair]["barycenter"],
    }

In [None]:
extracted_results

In [None]:
latex_table_str = ""
for pair in extracted_results:
    latex_table_str += f"{pair[0]}-{pair[1]} & {extracted_results[pair]['identity']:.3f} & {extracted_results[pair]['barycenter']:.3f} {extracted_results[pair]['sinkhorn_mean']:.3f} & {extracted_results[pair]['sinkhorn_std']:.3f}  \\\\ \n"

In [None]:
print(latex_table_str)

In [None]:
latex_table_str