## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolation,
    load_model_from_info,
    load_permutations,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_sampled_points = 1000  # 2048
num_test_samples = 500

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_info(cfg.model_info_path, seed) for seed in cfg.model_seeds
}

num_models = len(models)
pylogger.info(f"Using model {core_cfg.model.name}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

for symbol, model in models_permuted_to_universe.items():
    permuted_params = apply_permutation_to_statedict(permutation_spec, permutations[symbol], model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [None]:
from ccmm.utils.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {
        other_symb: copy.deepcopy(model)
        for symbol, model in models.items()
        for other_symb in set(symbols).difference(symbol)
    }
    for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in canonical_combinations:
    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )
    models_permuted_pairwise[fixed][permutee].model.load_state_dict(permuted_params)

### Check performance of models before and after permutation

In [None]:
for symbol, model in models_permuted_to_universe.items():
    trainer.test(models_permuted_to_universe[symbol], test_loader)
    trainer.test(models[symbol], test_loader)

## Analyze models as vectors

### Flatten models

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()
    }
    for symbol, models in models_permuted_pairwise.items()
}

### Analyze the norms

In [None]:
from ccmm.utils.utils import to_np


norms = {"model": [], "permuted": [], "diff": []}
for symbol, model in models.items():
    norm = torch.norm(flat_models[symbol])
    norm_permuted = torch.norm(flat_models_permuted_to_universe[symbol])
    norm_diff = torch.norm(flat_models[symbol] - flat_models_permuted_to_universe[symbol])

    norms["model"].append(to_np(norm))
    norms["permuted"].append(to_np(norm_permuted))
    norms["diff"].append(to_np(norm_diff))

In [None]:
import pandas as pd

df = pd.DataFrame(norms, index=models.keys())

df = df.apply(pd.to_numeric)

plt.figure(figsize=(5, 5))
sns.heatmap(df, annot=True, cmap="viridis")
plt.title("Model Norms Comparison")
plt.ylabel("Model Symbol")
plt.show()

In [None]:
# matrix of the cosine products
cosine_matrix = np.zeros((len(models), len(models)))

for i, (symbol_i, model_i) in enumerate(models.items()):
    for j, (symbol_j, model_j) in enumerate(models.items()):
        cosine_matrix[i, j] = flat_models_permuted_to_universe[symbol_i].dot(flat_models[symbol_j]) / (
            torch.norm(flat_models_permuted_to_universe[symbol_i]) * torch.norm(flat_models[symbol_j])
        )

In [None]:
# plot the matrix

plt.figure(figsize=(5, 5))
sns.heatmap(cosine_matrix, annot=True, cmap="viridis")
plt.title("Cosine Similarity Matrix")
plt.ylabel("Model Symbol")
plt.show()

In [None]:
x = to_np(flat_models["a"])
plt.style.use("seaborn")
fig, ax = plt.subplots(figsize=(5, 5))

ax.hist(x, bins=500)
print("?")

#### Experiment: sparsify models

In [None]:
sparsify_models = False

In [None]:
if sparsify_models:
    sparsified_models = {}

    # set to zero all the values far than 1std from the mean
    sparsified_models = {
        symbol: torch.where(torch.abs(flat_model) > torch.std(flat_model), torch.zeros_like(flat_model), flat_model)
        for symbol, flat_model in flat_models.items()
    }
    x = flat_models["a"]
    print(x.shape)
    # count how many values < 1e-4
    torch.sum(torch.abs(x) < 1e-2)

    flat_models_sparse = {}
    flat_models_perm_sparse = {}
    for symb, model in flat_models.items():

        flat_models_sparse[symb] = torch.clone(model)
        flat_models_sparse[symb][torch.abs(model) < 1e-3] = 0.0

        flat_models_perm_sparse[symb] = torch.clone(flat_models_permuted_to_universe[symb])
        flat_models_perm_sparse[symb][torch.abs(flat_models_permuted_to_universe[symb]) < 1e-3] = 0.0

## Going from 2D to high dimensions and back

### Sampling the 2D plane

In [None]:
random_points_plane = qmc.scale(
    qmc.Sobol(d=2, scramble=True, seed=cfg.seed_index).random(num_sampled_points),
    [-0.5, -0.5],
    [0.5, 0.5],
)

