## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger


import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolation,
    load_model_from_info,
    load_permutations,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
epoch = 99
num_sampled_points = 500  # 2048
num_test_samples = 1000

## Load models

In [None]:
# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_info(cfg.model_info_path, seed) for seed in cfg.model_seeds
}

pylogger.info(f"Using model {core_cfg.model.name}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers)
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

### Sample points in the param space

In [None]:
eval_points = qmc.scale(
    qmc.Sobol(d=2, scramble=True, seed=cfg.seed_index).random(num_sampled_points),
    [-0.5, -0.5],
    [1.5, 1.5],
)

pylogger.info(eval_points[:10])

In [None]:
models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}


for symbol, model in models_permuted_to_universe.items():
    permuted_params = apply_permutation_to_statedict(permutation_spec, permutations[symbol], model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

In [None]:
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

In [None]:
from ccmm.utils.utils import unfactor_permutations

models_permuted_to_ref_model = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

ref_model_symb = "a"
pairwise_permutations = unfactor_permutations(permutations)

permuted_params = apply_permutation_to_statedict(
    permutation_spec, pairwise_permutations[ref_model_symb]["b"], models_permuted_to_ref_model["b"].model.state_dict()
)
models_permuted_to_ref_model["b"].model.load_state_dict(permuted_params)

In [None]:
flat_models_permuted_to_ref_model = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_ref_model.items()
}

In [None]:
proj = lambda a, b: torch.dot(a, b) / torch.dot(b, b) * b
norm = lambda a: torch.sqrt(torch.dot(a, a))
normalize = lambda a: a / norm(a)

In [None]:
model_a_flat = flat_models["a"]
model_b_flat = flat_models["b"]

model_b_flat_permuted = flat_models_permuted_to_universe["b"]

model_b_flat_permuted_pairwise = flat_models_permuted_to_ref_model["b"]

In [None]:
# Creating basis vectors

# model_a_flat is the origin
# basis1 is the vector from model_a_flat to model_b_flat
# 2 basis vectors: one goes to theta_a, the other to pi(theta_b)


def get_basis_vectors(origin_model, model_b_flat, model_b_flat_permuted):
    basis1 = model_b_flat - origin_model
    scale = norm(basis1)
    basis1_normed = normalize(basis1)

    # a_to_pi_b is the vector from pi(theta_b) to model_a_flat
    a_to_pi_b = model_b_flat_permuted - origin_model
    # make the basis orthogonal by discarding the component of a_to_pi_b in the direction of basis1
    basis2 = a_to_pi_b - proj(a_to_pi_b, basis1)
    basis2_normed = normalize(basis2)

    return basis1_normed, basis2_normed, scale


basis1_normed, basis2_normed, scale = get_basis_vectors(
    origin_model=model_a_flat, model_b_flat=model_b_flat, model_b_flat_permuted=model_b_flat_permuted_pairwise
)

In [None]:
project2d = (
    lambda theta: (
        torch.stack([torch.dot(theta - model_a_flat, basis1_normed), torch.dot(theta - model_a_flat, basis2_normed)])
        / scale
    )
    .detach()
    .cpu()
    .numpy()
)

In [None]:
import math


def get_pentagon_vertices(center_x, center_y, radius):
    pentagon_vertices = []
    for i in range(5):
        angle_deg = 72 * i  # 72 degrees between each point
        angle_rad = math.radians(angle_deg)  # Convert to radians
        x = radius * math.cos(angle_rad) + center_x
        y = radius * math.sin(angle_rad) + center_y
        pentagon_vertices.append((x, y))

    return np.array(pentagon_vertices)

In [None]:
def represent_barycentric_coordinates(x):
    origins = get_pentagon_vertices(0.5, 0.5, 0.9)
    A = origins.transpose(1, 0)

    z, residuals, rank, s = np.linalg.lstsq(A, x, rcond=None)

    assert np.allclose(np.dot(A, z), x)

    return z

In [None]:
x = eval_points[0]

In [None]:
# also solve for the barycentric coordinates of the models in the high-dimensional space
# scale the entire equation constraining the coefficients to sum up to 1 by a large scalar

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

In [None]:
def eval_one(xy, model_flat, basis1_normed, basis2_normed, scale, model, trainer, test_loader):

    new_flat_params = model_flat + scale * (basis1_normed * xy[0] + basis2_normed * xy[1])
    new_params = vector_to_state_dict(new_flat_params, model.model)

    model.model.load_state_dict(new_params)

    eval_results = trainer.test(model, test_loader, verbose=False)

    return eval_results


eval_results = np.array(
    [
        eval_one(xy, model_a_flat, basis1_normed, basis2_normed, scale, model, trainer, test_loader)
        for xy in tqdm(eval_points)
    ]
)

In [None]:
test_losses = np.array([res[0]["loss/test"] for res in eval_results])

In [None]:
# Create grid values first.
xi = np.linspace(-0.5, 1.5)
yi = np.linspace(-0.5, 1.5)

# Linearly interpolate the data (x, y) on a grid defined by (xi, yi).
triang = tri.Triangulation(eval_points[:, 0], eval_points[:, 1])
# We need to cap the maximum loss value so that the contouring is not completely saturated by wildly large losses
interpolator = tri.LinearTriInterpolator(triang, np.clip(test_losses, None, 0.55))

# interpolator = tri.LinearTriInterpolator(triang, jnp.log(jnp.minimum(1.5, eval_results[:, 0])))
zi = interpolator(*np.meshgrid(xi, yi))

## Plot

