## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate_state_dicts,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
permutations

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

for symbol, model in models_permuted_to_universe.items():
    perms_to_universe = {}

    for perm_name, perm in permutations[symbol].items():
        perm = perm_indices_to_perm_matrix(perm)
        perm_to_universe = perm.T
        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)
        perms_to_universe[perm_name] = perm_to_universe

    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [None]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}

pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["a"])

    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )
    ref_model.model.load_state_dict(permuted_params)

    models_permuted_pairwise[fixed][permutee] = ref_model

### Check performance of models before and after permutation

In [None]:
for symbol, model in models_permuted_to_universe.items():
    trainer.test(models_permuted_to_universe[symbol], test_loader)
    trainer.test(models[symbol], test_loader)

## Analyze models as vectors

### Flatten models

In [None]:
other_symbs = {symbol: set(symbols).difference(symbol) for symbol in symbols}
print(other_symbs)

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(models_permuted_pairwise[symbol][other_symb].parameters())
        for other_symb in other_symbs[symbol]
    }
    for symbol in symbols
}

In [None]:
for symbol in symbols:
    flat_models_permuted_pairwise[symbol][symbol] = flat_models[symbol]
    models_permuted_pairwise[symbol][symbol] = models[symbol]

## Analyzing variance

In [None]:
inputs = next(iter(train_loader))[0]

In [None]:
def get_layerwise_vars(model_a, model_b, ref_model):
    vvs = []

    for lamb in [0, 0.5, 1]:
        interp_params = linear_interpolate_state_dicts(t1=model_a.state_dict(), t2=model_b.state_dict(), lam=lamb)
        ref_model.load_state_dict(interp_params)

        vv = []

        for i in [1, 4, 7, 9, 12, 14, 17, 19]:
            subnet = ref_model.model.embedder[:i]

            with torch.no_grad():
                out = subnet(inputs)

            out = out.permute(1, 0, 2, 3).reshape(out.size(1), -1)
            avg_var = out.var(1).mean()
            vv.append(avg_var.item())

        vvs.append(np.array(vv))

    # lists of layerwise variances for endpoint A, midpoint 0.5, endpoint B
    vv0, vva, vv1 = vvs
    return vv0, vva, vv1


def get_layerwise_ratios(model_a, model_b):
    """
    Returns a list of ratios between the variance of the weight-interpolation midpoint and the averaged variances of the two endpoints.
    """

    ref_model = copy.deepcopy(model_a)

    vv0, vva, vv1 = get_layerwise_vars(model_a, model_b, ref_model)

    vv00 = (vv0 + vv1) / 2

    ratio = vva / vv00

    return ratio

In [None]:
# layerwise_ratios_naive = get_layerwise_ratios(models["a"], models["c"])
# layerwise_ratios_permuted = get_layerwise_ratios(models_permuted_pairwise["a"]["c"], models["a"])

In [None]:
# plt.plot(layerwise_ratios_naive, label="Without neuron matching")
# plt.plot(layerwise_ratios_permuted, label="With neuron matching")

# plt.ylim([0, 1])

# plt.xlabel("Layer index")
# plt.ylabel(r"$ \frac{ \sigma_{0.5} } { ( \sigma_0 + \sigma_1 ) / 2}$")
# plt.title("VGG11 layerwise variance ratios")

# plt.legend()

# plt.show()

## Repair

In [None]:
from ccmm.matching.repair import (
    replace_conv_layers,
    make_tracked_net,
    reset_bn_stats,
    ResetConv,
    compute_goal_statistics_two_models,
    compute_goal_statistics,
)

### Wrap networks

In [None]:
# TODO: use 'b' instead when it's available

model_a = copy.deepcopy(models["a"])
model_b = copy.deepcopy(models["c"])
model_b_perm = copy.deepcopy(models_permuted_pairwise["a"]["c"])

In [None]:
## calculate the statistics of every hidden unit in the endpoint networks

model_a_wrapped = make_tracked_net(model_a).cuda()
model_b_perm_wrapped = make_tracked_net(model_b_perm).cuda()