pylogger.info(random_points_plane[:10])

In [None]:
boundaries = [[-0.5], [0.5]]
lower_bounds = np.array([boundaries[0][0], boundaries[0][0]])
upper_bounds = np.array([boundaries[1][0], boundaries[1][0]])

pylogger.info(f"Lower bounds: {lower_bounds}")
pylogger.info(f"Upper bounds: {upper_bounds}")

### Method 1: Barycentric coordinates

#### Utils

In [None]:
def represent_2D_point_pentagon_barycentric_coordinates(x):
    """
    x: point in the plane (2, )
    """
    origins = get_pentagon_vertices(0.0, 0.0, 0.5)

    # (2, num_models)
    A = origins.transpose(1, 0)

    # (3, num_models)
    A = np.vstack([A, np.ones(5)])

    # (3, )
    x = np.append(x, 1)

    z, residuals, rank, s = np.linalg.lstsq(A, x, rcond=None)

    A = torch.from_numpy(A)
    z = torch.from_numpy(z)
    x = torch.from_numpy(x)
    # assert torch.allclose(A @ z, x)

    assert torch.allclose(torch.sum(z).float(), torch.tensor(1.0).float())

    return z.float()


def get_pentagon_vertices(center_x, center_y, radius):
    """
    Get the vertices of a pentagon centered at (center_x, center_y) with the given radius.
    """
    pentagon_vertices = []

    for i in range(5):

        angle_deg = 72 * i  # 72 degrees between each point
        angle_rad = math.radians(angle_deg)  # Convert to radians

        x = radius * math.cos(angle_rad) + center_x
        y = radius * math.sin(angle_rad) + center_y

        pentagon_vertices.append((x, y))

    return np.array(pentagon_vertices)


import numpy as np


def create_regular_polygon_vertices(sides, radius=1):
    """Create vertices for a regular polygon centered at the origin"""
    return np.array(
        [[radius * np.cos(2 * np.pi * i / sides), radius * np.sin(2 * np.pi * i / sides)] for i in range(sides)]
    )


def wachspress_coordinates(vertices, p):
    """Calculate Wachspress barycentric coordinates for a point inside a polygon"""
    n = len(vertices)
    alphas = np.zeros(n)
    for i in range(n):
        v1, v2, v3 = vertices[i - 1], vertices[i], vertices[(i + 1) % n]
        area = area_of_triangle(v1, v2, v3) + 1e-6
        d1 = distance_to_edge(p, v1, v2) + 1e-6
        d2 = distance_to_edge(p, v2, v3) + 1e-6
        alphas[i] = area / (d1 * d2) if d1 * d2 > 1e-5 else 0

    return alphas / np.sum(alphas)


def area_of_triangle(v1, v2, v3):
    """Calculate the area of a triangle given its vertices"""
    return 0.5 * np.linalg.norm(np.cross(v2 - v1, v3 - v1)) + 1e-6


def distance_to_edge(p, v1, v2):
    """Calculate the distance from point p to the edge (v1, v2)"""
    if np.all(v1 == v2):
        return np.linalg.norm(p - v1 + 1e-6)
    return np.linalg.norm(np.cross(v2 - v1, v1 - p)) / np.linalg.norm(v2 - v1) + 1e-6


def wachspress_coordinates(vertices, p):
    """Calculate Wachspress barycentric coordinates for a point inside a polygon with a fix for vertices"""
    n = len(vertices)
    alphas = np.zeros(n)

    # Check if the point is exactly on one of the vertices
    for i, vertex in enumerate(vertices):
        if np.all(p == vertex):
            coords = np.zeros(n)
            coords[i] = 1
            return coords

    # Calculate Wachspress weights if the point is not one of the vertices
    for i in range(n):
        v1, v2, v3 = vertices[i - 1], vertices[i], vertices[(i + 1) % n]
        area = area_of_triangle(v1, v2, v3)
        d1 = distance_to_edge(p, v1, v2)
        d2 = distance_to_edge(p, v2, v3)
        alphas[i] = area / (d1 * d2) if d1 * d2 != 0 else 0

    return alphas / np.sum(alphas)

In [None]:
vertices = create_regular_polygon_vertices(10, 0.5)

