# Match and analyze many models
---
We load a set of pretrained models and match them cycle-consistently

## Imports

In [2]:
import copy
import itertools
import logging
import math
from functools import partial
from pathlib import Path
from typing import Dict
import json

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    get_interpolated_loss_acc_curves,
    l2_norm_models,
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)

  from .autonotebook import tqdm as notebook_tqdm
  import pkg_resources


In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)
sns.set_context("talk")

cmap_name = "coolwarm_r"

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=["model=resnet20", "model.widen_factor=2"])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

# TODO: remove ln from artifact path
artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_ln_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Cycle-Consistent Matching 

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation_spec()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

for symbol, model in models_permuted_to_universe.items():
    perms_to_universe = {}

    for perm_name, perm in permutations[symbol].items():
        perm = perm_indices_to_perm_matrix(perm)
        perm_to_universe = perm.T
        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)
        perms_to_universe[perm_name] = perm_to_universe

    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [None]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["a"])

    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )

    ref_model.model.load_state_dict(permuted_params)
    models_permuted_pairwise[fixed][permutee] = ref_model

### Check performance of models before and after permutation

In [None]:
before_perms = []
after_perms = []

In [None]:
for symbol, model in models_permuted_to_universe.items():
    after_perm = trainer.test(models_permuted_to_universe[symbol], test_loader)[0]["acc/test"]
    before_perm = trainer.test(models[symbol], test_loader)[0]["acc/test"]

    before_perms.append(before_perm)
    after_perms.append(after_perm)

In [None]:
print(after_perms)
print(before_perms)

### Check that permutation pairwise doesn't change performance 

In [None]:
trainer.test(models["a"], test_loader)[0]["acc/test"]
trainer.test(models_permuted_pairwise["b"]["a"], test_loader)[0]["acc/test"]

## Git Re-Basin

In [None]:
from ccmm.matching.weight_matching import PermutationSpec, weight_matching
from ccmm.matching.utils import get_inverse_permutations


pairwise_perms_gitrebasin = {
    symb: {other_symb: None for other_symb in set(symbols).difference(symb)} for symb in symbols
}

for fixed, permutee in canonical_combinations:
    permutation = weight_matching(
        permutation_spec,
        fixed=models[fixed].model.state_dict(),
        permutee=models[permutee].model.state_dict(),
    )
    pairwise_perms_gitrebasin[fixed][permutee] = permutation
    pairwise_perms_gitrebasin[permutee][fixed] = get_inverse_permutations(permutation)

### Analyze models as vectors

In [None]:
other_symbs = {symbol: set(symbols).difference(symbol) for symbol in symbols}
print(other_symbs)

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(models_permuted_pairwise[symbol][other_symb].parameters())
        for other_symb in other_symbs[symbol]
    }
    for symbol in symbols
}

In [None]:
for symbol in symbols:
    flat_models_permuted_pairwise[symbol][symbol] = flat_models[symbol]
    models_permuted_pairwise[symbol][symbol] = models[symbol]

### Check that the permutations are cycle consistent 

In [None]:
perm_names = list(pairwise_permutations["a"]["b"].keys())

for perm_name in perm_names:
    P1 = perm_indices_to_perm_matrix(pairwise_permutations["a"]["b"][perm_name])
    P2 = perm_indices_to_perm_matrix(pairwise_permutations["b"]["c"][perm_name])
    P3 = perm_indices_to_perm_matrix(pairwise_permutations["c"]["a"][perm_name])

    cyclic_composition = P1 @ P2 @ P3
    assert torch.allclose(cyclic_composition, torch.eye(P1.shape[0]))

## Plots and tables

In [None]:
label_ours = r"$C^2M^3$"
label_gitrebasin = "Git Re-Basin"

lambdas = np.linspace(0, 1, 25)

get_curves = partial(
    get_interpolated_loss_acc_curves, lambdas=lambdas, ref_model=ref_model, trainer=trainer, loader=test_loader
)

In [None]:
def plot_lmc(values, lambdas, labels, colors, axis=None):

    num_curves = len(values)
    transparencies = np.linspace(0.5, 1, num_curves)
    linewidths = np.linspace(2.0, 4.0, num_curves)

    for i, (val, label) in enumerate(zip(values, labels)):
        if axis is None:
            axis = plt

        axis.plot(lambdas, val, label=label, alpha=transparencies[i], linewidth=linewidths[i], color=colors[i])


plot_lmc = partial(plot_lmc, lambdas=lambdas)

In [None]:
def plot_loss_and_acc_curves(losses, accuracies, labels, output_name):

    colors = palette.get_colors(len(labels))
    fig, axes = plt.subplots(
        1,
        2,
        figsize=(10, 3),
    )
    plot_lmc(accuracies, axis=axes[0], labels=labels, colors=colors)

    axes[0].set_title("Accuracy")
    axes[0].set_xlabel(r"$\lambda$")
    axes[0].grid(True, alpha=0.3, linestyle="--")

    plot_lmc(losses, axis=axes[1], labels=labels, colors=colors)
    axes[1].set_title("Loss")
    axes[1].set_xlabel(r"$\lambda$")
    axes[1].grid(True, alpha=0.3, linestyle="--")

    plt.subplots_adjust(bottom=0.1, right=0.8, top=0.9, wspace=0.4)

    legend_y = -0.5 if len(labels) > 3 else -0.4
    legend_x = -0.2
    plt.legend(bbox_to_anchor=(legend_x, legend_y), loc="center", ncol=3)
    plt.savefig(f"figures/{output_name}.pdf", bbox_inches="tight")

