## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math
import itertools
from ccmm.utils.utils import l2_norm_models
import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto
from functools import partial

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from scipy.optimize import linear_sum_assignment
import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)
from ccmm.pl_modules.pl_module import MyLightningModule

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves
import pytorch_lightning

In [None]:
import autograd.numpy as anp

In [None]:
import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers

In [None]:
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import numpy as np
from scipy.linalg import eig
from numpy.linalg import svd
from scipy.optimize import linear_sum_assignment
import scipy
import json


def build_laplacian(knn_graph, normalized=True):

    A = (knn_graph + knn_graph.T).astype(float)
    A = A.toarray()

    D = np.diag(np.sum(A, axis=1))
    L = D - A

    if normalized:
        D_inv_sqrt = np.diag(1 / (np.sqrt(np.diag(D)) + 1e-6))
        L = D_inv_sqrt @ L @ D_inv_sqrt

    evals, evecs = eig(L)
    evals = evals.real

    idx = evals.argsort()
    evals = evals[idx]
    evecs = evecs[:, idx]

    return A, L, evals, evecs

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="func_maps", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = -1
num_train_samples = -1

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

num_train_samples = len(train_dataset) if num_train_samples < 0 else num_train_samples

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=1000, num_workers=cfg.num_workers)

num_test_samples = len(test_dataset) if num_test_samples < 0 else num_test_samples
test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False, max_epochs=10)

## Train models

In [None]:
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, input=28 * 28, num_classes=10):
        super().__init__()
        self.input = input
        self.layer0 = nn.Linear(input, 512)
        self.layer1 = nn.Linear(512, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 256)
        self.layer4 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.view(-1, self.input)

        h0 = nn.functional.relu(self.layer0(x))

        h1 = nn.functional.relu(self.layer1(h0))

        h2 = nn.functional.relu(self.layer2(h1))

        h3 = nn.functional.relu(self.layer3(h2))

        h4 = self.layer4(h3)

        embeddings = [h0, h1, h2, h3, h4]

        return nn.functional.log_softmax(h4, dim=-1), embeddings

In [None]:
from ccmm.matching.permutation_spec import MLPPermutationSpecBuilder

permutation_spec_builder = MLPPermutationSpecBuilder(4)
permutation_spec = permutation_spec_builder.create_permutation_spec()

In [None]:
cfg.seed_index = 0
seed_index_everything(cfg)
model_a = MyLightningModule(MLP(), num_classes=10)

trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=20)
trainer.fit(model_a, train_loader)

In [None]:
cfg.seed_index = 1
seed_index_everything(cfg)

model_b = MyLightningModule(MLP(), num_classes=10)
trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=20)
trainer.fit(model_b, train_loader)

## Matching

In [None]:
from ccmm.matching.weight_matching import weight_matching

permutations = weight_matching(permutation_spec, model_a.model.state_dict(), model_b.model.state_dict())

## Get activations

In [None]:
num_activations = 10000
train_loader = DataLoader(train_subset, batch_size=num_activations, num_workers=cfg.num_workers)

In [None]:
for batch in train_loader:

    x, y = batch
    features_a = model_a.model(x)[-1]
    features_b = model_b.model(x)[-1]
    break

## Focus on a single layer

In [None]:
layer_idx = 1

perm_gt = permutations[f"P_{layer_idx}"]

In [None]:
# (descriptor_dim, num_neurons), where descriptor_dim is the number of samples for which we are considering the neuron activation
layer_a = features_a[layer_idx]
layer_b = features_b[layer_idx]

In [None]:
# normalize to have unit norm

layer_a = layer_a / (torch.norm(layer_a, dim=0) + 1e-6)
layer_b = layer_b / (torch.norm(layer_b, dim=0) + 1e-6)

In [None]:
print(layer_a.shape, layer_b.shape)

In [None]:
from ccmm.utils.utils import to_np

layer_a = to_np(layer_a)
layer_b = to_np(layer_b)

## Git Re-Basin style matching

In [None]:
layer_a_weights = model_a.model.state_dict()[f"layer{layer_idx}.weight"]
layer_b_weights = model_b.model.state_dict()[f"layer{layer_idx}.weight"]