plt.scatter(vertices[:, 0], vertices[:, 1], c="black", s=100)

circle = plt.Circle((0, 0), 0.5, color="black", fill=False)

plt.gca().add_patch(circle)

#### Represent permuted models as barycentric coordinates wrt the n models


In [None]:
def represent_wrt_models_barycentric(model_to_repr, flat_models):

    # (num_params_per_model, num_models)
    A = torch.stack(list(flat_models), dim=1)

    scaling = 1
    # Augment A with an additional row for the sum-to-one constraint
    ones_row = torch.ones(1, A.shape[1]) * scaling

    # (num_params_per_model + 1, num_models)

    A_augmented = torch.cat([A, ones_row], dim=0)

    # Augment the target model with an additional element for the sum-to-one constraint
    # (num_params_per_model + 1,)
    target_augmented = torch.cat([model_to_repr, torch.tensor([scaling])])

    # Solve the linear system (least squares)
    # want z such that Az = x
    # x is the target model
    barycentric_coords = torch.linalg.lstsq(A_augmented, target_augmented.unsqueeze(1)).solution
    # barycentric_coords = torch.linalg.lstsq(A, model_to_repr.unsqueeze(1)).solution

    pylogger.info(barycentric_coords)

    return barycentric_coords.cpu().detach().numpy()

In [None]:
# USE WHEN TRYING TO REPRESENT PERMUTED MODELS WRT ORIGIN MODELS
def get_model_2D_coordinates_barycentric(flat_models):

    vertices = create_regular_polygon_vertices(num_models, 0.5)

    model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}
    universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}

    for model_num, (symbol, perm_model) in enumerate(flat_models_permuted_to_universe.items()):

        model_baryc_coordinates = represent_wrt_models_barycentric(perm_model, flat_models)

        pylogger.info(f"{symbol} baryc coords: {model_baryc_coordinates.sum()}")

        model_2D_repr[symbol] = vertices[model_num]
        universe_model_2D_repr[symbol] = (model_baryc_coordinates * vertices).sum(axis=0)

    return model_2D_repr, universe_model_2D_repr

In [None]:
# USE WHEN CONSIDERING ALL THE 2*N MODELS (PERM AND ORIGINS) TO REPRESENT EACH MODEL AS A VERTEX OF THE 2*N REGULAR POLYGON
def get_model_2D_coordinates_barycentric_all_models(all_flat_models, num_models, num_pairwise_perms=0):

    model_letters = ["A", "B", "C", "D", "E"][:num_models]

    symbol_names = [r"\Theta_" + i for i in model_letters]
    perm_symbol_names = [r"\pi(\Theta_" + i + ")" for i in model_letters]
    if num_pairwise_perms > 0:
        pairwise_perm_symbol_names = []
        for ind, letter in enumerate(model_letters[1:]):
            pairwise_perm_symbol_names.append(
                r"\pi_{" + f"{letter}->{model_letters[ind]}" + r"}(\Theta_" + letter + ")"
            )

    all_symbol_names = symbol_names + perm_symbol_names + pairwise_perm_symbol_names

    origins = create_regular_polygon_vertices(2 * num_models + num_pairwise_perms, 0.45)

    model_2D_repr = {symbol_name: None for symbol_name in all_symbol_names}
    # universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}

    for model_num, flat_model in enumerate(all_flat_models):

        model_baryc_coordinates = represent_wrt_models_barycentric(flat_model, all_flat_models)

        pylogger.info(f"{all_symbol_names[model_num]} baryc coords: {model_baryc_coordinates.sum()}")

        model_2D_repr[all_symbol_names[model_num]] = origins[model_num]

    return model_2D_repr


all_flat_models = [*flat_models.values(), *flat_models_permuted_to_universe.values()]
all_flat_models.append(flat_models_permuted_pairwise["b"]["a"])

model_2D_repr = get_model_2D_coordinates_barycentric_all_models(all_flat_models, num_models, num_pairwise_perms=1)

### Collect the test loss for random samples in the 2D plane 

In [None]:
import numpy as np