### PLOT: LMC of a $A$ and $A \rightarrow B \rightarrow C \rightarrow A$

In [None]:
def cyclic_permute(pairwise_perms, symbols, models):
    """
    Applies a cycle of permutations to the first model in models and returns the resulting model.
    """
    ordered_symbs = sorted(list(symbols))
    model_current = models[ordered_symbs[0]].model.state_dict()

    for i, symb in enumerate(ordered_symbs[1:] + [ordered_symbs[0]]):
        print("next: {} -- prev: {}".format(symb, ordered_symbs[i]))
        permutation = pairwise_perms[symb][ordered_symbs[i]]
        model_current = apply_permutation_to_statedict(permutation_spec, permutation, model_current)

    return model_current

In [None]:
a_cycle_ours = cyclic_permute(pairwise_permutations, ["a", "b", "c"], models)
a_cycle_gitr = cyclic_permute(pairwise_perms_gitrebasin, ["a", "b", "c"], models)

initial_model = models["a"]
permuted_model_ours = copy.deepcopy(initial_model)
permuted_model_ours.model.load_state_dict(a_cycle_ours)
permuted_model_gitr = copy.deepcopy(initial_model)
permuted_model_gitr.model.load_state_dict(a_cycle_gitr)

loss_cycle_ours, acc_cycle_ours = get_curves(
    model_a=initial_model,
    model_b=permuted_model_ours,
)
loss_cycle_gitr, acc_cycle_gitr = get_curves(
    model_a=initial_model,
    model_b=permuted_model_gitr,
)

In [None]:
labels = [label_gitrebasin, label_ours]
losses = [loss_cycle_gitr, loss_cycle_ours]
accuracies = [acc_cycle_gitr, acc_cycle_ours]

plot_loss_and_acc_curves(losses, accuracies, labels, "lmc_a_cycled_a")

### PLOT: LMC in the original space and in the universe

In [None]:
loss_ac, acc_ac = get_curves(
    model_a=models["a"],
    model_b=models["c"],
)

loss_a_univ_c_univ, acc_a_univ_c_univ = get_curves(
    model_a=models_permuted_to_universe["a"],
    model_b=models_permuted_to_universe["c"],
)

loss_bc, acc_bc = get_curves(
    model_a=models["b"],
    model_b=models["c"],
)

loss_b_univ_c_univ, acc_b_univ_c_univ = get_curves(
    model_a=models_permuted_to_universe["b"],
    model_b=models_permuted_to_universe["c"],
)

loss_ab, acc_ab = get_curves(
    model_a=models["a"],
    model_b=models["b"],
)

loss_a_univ_b_univ, acc_a_univ_b_univ = get_curves(
    model_a=models_permuted_to_universe["a"],
    model_b=models_permuted_to_universe["b"],
)

In [None]:
left_endpoints = [r"$A$", r"$P_A^{\top} A$", "$A$", r"$P_A^{\top} A$", "$B$", r"$P_B{^\top} B$"]
right_endpoints = ["C", r"$P_C^{\top} C$", "$B$", r"$P_B^{\top} B$", "$C$", r"$P_C{^\top} C$"]

labels = [f"{left} - {right}" for left, right in zip(left_endpoints, right_endpoints)]

losses = [loss_ac, loss_a_univ_c_univ, loss_ab, loss_a_univ_b_univ, loss_bc, loss_b_univ_c_univ]
accuracies = [acc_ac, acc_a_univ_c_univ, acc_ab, acc_a_univ_b_univ, acc_bc, acc_b_univ_c_univ]

plot_loss_and_acc_curves(losses, accuracies, labels, "interp_curves")

In [None]:
losses = {
    "ac": loss_ac,
    "a_univ_c_univ": loss_a_univ_c_univ,
    "ab": loss_ab,
    "a_univ_b_univ": loss_a_univ_b_univ,
    "bc": loss_bc,
    "b_univ_c_univ": loss_b_univ_c_univ,
}

json.dump(losses, open("results/losses.json", "w"))

### TABLE: Accumulated error in cycles

In [None]:
lambdas = np.linspace(0, 1, 3)

get_curves = partial(
    get_interpolated_loss_acc_curves, lambdas=lambdas, ref_model=ref_model, trainer=trainer, loader=test_loader
)

In [None]:
from ccmm.utils.utils import cosine_models


def get_accumulated_error(pairwise_perms, models, cycle_len=3, distance="l2"):

    symbols = list(pairwise_perms.keys())

    cycles = list(itertools.permutations(symbols))

    output = {}

    for c in cycles:
        print(f"Cycle: {c}")

        key = "".join(c)
        model_c = cyclic_permute(pairwise_perms, list(c), models)
        ordered_cycle = sorted(list(c))

        initial_model = models[ordered_cycle[0]]

        permuted_model = copy.deepcopy(initial_model)
        permuted_model.model.load_state_dict(model_c)

        if distance == "l2":
            print(f"Model distance: {l2_norm_models(model_c, initial_model.model.state_dict())}")

        elif distance == "cosine":
            print(f"Model similarity: {cosine_models(model_c, initial_model.model.state_dict())}")

        else:
            raise ValueError(f"Unknown distance metric: {distance}")

        losses, accs = get_curves(model_a=initial_model, model_b=permuted_model)

        output[key] = {"x": lambdas, "loss": np.array(losses), "acc": np.array(accs)}

    return output

In [None]:
get_accumulated_error(pairwise_perms_gitrebasin, models, cycle_len=3, distance="cosine")

In [None]:
get_accumulated_error(pairwise_permutations, models, cycle_len=3, distance="cosine")