## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=2000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

for symbol, model in models_permuted_to_universe.items():
    perms_to_universe = {}

    for perm_name, perm in permutations[symbol].items():
        perm = perm_indices_to_perm_matrix(perm)
        perm_to_universe = perm.T
        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)
        perms_to_universe[perm_name] = perm_to_universe

    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [None]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {
        other_symb: copy.deepcopy(model)
        for symbol, model in models.items()
        for other_symb in set(symbols).difference(symbol)
    }
    for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in canonical_combinations:
    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )
    models_permuted_pairwise[fixed][permutee].model.load_state_dict(permuted_params)

### Check performance of models before and after permutation

In [None]:
for symbol, model in models_permuted_to_universe.items():
    trainer.test(models_permuted_to_universe[symbol], test_loader)
    trainer.test(models[symbol], test_loader)

## Analyze models as vectors

### Flatten models

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()
    }
    for symbol, models in models_permuted_pairwise.items()
}

In [None]:
from ccmm.utils.utils import to_np

norms = {"model": [], "permuted": [], "diff": []}
for symbol, model in models.items():
    norm = torch.norm(flat_models[symbol])
    norm_permuted = torch.norm(flat_models_permuted_to_universe[symbol])
    norm_diff = torch.norm(flat_models[symbol] - flat_models_permuted_to_universe[symbol])

    norms["model"].append(to_np(norm))
    norms["permuted"].append(to_np(norm_permuted))
    norms["diff"].append(to_np(norm_diff))

In [None]:
import pandas as pd

df = pd.DataFrame(norms, index=models.keys())

df = df.apply(pd.to_numeric)

plt.figure(figsize=(5, 5))
sns.heatmap(df, annot=True, cmap="viridis")
plt.title("Model Norms Comparison")
plt.ylabel("Model Symbol")
plt.show()

## Compute gradient of universe model

In [None]:
model_a_universe = copy.deepcopy(models_permuted_to_universe["a"]).cuda()
model_a = copy.deepcopy(models["a"]).cuda()
merged_model = copy.deepcopy(models["a"]).cuda()

In [None]:
# model_a_parameters = torch.nn.utils.parameters_to_vector(model_a.parameters())
# print(model_a_parameters)
# model_a_univ_parameters = torch.nn.utils.parameters_to_vector(model_a_universe.parameters())
# print(model_a_univ_parameters)

In [None]:
torch.stack([model_perm_to_uni for model_perm_to_uni in flat_models_permuted_to_universe.values()])

In [None]:
torch.mean(torch.stack([model_perm_to_uni for model_perm_to_uni in flat_models_permuted_to_universe.values()]), dim=0)

In [None]:
merged_model_params = torch.mean(
    torch.stack([model_perm_to_uni for model_perm_to_uni in flat_models_permuted_to_universe.values()]), dim=0
)
merged_model_universe = vector_to_state_dict(merged_model_params, merged_model)
merged_model.load_state_dict(merged_model_universe)

In [None]:
parameters = torch.nn.utils.parameters_to_vector(merged_model.parameters())
print(parameters)

model_a_parameters = torch.nn.utils.parameters_to_vector(model_a.parameters())
print(model_a_parameters)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

gradient_norm_merged = 0.0
gradient_norm = 0.0

merged_model.zero_grad()
model_a.zero_grad()

for batch in tqdm(test_loader):
    x, y = batch

    x = x.to("cuda")
    y = y.to("cuda")

    output_merged = merged_model(x)
    loss_universe = loss_fn(output_merged, y)

    output = model_a(x)

    print(torch.allclose(output_merged, output, atol=1e-4))
    loss = loss_fn(output, y)

    loss_universe.backward()
    loss.backward()

for p_univ, p in zip(merged_model.parameters(), model_a.parameters()):
    gradient_norm_merged += p_univ.grad.norm().item() ** 2
    gradient_norm += p.grad.norm().item() ** 2

In [None]:
pylogger.info(f"Gradient norm merged: {gradient_norm_merged}")
pylogger.info(f"Gradient norm: {gradient_norm}")