sim_matrix_git_rebasin = layer_a_weights @ layer_b_weights.T

In [None]:
_, ci = linear_sum_assignment(sim_matrix_git_rebasin.detach().numpy(), maximize=True)

In [None]:
perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()

In [None]:
b_perm = perm_matrix @ layer_b_weights.detach().numpy()

In [None]:
np.trace(layer_a @ b_perm.T)

## SVD 

In [None]:
# (num_samples, num_neurons)
layer_a.shape

In [None]:
import numpy as np


def svd_threshold(matrix, variance_threshold=0.99):
    # Compute SVD
    U, S, Vt = np.linalg.svd(matrix, full_matrices=False)

    # Calculate the cumulative variance explained by the singular values
    total_variance = np.sum(S**2)
    explained_variance = np.cumsum(S**2) / total_variance

    # Determine the number of singular values needed to explain the desired threshold of variance
    num_components = np.argmax(explained_variance >= variance_threshold) + 1

    # Select the subset of singular values and vectors explaining the desired variance
    U_reduced = U[:, :num_components]
    S_reduced = S[:num_components]
    Vt_reduced = Vt[:num_components, :]

    return U_reduced, S_reduced, Vt_reduced, explained_variance


def svd_num_components(matrix, num_components=10):
    # matrix is ~ (num_samples, num_neurons)
    num_samples, num_neurons = matrix.shape
    K = num_components

    U, S, Vt = np.linalg.svd(matrix, full_matrices=False)

    assert U.shape == (num_samples, num_neurons)
    assert S.shape == (num_neurons,)
    assert Vt.shape == (num_neurons, num_neurons)

    U_reduced = U[:, :K]
    S_reduced = S[:K]
    Vt_reduced = Vt[:K, :]

    return U_reduced, S_reduced, Vt_reduced

In [None]:
# (num_neurons, num_comps), (num_comps), (num_comps, num_samples)
U_a, S_a, Vt_a = svd_num_components(layer_a, num_components=128)
U_b, S_b, Vt_b = svd_num_components(layer_b, num_components=128)

In [None]:
print(U_a.shape, S_a.shape, Vt_a.shape)
print(U_b.shape, S_b.shape, Vt_b.shape)

In [None]:
eigenneurons_a = 1 / ((np.diag(S_a) ** 0.5) + 1 - 6) @ Vt_a
eigenneurons_b = 1 / ((np.diag(S_b) ** 0.5) + 1 - 6) @ Vt_b

eigenneurons_a = eigenneurons_a.T
eigenneurons_b = eigenneurons_b.T

In [None]:
# express each layer as a linear combination of the singular vectors
layer_a_reconstructed = U_a @ np.diag(S_a) @ Vt_a
layer_b_reconstructed = U_b @ np.diag(S_b) @ Vt_b

In [None]:
layer_a_reconstructed.shape

### Reconstruction error 

In [None]:
# check if the reconstruction is close to the original layer by computing the norm
np.linalg.norm(layer_a_reconstructed - layer_a)

In [None]:
# check if the reconstruction is close to the original layer by computing the norm
np.linalg.norm(layer_b_reconstructed - layer_b)

In [None]:
# compute the norm of the two models for comparison
np.linalg.norm(layer_a - layer_b)

### Graph

In [None]:
W_a = layer_a_weights.detach().numpy()
W_b = layer_b_weights.detach().numpy()

W_a.shape

In [None]:
# (num_samples, num_neurons)
layer_a.shape

In [None]:
# make a 7x7 grid of subplots, one with each func map


def plot_func_maps(func_maps, fig_name, vmin, vmax):
    fig, axs = plt.subplots(7, 7, figsize=(20, 20))

    for i in range(7):
        for j in range(7):

            ax = axs[i, j]
            ax.imshow(func_maps[i * 7 + j], cmap=cmap_name, vmin=vmin, vmax=vmax)
            ax.axis("off")

    plt.savefig(f"figures/{fig_name}.png")