def evaluate_model_interp_on_point(point, flat_models, model, trainer, test_loader):

    # (num_models, )
    num_vertices = len(flat_models)
    origins = create_regular_polygon_vertices(num_vertices, radius=0.45)

    baryc_coords = torch.tensor(wachspress_coordinates(origins, point)).unsqueeze(1)

    # (num_models, num_params_per_model)
    flat_models = torch.stack(flat_models)

    new_flat_params = (flat_models * baryc_coords).sum(dim=0)

    new_params = vector_to_state_dict(new_flat_params, model.model)

    model.model.load_state_dict(new_params)

    results = trainer.test(model, test_loader, verbose=False)

    return results[0]["loss/test"]

In [None]:
model = copy.deepcopy(models["a"])

test_loss_for_random_pts_plane = np.array(
    [
        evaluate_model_interp_on_point(point, all_flat_models, model, trainer, test_loader)
        for point in tqdm(random_points_plane)
    ]
)

test_loss_for_random_pts_plane[:10]

In [None]:
# from functools import partial
# from multiprocessing import Pool

# pool = Pool() #defaults to number of available CPU's

# eval_func = partial(evaluate_model_interp_on_point, all_flat_models, model, trainer, test_loader)

# results = np.zeros(len(random_points_plane))
# for ind, res in enumerate(tqdm(pool.imap(eval_func, iter(random_points_plane)), total=len(random_points_plane))):
#     results[ind] = res

### Method 2: Reference models as basis

In [None]:
def get_basis_vectors(origin_model, basis_model_1, basis_model_2):
    basis1 = basis_model_1 - origin_model
    scale1 = norm(basis1)
    basis1_normed = normalize_unit_norm(basis1)

    basis2 = basis_model_2 - origin_model
    scale2 = norm(basis2)
    basis2 = basis2 - project_onto(basis2, basis1_normed)
    basis2_normed = normalize_unit_norm(basis2)

    return basis1_normed, basis2_normed, scale1, scale2


# basis_model_1, basis_model_2, scale_1, scale_2 = get_basis_vectors(origin_model=flat_models['a'], basis_model_1=flat_models['b'], basis_model_2=flat_models['c'])

In [None]:
import numpy as np
from multiprocessing import Pool


def evaluate_model_interp_on_point(
    point, basis_model_1, basis_model_2, origin_model, ref_model, trainer, test_loader, scale_1, scale_2
):

    # (num_models, )
    new_flat_params = origin_model + (scale_1 * basis_model_1 * point[0] + scale_2 * basis_model_2 * point[1])

    new_params = vector_to_state_dict(new_flat_params, ref_model.model)

    ref_model.model.load_state_dict(new_params)

    eval_results = trainer.test(ref_model, test_loader, verbose=False)

    return eval_results[0]["loss/test"]


# ref_model = copy.deepcopy(models['a'])
# origin_model = flat_models['a']

# eval_results = np.array([evaluate_model_interp_on_point(point, scale_1=scale_1, scale_2=scale_2, basis_model_1=basis_model_1, basis_model_2=basis_model_2, origin_model=origin_model, ref_model=ref_model, trainer=trainer, test_loader=test_loader) for point in tqdm(random_points_plane)])

#### Represent models 2D

In [None]:
def represent_wrt_models(model_to_repr, origin_model, basis1, basis2, scale_1, scale_2):

    x_coord = torch.dot(model_to_repr - origin_model, basis1) / scale_1
    y_coord = torch.dot(model_to_repr - origin_model, basis2) / scale_2

    return torch.stack([x_coord, y_coord]).detach().cpu().numpy()

In [None]:
def get_model_2D_coordinates(flat_models, flat_perm_models, basis_model_1, basis_model_2, scale_1, scale_2):

    model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}
    universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}

    for model_num, (symbol, perm_model) in enumerate(flat_perm_models.items()):

        model_2D = represent_wrt_models(
            flat_models[symbol],
            origin_model=flat_models["a"],
            basis1=basis_model_1,
            basis2=basis_model_2,
            scale_1=scale_1,
            scale_2=scale_2,
        )
        model_2D_perm = represent_wrt_models(
            perm_model,
            origin_model=flat_models["a"],
            basis1=basis_model_1,
            basis2=basis_model_2,
            scale_1=scale_1,
            scale_2=scale_2,
        )

        model_2D_repr[symbol] = model_2D
        universe_model_2D_repr[symbol] = model_2D_perm

    return model_2D_repr, universe_model_2D_repr


