## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math
import itertools
from ccmm.utils.utils import l2_norm_models
import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto
from functools import partial

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from scipy.optimize import linear_sum_assignment
import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)
from ccmm.pl_modules.pl_module import MyLightningModule

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves
import pytorch_lightning

In [None]:
import autograd.numpy as anp

In [None]:
import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers

In [None]:
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import numpy as np
from scipy.linalg import eig
from numpy.linalg import svd
from scipy.optimize import linear_sum_assignment
import scipy
import json


def build_laplacian(knn_graph, normalized=True):

    A = (knn_graph + knn_graph.T).astype(float)
    A = A.toarray()

    D = np.diag(np.sum(A, axis=1))
    L = D - A

    if normalized:
        D_inv_sqrt = np.diag(1 / (np.sqrt(np.diag(D)) + 1e-6))
        L = D_inv_sqrt @ L @ D_inv_sqrt

    evals, evecs = eig(L)
    evals = evals.real

    idx = evals.argsort()
    evals = evals[idx]
    evecs = evecs[:, idx]

    return A, L, evals, evecs

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="func_maps", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = -1
num_train_samples = -1

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

num_train_samples = len(train_dataset) if num_train_samples < 0 else num_train_samples
train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=512, num_workers=cfg.num_workers)

num_test_samples = len(test_dataset) if num_test_samples < 0 else num_test_samples
test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False, max_epochs=10)

## Train models

In [None]:
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, input=28 * 28, num_classes=10):
        super().__init__()
        self.input = input
        self.layer0 = nn.Linear(input, 512)
        self.layer1 = nn.Linear(512, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 256)
        self.layer4 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.view(-1, self.input)

        h0 = nn.functional.relu(self.layer0(x))

        h1 = nn.functional.relu(self.layer1(h0))

        h2 = nn.functional.relu(self.layer2(h1))

        h3 = nn.functional.relu(self.layer3(h2))

        h4 = self.layer4(h3)

        embeddings = [h0, h1, h2, h3, h4]

        return nn.functional.log_softmax(h4, dim=-1), embeddings

In [None]:
from ccmm.matching.permutation_spec import MLPPermutationSpecBuilder

permutation_spec_builder = MLPPermutationSpecBuilder(4)
permutation_spec = permutation_spec_builder.create_permutation_spec()

In [None]:
cfg.seed_index = 0
seed_index_everything(cfg)
model_a = MyLightningModule(MLP(), num_classes=10)

trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=50)
trainer.fit(model_a, train_loader)

trainer.test(model_a, test_loader)

In [None]:
cfg.seed_index = 1
seed_index_everything(cfg)

model_b = MyLightningModule(MLP(), num_classes=10)
trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=20)
trainer.fit(model_b, train_loader)

trainer.test(model_b, test_loader)

## Matching

In [None]:
from ccmm.matching.weight_matching import weight_matching

permutations = weight_matching(permutation_spec, model_a.model.state_dict(), model_b.model.state_dict())

In [None]:
permutations

In [None]:
from ccmm.matching.utils import apply_permutation_to_statedict

updated_params = apply_permutation_to_statedict(permutation_spec, permutations, model_b.model.state_dict())

In [None]:
import copy

model_b_perm = copy.deepcopy(model_b)
model_b_perm.model.load_state_dict(updated_params)

In [None]:
lambdas = np.linspace(0, 1, 10)

all_results = {"naive": [], "matched": []}

for lambd in lambdas:

    model_interp = copy.deepcopy(model_b)
    model_naive = copy.deepcopy(model_b)

    naive_interp_params = linear_interpolate(model_a=model_a, model_b=model_b, lambd=lambd)

    model_naive.load_state_dict(naive_interp_params)

    model_interp_params = linear_interpolate(model_a=model_a, model_b=model_b_perm, lambd=lambd)

    model_interp.load_state_dict(model_interp_params)

    trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=20)
    results = trainer.test(model_interp, test_loader)
    results_naive = trainer.test(model_naive, test_loader)

    all_results["naive"].append(results_naive)
    all_results["matched"].append(results)

In [None]:
# plot lambdas as x and test accuracy as y

plt.plot(lambdas, [x[0]["loss/test"] for x in all_results["naive"]], label="naive")
plt.plot(lambdas, [x[0]["loss/test"] for x in all_results["matched"]], label="matched")

plt.legend()
plt.xlabel("$\lambda$")
plt.ylabel("Test Accuracy")