In [None]:
def compute_func_map(X, Y, P, radius=None, num_neighbors=None, mode="distance", normalize_lap=True):

    assert radius is not None or num_neighbors is not None

    if radius is not None:
        Xneigh = NearestNeighbors(radius=radius)
        Yneigh = NearestNeighbors(radius=radius)

    elif num_neighbors is not None:
        Xneigh = NearestNeighbors(n_neighbors=num_neighbors)
        Yneigh = NearestNeighbors(n_neighbors=num_neighbors)

    else:
        raise ValueError("Either radius or num_neighbors must be provided")

    Xneigh.fit(X)
    # (num_neurons, num_neurons)
    X_knn_graph = Xneigh.kneighbors_graph(X, mode=mode)

    Yneigh.fit(Y)
    Y_knn_graph = Yneigh.kneighbors_graph(Y, mode=mode)

    XA, XL, Xevals, Xevecs = build_laplacian(X_knn_graph, normalize_lap)
    YA, YL, Yevals, Yevecs = build_laplacian(Y_knn_graph, normalize_lap)

    Xevecs = Xevecs.real
    Yevecs = Yevecs.real

    num_eigenvectors = 50
    C = Xevecs[:, :num_eigenvectors].T @ P @ Yevecs[:, :num_eigenvectors]

    return C

In [None]:
descriptors = "eigenneurons"  # weights, features, features_denoised

if descriptors == "weights":
    X, Y = W_a, W_b
elif descriptors == "features":
    X, Y = layer_a.T, layer_b.T
elif descriptors == "features_denoised":
    X, Y = layer_a_reconstructed.T, layer_b_reconstructed.T
elif descriptors == "eigenneurons":
    X, Y = eigenneurons_a, eigenneurons_b
else:
    raise ValueError("Invalid value for use_weights_or_features")

In [None]:
P = perm_indices_to_perm_matrix(perm_gt).numpy()
normalize_lap = True
mode = "connectivity"  # connectivity or distance

In [None]:
func_maps_neighbors = [
    compute_func_map(X, Y, P, num_neighbors=k, mode=mode, normalize_lap=normalize_lap) for k in range(1, 100, 2)
]

In [None]:
plot_name = f"func_maps_{descriptors}_{mode}_normalizeLap_{normalize_lap}"
plot_func_maps(func_maps_neighbors, plot_name, vmin=-0.3, vmax=0.3)

In [None]:
# compute diameter of X

# compute cidst with numpy
from scipy.spatial.distance import cdist

diameter = np.max(cdist(X.T, X.T, metric="euclidean"))
print(diameter)

In [None]:
func_maps_radius = [compute_func_map(X, Y, P, radius=r) for r in np.linspace(0.01, 100, 50)]

In [None]:
plot_func_maps(func_maps_radius)

In [None]:
k = 10

# (num_neurons, num_samples)
# X, Y = layer_a_reconstructed.T, layer_b_reconstructed.T
X, Y = W_a, W_b

Xneigh = NearestNeighbors(n_neighbors=k)
Xneigh.fit(X)

# (num_neurons, num_neurons)
X_knn_graph = Xneigh.kneighbors_graph(X, mode="connectivity")

Yneigh = NearestNeighbors(n_neighbors=k)
Yneigh.fit(Y)
Y_knn_graph = Yneigh.kneighbors_graph(Y, mode="connectivity")

In [None]:
pca = PCA(n_components=3)
pca.fit(X.T)

Xx = pca.components_[0, :]
Xy = pca.components_[1, :]
Xz = pca.components_[2, :]

pca = PCA(n_components=3)
pca.fit(Y.T)

Yx = pca.components_[0, :]
Yy = pca.components_[1, :]
Yz = pca.components_[2, :]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0] = fig.add_subplot(121, projection="3d")
ax[0].scatter(Xx, Xy, Xz, c="tab:blue")

ax[1] = fig.add_subplot(122, projection="3d")
ax[1].scatter(Yx, Yy, Yz, c="tab:red")

plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0] = fig.add_subplot(121, projection="3d")

num_neurons = W_a.shape[0]
for i in range(num_neurons):
    for j in range(num_neurons):
        if X_knn_graph[i, j] > 0:
            ax[0].plot([Xx[i], Xx[j]], [Xy[i], Xy[j]], [Xz[i], Xz[j]], "b-", alpha=0.5)

ax[0].scatter(Xx, Xy, Xz, c="tab:blue")

