## Imports

In [None]:
import copy
import itertools
import logging
import math
from functools import partial
from pathlib import Path
from typing import Dict
import json
import timm

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    get_interpolated_loss_acc_curves,
    l2_norm_models,
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)
sns.set_context("talk")

cmap_name = "coolwarm_r"

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Heatmap

In [None]:
data = json.load(open(f"{PROJECT_ROOT}/results.json"))

In [None]:
# Make the matrix symmetric and adjust labels for both rows and columns to be models
models = list(data.keys())
n = len(models)

# Initialize a symmetric matrix with zeros
symmetric_matrix = np.zeros((n, n))

# Fill in the symmetric values from the provided data
for i, model_a in enumerate(models):
    for j, model_b in enumerate(models):
        if model_b in data[model_a]:
            symmetric_matrix[i, j] = data[model_a][model_b][0]
        elif model_a in data[model_b]:
            symmetric_matrix[i, j] = data[model_b][model_a][0]

# Creating the heatmap with formatted labels and symmetric values
plt.figure(figsize=(15, 12))


cmap = sns.light_palette("seagreen", as_cmap=True)

# Create a mask to display only the lower and upper triangles
mask = np.zeros_like(symmetric_matrix)
mask = np.tril(np.ones_like(symmetric_matrix), k=0)

sns.heatmap(symmetric_matrix, xticklabels=models, yticklabels=models, annot=True, fmt=".2f", cmap=cmap, mask=mask)

plt.title("Merging Accuracies")
plt.xlabel("Models")
plt.ylabel("Models")


plt.savefig(f"{PROJECT_ROOT}/notebooks/plots/resnet50_merging_accuracies_heatmap.pdf")
plt.show()

In [None]:
from itertools import combinations

threshold = 0.440
valid_triplets = []

for a, b, c in combinations(range(n), 3):
    if symmetric_matrix[a, b] > threshold and symmetric_matrix[b, c] > threshold and symmetric_matrix[a, c] > threshold:
        valid_triplets.append((models[a], models[b], models[c]))

valid_triplets

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching")

In [None]:
cfg = compose(config_name="matching", overrides=["model=resnet50", "dataset=tiny_imagenet"])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
data_path = PROJECT_ROOT / "data/imagenet-mini"

In [None]:
import os


paths = []
labels = []
for dirname, _, filenames in os.walk(data_path):
    for filename in filenames:
        if filename[-4:] == "JPEG":
            paths += [(os.path.join(dirname, filename))]
            label = dirname.split("/")[-1]
            labels += [label]

In [None]:
class_names = sorted(set(labels))
N = list(range(len(class_names)))
normal_mapping = dict(zip(class_names, N))
reverse_mapping = dict(zip(N, class_names))

In [None]:
import pandas as pd

df = pd.DataFrame(columns=["path", "label"])
df["path"] = paths
df["label"] = labels
df["label"] = df["label"].map(normal_mapping)

In [None]:
from PIL import Image
from torch.utils.data import Dataset

import torchvision.transforms as transforms


class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        path = self.dataframe.loc[index, "path"]
        label = self.dataframe.loc[index, "label"]
        image = Image.open(path).convert("RGB")

        transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
        image = transform(image)
        return image, label

In [None]:
train_dataset = CustomDataset(df)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=100, num_workers=8)

## Load models

In [None]:
run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

In [None]:
# # {a: model_a, b: model_b, c: model_c, ..}

from ccmm.models.resnet50 import PretrainedResNet50
from ccmm.pl_modules.pl_module import MyLightningModule
from torchvision.models import resnet50, ResNet50_Weights

num_classes = 1000

model_seeds = [1, 2, 3]
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in model_seeds}

models: Dict[str, LightningModule] = {
    "a": MyLightningModule(model=PretrainedResNet50(num_classes=1000, weights="c1"), num_classes=num_classes),
    "b": MyLightningModule(model=PretrainedResNet50(num_classes=1000, weights="a1"), num_classes=num_classes),
    "c": MyLightningModule(model=PretrainedResNet50(num_classes=1000, weights="ram"), num_classes=num_classes),
}
model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}

num_models = len(models)

In [None]:
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)
fixed_symbol, permutee_symbol = "a", "b"

## Load permutation specification

In [None]:
from ccmm.utils.perm_graph import get_perm_dict

x = torch.randn(1, 3, 256, 256)
ref_model = copy.deepcopy(models[fixed_symbol])

permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation_spec(ref_model)

In [None]:
permutation_spec.layer_and_axes_to_perm.keys()

In [None]:
from ccmm.utils.utils import get_model


inner_ref_model = get_model(ref_model)

In [None]:
set(permutation_spec.layer_and_axes_to_perm.keys()).difference(set(inner_ref_model.state_dict().keys()))

In [None]:
set(inner_ref_model.state_dict().keys()).difference(set(permutation_spec.layer_and_axes_to_perm.keys()))

## Test endpoint models

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

In [None]:
trainer.test(models["a"], train_loader)
trainer.test(models["b"], train_loader)
trainer.test(models["c"], train_loader)

## MergeMany

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())

sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

In [None]:
from ccmm.matching.merger import FrankWolfeSynchronizedMerger


fw_merger = FrankWolfeSynchronizedMerger(
    name="frank_wolfe_sync", permutation_spec=permutation_spec, initialization_method="sinkhorn", max_iter=100
)

In [None]:
fw_non_repaired, fw_repaired, models_to_univ = fw_merger(models, train_loader=train_loader)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

In [None]:
trainer.test(models_to_univ["a"], train_loader)
trainer.test(models_to_univ["b"], train_loader)
trainer.test(models_to_univ["c"], train_loader)

In [None]:
trainer.test(fw_non_repaired, train_loader)

In [None]:
trainer.test(fw_repaired, train_loader)

## MergeMany Git Re-Basin

In [None]:
from ccmm.matching.merger import GitRebasinMerger


mm_merger = GitRebasinMerger(name="git_rebasin", permutation_spec=permutation_spec)

In [None]:
non_repaired, repaired = mm_merger(models, train_loader=train_loader)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)
trainer.test(repaired, train_loader)

In [None]:
trainer.test(non_repaired, train_loader)

### Naive

In [None]:
from ccmm.matching.merger import DummyMerger


merger = DummyMerger(name="naive", permutation_spec=permutation_spec)

In [None]:
naive_non_repaired, naive_repaired = merger(models, train_loader=train_loader)

In [None]:
trainer.test(naive_non_repaired, train_loader)

In [None]:
trainer.test(naive_repaired, train_loader)