# model_2D_repr, universe_model_2D_repr = get_model_2D_coordinates(flat_models, flat_models_permuted_to_universe, basis_model_1, basis_model_2, scale_1, scale_2)

In [None]:
# pylogger.info(model_2D_repr)
# pylogger.info(universe_model_2D_repr)

## Plot

### Create the 2D grid of points and their corresponding losses

In [None]:
xi = np.linspace(boundaries[0][0], boundaries[1][0])
yi = np.linspace(boundaries[0][0], boundaries[1][0])

# Linearly interpolate the data (x, y) on a grid defined by (xi, yi).
triang = tri.Triangulation(random_points_plane[:, 0], random_points_plane[:, 1])

# We need to cap the maximum loss value so that the contouring is not completely saturated by wildly large losses
interpolator = tri.LinearTriInterpolator(triang, np.clip(test_loss_for_random_pts_plane, None, 5))

# interpolator = tri.LinearTriInterpolator(triang, jnp.log(jnp.minimum(1.5, eval_results[:, 0])))
zi = interpolator(*np.meshgrid(xi, yi))

In [None]:
plt.figure(figsize=(10, 10))
num_levels = 13

plt.contour(xi, yi, zi, levels=num_levels, linewidths=0.25, colors="grey", alpha=0.5)

# cmap = truncate_colormap(plt.get_cmap(cmap_name), 0.0, 1)

plt.contourf(xi, yi, zi, levels=num_levels, cmap=plt.get_cmap(cmap_name), extend="both")

label_bboxes = dict(facecolor="tab:grey", boxstyle="round", edgecolor="none", alpha=0.5)

print(model_2D_repr)
for symbol, point in model_2D_repr.items():
    plt.scatter(point[0], point[1], marker="x", color="black", zorder=10)
    plt.text(
        point[0] - 0.05,
        point[1] + 0.05,
        r"${\bf " + symbol + r"}$",
        color="white",
        fontsize=24,
        bbox=label_bboxes,
        horizontalalignment="right",
        verticalalignment="top",
    )


box_x = 0
box_y = 0.5
title_text = r"$C^2M^2$"

# Draw box only
plt.text(
    box_x,
    box_y,
    title_text,
    color=(0.0, 0.0, 0.0, 0.0),
    fontsize=24,
    horizontalalignment="center",
    verticalalignment="center",
    bbox=dict(boxstyle="round", fc=(1, 1, 1, 1), ec="black", pad=0.4),
)
# Draw text only
plt.text(
    box_x,
    box_y - 0.0115,
    title_text,
    color=(0.0, 0.0, 0.0, 1.0),
    fontsize=24,
    horizontalalignment="center",
    verticalalignment="center",
)


# plt.colorbar()
plt.xlim(-0.5, 0.5)
plt.ylim(-0.5, 0.5)
#   plt.xlim(-0.9, 1.9)
#   plt.ylim(-0.9, 1.9)
plt.xticks([])
plt.yticks([])
plt.axis("equal")
# plt.tight_layout()
plt.savefig("resnet_cifar_loss_contour.png", dpi=300)

In [None]:
def linear_interpolation(model_a, model_b, lamb):
    return (1 - lamb) * model_a + lamb * model_b


def get_interp_loss_curve(lambdas, model_a, model_b, ref_model):

    interp_losses = []

    for lamb in lambdas:
        interp_params = linear_interpolation(model_a, model_b, lamb)
        interp_params = vector_to_state_dict(interp_params, ref_model.model)
        ref_model.model.load_state_dict(interp_params)
        eval_results = trainer.test(ref_model, test_loader, verbose=False)
        interp_losses.append(eval_results[0]["loss/test"])

    return interp_losses

In [None]:
ref_model = copy.deepcopy(models["a"])

lambdas = np.linspace(0, 1, 25)
# A, B
interp_ab = get_interp_loss_curve(lambdas, flat_models["a"], flat_models["b"], ref_model)

# A, P_AB(B)
interp_a_bperm_to_a = get_interp_loss_curve(
    lambdas, flat_models["a"], flat_models_permuted_pairwise["a"]["b"], ref_model
)