ax[1] = fig.add_subplot(122, projection="3d")
for i in range(num_neurons):
    for j in range(num_neurons):
        if Y_knn_graph[i, j] > 0:
            ax[1].plot([Yx[i], Yx[j]], [Yy[i], Yy[j]], [Yz[i], Yz[j]], "b-", alpha=0.5)

ax[1].scatter(Yx, Yy, Yz, c="tab:red")

plt.show()

In [None]:
XA, XL, Xevals, Xevecs = build_laplacian(X_knn_graph, True)
YA, YL, Yevals, Yevecs = build_laplacian(Y_knn_graph, True)

In [None]:
# P = np.eye(num_neurons)

In [None]:
# K = 50
# C = Xevecs[:, :K].T @ P.T @ Yevecs[:, :K]
# C_Pt = Xevecs[:, :K].T @ P @ Yevecs[:, :K]

# fig, ax = plt.subplots(1, 3, figsize=(16, 8))
# ax[0].imshow(P)
# ax[1].imshow(C[1:,1:], cmap='bwr')
# ax[2].imshow(C_Pt[1:, 1:], cmap='bwr')

In [None]:
# n = num_neurons
# k = 24

# def project_sinkhorn(P, max_iters=100):
#     Q = P
#     for _ in range(max_iters):
#         r = np.sum(Q, axis=0)
#         Q = Q @ np.diag(1 / r)
#         c = np.sum(Q, axis=1)
#         Q = np.diag(1 / c) @ Q

#     return Q

# manifold = pymanopt.manifolds.Positive(n, n)

# @pymanopt.function.autograd(manifold)
# def cost(point):
#     C = Xevecs[:,1:k].T @ point @ Yevecs[:,1:k]
#     return anp.sum((C - anp.diag(anp.diag(C))) ** 2)

# optimizer = pymanopt.optimizers.SteepestDescent()
# problem = pymanopt.Problem(manifold, cost)

# P = None
# for outer_iter in range(10):
#     result = optimizer.run(problem, initial_point=P)
#     P = project_sinkhorn(result.point)

# C = Xevecs[:,1:k].T @ P @ Yevecs[:,1:k]
# fig, ax = plt.subplots(1, 2, figsize=(4,2))
# ax[0].imshow(P)
# ax[1].imshow(C, cmap='bwr')

### Solve a LAP in the reduced space

In [None]:
from scipy.optimize import linear_sum_assignment

# _, ci = linear_sum_assignment(U_a.T @ U_b + Vt_a.T @ Vt_b.T, maximize=True)
_, ci = linear_sum_assignment(layer_a_reconstructed.T @ layer_b_reconstructed, maximize=True)

In [None]:
perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()

In [None]:
perm_matrix.shape

In [None]:
# U_sigma_b_perm =  U_b @ perm_matrix
# Vt_b = perm_matrix @ Vt_b
# S_b = perm_matrix @ S_b

# layer_b_reconstructed_perm = U_sigma_b_perm @ np.diag(S_b) @ Vt_b

In [None]:
layer_b_reconstructed_perm = perm_matrix @ layer_b_reconstructed.T

layer_b_reconstructed_perm = layer_b_reconstructed_perm.T

In [None]:
layer_b_recon_perm_norm = layer_b_reconstructed_perm / (np.linalg.norm(layer_b_reconstructed_perm, axis=0) + 1e-6)
layer_a_norm = layer_a / (np.linalg.norm(layer_a, axis=0) + 1e-6)
layer_b_norm = layer_b / (np.linalg.norm(layer_b, axis=0) + 1e-6)

In [None]:
np.trace(layer_b_recon_perm_norm.T @ layer_a_norm)

In [None]:
np.trace(layer_b_norm.T @ layer_a_norm)

### LAP in the original space

In [None]:
sim_matrix_orig_space = layer_a @ layer_b.T

_, ci = linear_sum_assignment(-sim_matrix_orig_space, maximize=True)
perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()

In [None]:
layer_b_perm = perm_matrix @ layer_b

layer_b_perm_norm = layer_b_perm / (np.linalg.norm(layer_b_perm, axis=0) + 1e-6)
np.trace(layer_a_norm @ layer_b_perm_norm.T)