Check that the results are still the same

In [None]:
check = False

if check:
    trainer.test(model_a, test_loader)
    trainer.test(model_a_wrapped, test_loader)

    trainer.test(model_b_perm, test_loader)
    trainer.test(model_b_perm_wrapped, test_loader)

### Reset batch norm stats 

In [None]:
reset_bn_stats(model_a_wrapped.cuda())
reset_bn_stats(model_b_perm_wrapped.cuda())

Check that the results are still the same

In [None]:
check = False

if check:
    trainer.test(model_a, test_loader)
    trainer.test(model_a_wrapped, test_loader)

    trainer.test(model_b_perm, test_loader)
    trainer.test(model_b_perm_wrapped, test_loader)

### Create interpolated model

In [None]:
interp_model = copy.deepcopy(model_a)
interp_model.load_state_dict(
    linear_interpolate_state_dicts(t1=model_a.state_dict(), t2=model_b_perm.state_dict(), lam=0.5)
)

In [None]:
model_interp_wrapped = make_tracked_net(interp_model).cuda()

### Compute statistics

In [None]:
compute_goal_statistics_two_models(model_a_wrapped, model_interp_wrapped, model_b_perm_wrapped)

### Fuse batch norm layers into convolutions

In [None]:
# reset the tracked mean/var and fuse rescalings back into conv layers
reset_bn_stats(model_interp_wrapped.cuda())

In [None]:
ref_model = copy.deepcopy(model_a)

# fuse the rescaling+shift coefficients back into conv layers
fused_interp = fuse_tracked_net(model_interp_wrapped)

In [None]:
# evaluate model_a
model_a_wrapped.eval()
trainer.test(model_a_wrapped, test_loader)

# evaluate model_b_perm
model_b_perm_wrapped.eval()
trainer.test(model_b_perm_wrapped, test_loader)

# evaluate fused_interp
fused_interp.eval()
repaired_results = trainer.test(model_interp_wrapped, test_loader)[0]

In [None]:
def evaluate_interpolated_model(lambd, model_a, model_b, ref_model):
    interp_params = linear_interpolate_state_dicts(
        t1=model_a.model.state_dict(), t2=model_b.model.state_dict(), lam=lambd
    )

    ref_model.model.load_state_dict(interp_params)

    test_results = trainer.test(ref_model, test_loader, verbose=False)[0]

    return test_results

In [None]:
ref_model = copy.deepcopy(models["a"])
lambd = 0.5

In [None]:
naive_interp_results = evaluate_interpolated_model(lambd, model_a=models["a"], model_b=models["c"], ref_model=ref_model)

In [None]:
perm_interp_results = evaluate_interpolated_model(
    lambd, model_a=models["a"], model_b=models_permuted_pairwise["a"]["c"], ref_model=ref_model
)

In [None]:
pylogger.info(
    f"naive: {naive_interp_results['loss/test']}, matched: {perm_interp_results['loss/test']}, repaired: {repaired_results['loss/test']}"
)
pylogger.info(
    f"naive: {naive_interp_results['acc/test']}, matched: {perm_interp_results['acc/test']}, repaired: {repaired_results['acc/test']}"
)

## Repair over N models

In [None]:
merged_model = copy.deepcopy(models["a"])

mean_model_params = torch.stack([model for model in flat_models_permuted_to_universe.values()]).mean(dim=0)

merged_model.load_state_dict(vector_to_state_dict(mean_model_params, merged_model))

In [None]:
merged_model_wrapped = make_tracked_net(merged_model).cuda()

merged_model_wrapped.eval()
trainer.test(merged_model_wrapped, test_loader)

In [None]:
wrapped_models = [make_tracked_net(models_permuted_to_universe[symbol]).cuda() for symbol in symbols]

for model in wrapped_models:
    reset_bn_stats(model.cuda())

compute_goal_statistics(merged_model_wrapped, wrapped_models)

