## Imports

In [None]:
import copy

import itertools
import logging
import math
from functools import partial
from pathlib import Path
from typing import Dict
import json

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.perm_graph import get_perm_dict

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    get_interpolated_loss_acc_curves,
    l2_norm_models,
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)
sns.set_context("talk")

cmap_name = "coolwarm_r"

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="dev")

In [None]:
model_name = "vit"

dataset = "mnist" if model_name == "mlp" else "cifar10"

In [None]:
cfg = compose(config_name="matching", overrides=[f"model={model_name}", f"dataset={dataset}"])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:latest"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}
model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation_spec()

ref_model = list(models.values())[0]
# assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

### Git Re-Basin

In [None]:
# matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
# pylogger.info(f"Matcher: {matcher.name}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
from ccmm.matching.utils import get_inverse_permutations

symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)
fixed, permutee = "a", "b"
fixed_model, permutee_model = models[fixed], models[permutee]

# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a
permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}

# matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
# permutations[fixed_symbol][permutee_symbol], perm_history = matcher(
#     fixed=fixed_model.model, permutee=permutee_model.model
# )

# permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(permutations[fixed_symbol][permutee_symbol])

## Get perm dict

In [None]:
# import graphviz
import copy
import functools
import warnings
from collections import defaultdict

import graphviz
from numpy import arange, argmax, unique
from torchviz import make_dot

In [None]:
x = torch.randn(1, 1, 28, 28) if dataset == "mnist" else torch.randn(1, 3, 32, 32)

### Backpropagation graph

In [None]:
device = next(ref_model.parameters()).device

ref_model.to("cpu")
input = x.to("cpu")

model_output = ref_model(input)

dot = make_dot(model_output, params=dict(ref_model.named_parameters()))

dot.view(directory="graphviz", filename=f"{model_name}_model_graph")

In [None]:
from ccmm.utils.perm_graph import TorchVizPermutationGraph


g = TorchVizPermutationGraph()
g.from_dot(dot)

In [None]:
# map param name to permutation
param_to_node_id_map = dict()
for name, param in ref_model.named_parameters():
    key = g.paramid(name)
    param_to_node_id_map[name] = g.closer_perm(key)

permutation_list = list(param_to_node_id_map.values())
visited = set()

In [None]:
param_to_node_id_map

### Permutation graph

In [None]:
from ccmm.utils.perm_graph import build_permutation_graph, solve_graph


perm_graph, parameter_map = build_permutation_graph(ref_model, x)

In [None]:
perm_graph.view(directory="graphviz", filename=f"{model_name}_perm_graph")

In [None]:
parameter_map

In [None]:
perm_graph.view(directory="graphviz")

In [None]:
from ccmm.utils.perm_graph import perm_graph_to_perm_dict

perm_dict = perm_graph_to_perm_dict(perm_graph)

In [None]:
perm_dict

In [None]:
perm_dict, param_to_perm_map, param_to_prev_perm_map = get_perm_dict(ref_model, x)

In [None]:
param_to_perm_map

In [None]:
param_to_prev_perm_map

In [None]:
# perm_dict, map_param_index, map_prev_param_index = get_perm_dict(ref_model, input=x)

In [None]:
from ccmm.utils.perm_graph import graph_permutations_to_perm_spec

perm_spec_from_graph = graph_permutations_to_perm_spec(ref_model, perm_dict, param_to_perm_map, param_to_prev_perm_map)

In [None]:
perm_spec_names = list(perm_spec_from_graph.perm_to_layers_and_axes.keys())
graph_perm_spec_names = list(permutation_spec.perm_to_layers_and_axes.keys())

perm_name_mapping = {
    perm_spec_name: graph_perm_spec_name
    for perm_spec_name, graph_perm_spec_name in zip(perm_spec_names, graph_perm_spec_names)
}

In [None]:
# for perm_spec_name, graph_perm_spec_name in perm_name_mapping.items():
#     assert sorted(perm_spec_from_graph.perm_to_layers_and_axes[perm_spec_name]) == sorted(permutation_spec.perm_to_layers_and_axes[graph_perm_spec_name])

In [None]:
# assert len(permutation_spec.perm_to_layers_and_axes) == len(perm_spec_from_graph.perm_to_layers_and_axes)

In [None]:
perm_spec_from_graph.perm_to_layers_and_axes[2]

In [None]:
permutation_spec.perm_to_layers_and_axes["P_in"]

In [None]:
perm_spec_from_graph.perm_to_layers_and_axes[len(perm_spec_from_graph.perm_to_layers_and_axes) - 1]

In [None]:
permutation_spec.perm_to_layers_and_axes["P_last"]