# P(B), P_AB(B)
interp_b_uni_bperm_to_a = get_interp_loss_curve(
    lambdas, flat_models_permuted_to_universe["b"], flat_models_permuted_pairwise["a"]["b"], ref_model
)

# P(A), P(B)
interp_b_uni_a_uni = get_interp_loss_curve(
    lambdas, flat_models_permuted_to_universe["b"], flat_models_permuted_to_universe["a"], ref_model
)

In [None]:
# B, P(B)
interp_b_a_uni = get_interp_loss_curve(lambdas, flat_models["b"], flat_models_permuted_to_universe["b"], ref_model)

In [None]:
plt.figure()
plt.plot(lambdas, interp_ab, marker="o", label="A, B")
plt.plot(lambdas, interp_a_bperm_to_a, marker="o", label="A, P_AB(B)")
plt.plot(lambdas, interp_b_uni_bperm_to_a, marker="o", label="P_univ(B), P_AB(B)")
plt.plot(lambdas, interp_b_uni_a_uni, marker="o", label="P_univ(B), P_univ(A)")
plt.plot(lambdas, interp_b_a_uni, marker="o", label="B, P_univ(B)")
plt.legend()

## Hic sunt leones

In [None]:
lambdas = np.linspace(0, 1, 25)

interp_results = []
interp_params = []

model_a_2D = represent_wrt_models(
    flat_models["a"],
    origin_model=flat_models["a"],
    basis1=basis_model_1,
    basis2=basis_model_2,
    scale_1=scale_1,
    scale_2=scale_2,
)
model_e_2D = represent_wrt_models(
    flat_models["e"],
    origin_model=flat_models["a"],
    basis1=basis_model_1,
    basis2=basis_model_2,
    scale_1=scale_1,
    scale_2=scale_2,
)

norms = []
results = {"2D_interp": [], "N_interp": []}
for lamb in lambdas:
    interp_model = flat_models["a"] * lamb + flat_models["e"] * (1 - lamb)
    interp_params.append(interp_model)

    new_params = vector_to_state_dict(interp_model, ref_model.model)
    ref_model.model.load_state_dict(new_params)
    res = trainer.test(ref_model, test_loader, verbose=False)[0]["loss/test"]

    results["2D_interp"].append(res)
    interp_point = model_a_2D * lamb + model_e_2D * (1 - lamb)
    new_params_reconstructed = origin_model + (
        scale_1 * basis_model_1 * interp_point[0] + scale_2 * basis_model_2 * interp_point[1]
    )

    ref_model.model.load_state_dict(vector_to_state_dict(new_params_reconstructed, ref_model.model))
    res = trainer.test(ref_model, test_loader, verbose=False)[0]["loss/test"]
    results["N_interp"].append(res)

    norms.append(torch.norm(new_params_reconstructed - interp_model).detach().cpu().numpy())

plt.figure()
plt.plot(lambdas, norms, marker="o")

In [None]:
results["2D_interp"] = np.array(results["2D_interp"])
results["N_interp"] = np.array(results["N_interp"])

plt.figure()
plt.plot(lambdas, results["2D_interp"], marker="o")
plt.plot(lambdas, results["N_interp"], marker="x")

In [None]:
(flat_models["c"] / norm(flat_models["c"])) @ (flat_models["b"] / norm(flat_models["b"]))

In [None]:
pylogger.info(model_a_2D)
pylogger.info(model_c_2D)

interps_on_plane = []
for lamb in lambdas:
    interps_on_plane.append(model_a_2D * lamb + model_c_2D * (1 - lamb))

    model_interp = flat_models["a"] * (1 - lamb) + flat_models["c"] * lamb


interps_on_plane

In [None]:
for ind, (par, lambd) in zip(interp_params, lambdas):

    interp_par_2D = represent_wrt_models(
        par, origin_model=flat_models["a"], basis1=basis_model_1, basis2=basis_model_2, scale_1=scale_1, scale_2=scale_2
    )

In [None]:
new_params = vector_to_state_dict(mean_abc, ref_model.model)

ref_model.model.load_state_dict(new_params)

test_loss_for_random_pts_plane = trainer.test(ref_model, test_loader, verbose=False)