In [None]:
reset_bn_stats(merged_model_wrapped.cuda())
merged_model_wrapped.eval()
trainer.test(merged_model_wrapped, test_loader)

In [None]:
from ccmm.utils.utils import average_models


for ref_symbol in symbols:
    merged_model = copy.deepcopy(models[ref_symbol])

    all_models_permuted_to_ref = {symb: models_permuted_pairwise[ref_symbol][symb] for symb in symbols}

    model_params = {symbol: model.state_dict() for symbol, model in all_models_permuted_to_ref.items()}

    mean_params = average_models(model_params)

    merged_model.load_state_dict(mean_params)

    results = trainer.test(merged_model, test_loader, verbose=True)

    merged_model_wrapped = make_tracked_net(merged_model).cuda()

    wrapped_models = [make_tracked_net(models[symbol]).cuda() for symbol in symbols]

    for model in wrapped_models:
        reset_bn_stats(model.cuda())

    compute_goal_statistics(merged_model_wrapped, wrapped_models)
    reset_bn_stats(merged_model_wrapped.cuda())

    results = trainer.test(merged_model_wrapped, test_loader, verbose=True)

## Git-rebasin merge many

In [None]:
from ccmm.matching.merger import GitRebasinMerger

git_rebasin_merger = GitRebasinMerger(name="git_rebasin_merger", permutation_spec=permutation_spec)

In [None]:
%%capture
merged_model = git_rebasin_merger(models)

In [None]:
merged_model_wrapped = make_tracked_net(merged_model).cuda()
wrapped_models = [make_tracked_net(models[symbol]).cuda() for symbol in symbols]

for model in wrapped_models:
    reset_bn_stats(model.cuda())

compute_goal_statistics(merged_model_wrapped, wrapped_models)
reset_bn_stats(merged_model_wrapped.cuda())

results = trainer.test(merged_model_wrapped, test_loader, verbose=True)

## Git-rebasin merge many wrt reference

In [None]:
from ccmm.matching.weight_matching import PermutationSpec, weight_matching


def merge_wrt_model(ref_model_id, models):
    model_params = [copy.deepcopy(model.model.state_dict()) for model in models.values()]

    num_models = len(model_params)
    ref_model_params = model_params[ref_model_id]

    other_model_ids = [i for i in range(num_models) if i != ref_model_id]
    permutations = []

    for other_model_id in other_model_ids:
        other_model_params = copy.deepcopy(model_params[other_model_id])

        permutation = weight_matching(
            permutation_spec,
            fixed=ref_model_params,
            permutee=other_model_params,
        )

        permutations.append(permutation)

        other_model_params = apply_permutation_to_statedict(permutation_spec, permutation, other_model_params)

        model_params[other_model_id] = other_model_params

    mean_params = average_models(model_params)
    merged_model = copy.deepcopy(models[list(models.keys())[0]])
    merged_model.model.load_state_dict(mean_params)

    return merged_model, permutations

In [None]:
%%capture

all_permutations = []
all_merged_models = []

for symbol_ind, symbol in enumerate(symbols):
    merged_model, permutations = merge_wrt_model(ref_model_id=symbol_ind, models=copy.deepcopy(models))

    all_permutations.append(permutations)
    all_merged_models.append(merged_model)

    results = trainer.test(merged_model, test_loader, verbose=True)

In [None]:
all_results = []
for symbol_ind, symbol in enumerate(symbols):
    merged_model = all_merged_models[symbol_ind]
    permutations = all_permutations[symbol_ind]

    results = trainer.test(merged_model, test_loader, verbose=True)[0]
    all_results.append(results)

In [None]:
mean_acc = np.mean([result["acc/test"] for result in all_results])
std_acc = np.std([result["acc/test"] for result in all_results])
pylogger.info(f"${round(mean_acc, 4)} \pm {round(std_acc, 4)}$")

In [None]:
mean_loss = np.mean([result["loss/test"] for result in all_results])
std_loss = np.std([result["loss/test"] for result in all_results])
pylogger.info(f"${round(mean_loss, 4)} \pm {round(std_loss, 4)}$")