In [None]:
plt.figure()
num_levels = 13
plt.contour(xi, yi, zi, levels=num_levels, linewidths=0.25, colors="grey", alpha=0.5)
# cmap_name = "RdGy"
# cmap_name = "RdYlBu"
# cmap_name = "Spectral"
cmap_name = "coolwarm_r"

# cmap_name = "YlOrBr_r"
# cmap_name = "RdBu"

# See https://stackoverflow.com/a/18926541/3880977
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    return colors.LinearSegmentedColormap.from_list(
        "trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)),
    )


cmap = truncate_colormap(plt.get_cmap(cmap_name), 0.0, 1.5)  # 0.9)
plt.contourf(xi, yi, zi, levels=num_levels, cmap=cmap, extend="both")

x, y = project2d(model_a_flat)
plt.scatter([x], [y], marker="x", color="white", zorder=10)

x, y = project2d(model_b_flat)
plt.scatter([x], [y], marker="x", color="white", zorder=10)

x, y = project2d(model_b_flat_permuted_pairwise)
plt.scatter([x], [y], marker="x", color="white", zorder=10)

label_bboxes = dict(facecolor="tab:grey", boxstyle="round", edgecolor="none", alpha=0.5)
plt.text(
    -0.075,
    -0.1,
    r"${\bf \Theta_A}$",
    color="white",
    fontsize=24,
    bbox=label_bboxes,
    horizontalalignment="right",
    verticalalignment="top",
)
plt.text(
    1.075,
    -0.1,
    r"${\bf \Theta_B}$",
    color="white",
    fontsize=24,
    bbox=label_bboxes,
    horizontalalignment="left",
    verticalalignment="top",
)
x, y = project2d(model_b_flat_permuted_pairwise)
plt.text(
    x - 0.075,
    y + 0.1,
    r"${\bf \pi(\Theta_B)}$",
    color="white",
    fontsize=24,
    bbox=label_bboxes,
    horizontalalignment="right",
    verticalalignment="bottom",
)

# https://github.com/matplotlib/matplotlib/issues/17284#issuecomment-772820638
# Draw line only
connectionstyle = "arc3,rad=-0.3"
plt.annotate(
    "",
    xy=(1, 0),
    xytext=(x, y),
    arrowprops=dict(
        arrowstyle="-",
        edgecolor="white",
        facecolor="none",
        linewidth=5,
        linestyle=(0, (5, 3)),
        shrinkA=20,
        shrinkB=15,
        connectionstyle=connectionstyle,
    ),
)
# Draw arrow head only
plt.annotate(
    "",
    xy=(1, 0),
    xytext=(x, y),
    arrowprops=dict(
        arrowstyle="<|-",
        edgecolor="none",
        facecolor="white",
        mutation_scale=40,
        linewidth=0,
        shrinkA=12.5,
        shrinkB=15,
        connectionstyle=connectionstyle,
    ),
)

plt.annotate(
    "",
    xy=(0, 0),
    xytext=(x, y),
    arrowprops=dict(
        arrowstyle="-",
        edgecolor="white",
        alpha=0.5,
        facecolor="none",
        linewidth=2,
        linestyle="-",
        shrinkA=10,
        shrinkB=10,
    ),
)
plt.annotate(
    "",
    xy=(0, 0),
    xytext=(1, 0),
    arrowprops=dict(
        arrowstyle="-",
        edgecolor="white",
        alpha=0.5,
        facecolor="none",
        linewidth=2,
        linestyle="-",
        shrinkA=10,
        shrinkB=10,
    ),
)

# plt.gca().add_artist(
#     AnnotationBbox(
#         OffsetImage(
#             plt.imread(
#                 "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/check-mark-button_2705.png"
#             ),
#             zoom=0.1,
#         ),
#         (x / 2, y / 2),
#         frameon=False,
#     )
# )
# plt.gca().add_artist(
#     AnnotationBbox(
#         OffsetImage(
#             plt.imread(
#                 "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/240/apple/325/cross-mark_274c.png"
#             ),
#             zoom=0.1,
#         ),
#         (0.5, 0),
#         frameon=False,
#     )
# )

# "Git Re-Basin" box
#   box_x = 0.5 * (arrow_start[0] + arrow_stop[0])
#   box_y = 0.5 * (arrow_start[1] + arrow_stop[1])
# box_x = 0.5 * (arrow_start[0] + arrow_stop[0]) + 0.325
# box_y = 0.5 * (arrow_start[1] + arrow_stop[1]) + 0.2

box_x = 0.5
box_y = 1.3
git_rebasin_text = r"$C^2M^2$"

# Draw box only
plt.text(
    box_x,
    box_y,
    git_rebasin_text,
    color=(0.0, 0.0, 0.0, 0.0),
    fontsize=24,
    horizontalalignment="center",
    verticalalignment="center",
    bbox=dict(boxstyle="round", fc=(1, 1, 1, 1), ec="black", pad=0.4),
)
# Draw text only
plt.text(
    box_x,
    box_y - 0.0115,
    git_rebasin_text,
    color=(0.0, 0.0, 0.0, 1.0),
    fontsize=24,
    horizontalalignment="center",
    verticalalignment="center",
)

# plt.colorbar()
plt.xlim(-0.4, 1.4)
plt.ylim(-0.45, 1.3)
#   plt.xlim(-0.9, 1.9)
#   plt.ylim(-0.9, 1.9)
# plt.xticks([])
# plt.yticks([])
plt.tight_layout()
plt.savefig("resnet_cifar_loss_contour.png", dpi=300)
# plt.savefig("resnet_cifar_mlp_loss_contour.pdf")