## Imports

In [None]:
import copy
import logging
import math
from pathlib import Path
from typing import Dict
from functools import partial
from pprint import pprint

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

import ccmm  # noqa
from ccmm.matching.frank_wolfe_sync_matching import frank_wolfe_synchronized_matching
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    perm_matrix_to_perm_indices,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)
from ccmm.matching.weight_matching import solve_linear_assignment_problem
from ccmm.matching.utils import unfactor_permutations

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
from typing import List


def average_models(model_params, reduction_fn, coeffs=None):
    if not isinstance(model_params, List):
        model_params = list(model_params.values())

    print(coeffs)
    if coeffs is None:
        coeffs = [1] * len(model_params)

    return {
        k: reduction_fn(torch.stack([p[k] * coeffs[i] for i, p in enumerate(model_params)]))
        for k in model_params[0].keys()
    }


def trimmed_mean(tensors, trim_ratio=0.1):
    num_values = tensors.size(0)
    num_to_trim = int(trim_ratio * num_values)
    sorted_tensors = tensors.sort(dim=0).values
    trimmed_tensors = sorted_tensors[num_to_trim : num_values - num_to_trim]
    return trimmed_tensors.mean(dim=0)


def winsorize(tensor, limits=[0.2, 0.8]):
    lower, upper = torch.quantile(tensor, torch.tensor(limits).float(), dim=0)
    clipped = torch.clamp(tensor, min=lower, max=upper)
    return clipped.mean(dim=0)


def robust_mean(tensor, threshold=3.5):
    median_val = tensor.median(dim=0).values
    mad_val = (tensor - median_val).abs().median(dim=0).values
    mad_val[mad_val == 0] = 1  # Prevent division by zero
    z_score = 0.6745 * (tensor - median_val) / mad_val
    mask = (z_score.abs() < threshold).float()  # Create a mask to zero-out outliers
    filtered_tensor = tensor * mask  # Apply mask
    robust_mean_val = filtered_tensor.sum(dim=0) / mask.sum(dim=0)  # Compute mean only over non-outlier values
    return robust_mean_val

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="merge_n_models")

In [None]:
cfg = compose(config_name="merge_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
symbols = symbols.difference({"o"})  # "o" is the model trained over all the dataset

sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
red_fns = {"mean": partial(torch.mean, dim=0), "trimmed": trimmed_mean, "winsor": winsorize, "filter": robust_mean}

chosen_red_fn = red_fns["mean"]

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

### Run matching

In [None]:
model_merger = instantiate(cfg.merger, permutation_spec=permutation_spec)

pylogger.info(f"Merger: {model_merger.name}")

In [None]:
max_iter = 200
initialization_method = "identity"
keep_soft_perms = True

symbols = list(models.keys())

merged_model = copy.deepcopy(models[symbols[0]])

perm_indices, opt_infos = frank_wolfe_synchronized_matching(
    models=models,
    perm_spec=permutation_spec,
    symbols=symbols,
    combinations=canonical_combinations,
    max_iter=max_iter,
    initialization_method=initialization_method,
    keep_soft_perms=keep_soft_perms,
    verbose=False,
)

### Plot intermediate permutations

In [None]:
plot_permutation_history_animation = False

In [None]:
if plot_permutation_history_animation:
    perm_symbols = list(perm_indices["a"].keys())
    # perm_symbols = ['P_bg0', 'P_blockgroup3.block2_inner']
    k = 0
    K = 5
    for perm_symb in perm_symbols:
        perms = {symb: [perm[symb][perm_symb] for perm in opt_infos["perm_history"][k:K]] for symb in symbols}

        fig, ax = plt.subplots(5, K - k, figsize=(20, 20))

        for i in range(K - k):
            for j, symb in enumerate(symbols):
                ax[j, i].imshow(perms[symb][i].cpu(), cmap="gray")
                ax[j, i].axis("off")
                ax[j, i].set_title(symb)
        plt.show()

In [None]:
perm_symbols = list(perm_indices["a"].keys())
print(perm_symbols)

### Plot convergence values

In [None]:
plt.plot(opt_infos["obj_values"], color=palette["light red"])


plt.xlabel("Iteration")
plt.ylabel("Objective")
plt.title("Objective curve")

plt.savefig("figures/convergence_obj.pdf", bbox_inches="tight")

In [None]:
plt.plot([step_size for ind, step_size in enumerate(opt_infos["step_sizes"])], color=palette["light red"])


plt.xlabel("Iteration")
plt.ylabel("Step size")
plt.title("Step size")

plt.savefig("figures/convergence_step_sizes.pdf", bbox_inches="tight")

In [None]:
plt.plot(
    [step_size for ind, step_size in enumerate(opt_infos["step_sizes"]) if ind % 2 == 1], color=palette["light red"]
)

plt.xlabel("Iteration")
plt.ylabel("Step size")
plt.title("Step size for odd iterations")

plt.savefig("figures/convergence_step_sizes_odd.pdf", bbox_inches="tight")

In [None]:
plt.plot(
    [step_size for ind, step_size in enumerate(opt_infos["step_sizes"]) if ind % 2 == 0], color=palette["light red"]
)
plt.ylabel("Step size")
plt.xlabel("Iteration")
plt.title("Step size for even iterations")
plt.savefig("figures/convergence_step_sizes_even.pdf", bbox_inches="tight")

### Plot soft vs hard perms

In [None]:
hard_perms = {
    symb: {p: solve_linear_assignment_problem(perm) for p, perm in perm_indices[symb].items()} for symb in symbols
}

soft_perms = copy.deepcopy(perm_indices)

In [None]:
plot_perms = False
if plot_perms:
    perm_symbols = list(perm_indices["a"].keys())

    for perm_symbol in perm_symbols:
        fig, axs = plt.subplots(2, len(symbols), figsize=(40, 20))
        fig.suptitle(f"Permutation: {perm_symbol}", fontsize=16)
        for i, symbol in enumerate(symbols):

            ax0 = axs[0][i]
            ax0.set_title(symbol)
            ax0.imshow(soft_perms[symbol][perm_symbol].cpu(), cmap=cmap_name)
            ax0.colorbar = plt.colorbar(
                matplotlib.cm.ScalarMappable(norm=colors.Normalize(vmin=0, vmax=1), cmap=cmap_name), ax=ax0
            )

            ax1 = axs[1][i]
            ax1.set_title(symbol)
            ax1.imshow(perm_indices_to_perm_matrix(hard_perms[symbol][perm_symbol]), cmap=cmap_name)
            ax1.colorbar = plt.colorbar(
                matplotlib.cm.ScalarMappable(norm=colors.Normalize(vmin=0, vmax=1), cmap=cmap_name), ax=ax1
            )

### Count non-zero values in soft perms

In [None]:
perm_name = "P_blockgroup2.block3_inner"
nonzero_idxs = (soft_perms["a"][perm_name] > 0) & (soft_perms["a"][perm_name] < 1)

soft_perms["a"][perm_name][nonzero_idxs]

### Map to universe

In [None]:
def get_models_permuted_to_univ(perms, models, symbols, keep_soft_perms=False):
    models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

    for symbol in symbols:
        perms_to_apply = {}

        for perm_name in perms[symbol].keys():
            perm = perms[symbol][perm_name]

            if keep_soft_perms:
                perm = perm.T
                perm_to_apply = perm
            else:
                perm = perm_indices_to_perm_matrix(perm).T
                perm_to_apply = perm_matrix_to_perm_indices(perm)

            perms_to_apply[perm_name] = perm_to_apply

        updated_params = apply_permutation_to_statedict(
            permutation_spec, perms_to_apply, models[symbol].model.state_dict()
        )
        models_permuted_to_universe[symbol].model.load_state_dict(updated_params)

    return models_permuted_to_universe


models_permuted_to_universe = get_models_permuted_to_univ(soft_perms, models, symbols, keep_soft_perms=True)

In [None]:
models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}
pairwise_permutations = unfactor_permutations(hard_perms)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["a"])

    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )

    ref_model.model.load_state_dict(permuted_params)
    models_permuted_pairwise[fixed][permutee] = ref_model

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### PLOT: cosine similarities in original and universe space

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

In [None]:
def plot_similarities(models, filename, dist="cosine"):

    dist_matrix = np.zeros((len(models), len(models)))

    for i, (symbol_i, model_i) in enumerate(models.items()):
        for j, (symbol_j, model_j) in enumerate(models.items()):

            if dist == "cosine":
                dist_matrix[i, j] = (model_i @ model_j) / (torch.norm(model_i) * torch.norm(model_j))
                title = "Cosine Similarity"
                fmt = ".2g"
                vmin, vmax = 0, 1

            elif dist == "euclidean":
                dist_matrix[i, j] = torch.norm(model_i - model_j)

                title = "Euclidean Distance"
                fmt = ".4g"

                vmin, vmax = 0, 85

            else:
                raise ValueError(f"Unknown distance metric: {dist}")

    plt.figure(figsize=(5, 5))

    cmap = sns.light_palette("seagreen", as_cmap=True)

    mask = np.triu(np.ones_like(dist_matrix), k=1)

    sns.heatmap(dist_matrix, annot=True, cmap=cmap, cbar=False, mask=mask, vmin=vmin, vmax=vmax, fmt=fmt)
    plt.ylabel("Model Symbol")
    plt.title(title)
    plt.savefig(f"figures/{filename}")

    plt.show()

In [None]:
plot_similarities = False
if plot_similarities:
    dist = "euclidean"
    plot_similarities(flat_models, filename=f"similarities_orig_{dist}.pdf", dist=dist)
    plot_similarities(flat_models_permuted_to_universe, filename=f"similarities_univ_{dist}.pdf", dist=dist)

### PLOT: CKA of original and universe space 

In [None]:
from latentis.measure.functional.cka import linear_cka

In [None]:
def plot_repr_similarities(models, dataset, filename, dist="cka"):

    dist_matrix = np.zeros((len(models), len(models)))

    for i, (symbol_i, model_i) in enumerate(models.items()):
        for j, (symbol_j, model_j) in enumerate(models.items()):
            # model_params = [model_i.model.state_dict(), model_j.model.state_dict()]
            # interp_model = average_models(model_params, reduction_fn=red_fns["mean"])

            num_activations = 1000
            train_loader = DataLoader(dataset, batch_size=num_activations, num_workers=0)
            batch = next(iter(train_loader))
            x, y = batch
            features_i = model_i.model(x, return_embeddings=True)
            features_j = model_j.model(x, return_embeddings=True)

            if dist == "cka":
                dist_matrix[i, j] = linear_cka(features_i, features_j)
                title = "Centered Kernel Alignment"
                fmt = ".2g"
                vmin, vmax = 0, 1
            elif dist == "euclidean":
                dist_matrix[i, j] = torch.norm(features_i - features_j)
                title = "Euclidean distance between representations"
                fmt = ".4g"
                vmin, vmax = 0, 300
            else:
                raise ValueError(f"Unknown distance metric: {dist}")

    plt.figure(figsize=(5, 5))

    cmap = sns.light_palette("seagreen", as_cmap=True)

    mask = np.triu(np.ones_like(dist_matrix), k=1)

    sns.heatmap(dist_matrix, annot=True, cmap=cmap, cbar=False, mask=mask, vmin=vmin, vmax=vmax, fmt=fmt)
    plt.ylabel("Model Symbol")
    plt.title(title)
    plt.savefig(f"figures/{filename}")

    plt.show()

In [None]:
plot_repr_similarities = False

if plot_repr_similarities:

    dist = "euclidean"
    plot_repr_similarities(models, dataset=train_dataset, filename=f"similarities_orig_repr_{dist}.pdf", dist=dist)
    plot_repr_similarities(
        models_permuted_to_universe, dataset=train_dataset, filename=f"similarities_univ_repr_{dist}.pdf", dist=dist
    )

### PLOT: performance

In [None]:
def plot_performance_matrix(models, loader, filename):

    dist_matrix = np.zeros((len(models), len(models)))

    for i, (symbol_i, model_i) in enumerate(models.items()):
        for j, (symbol_j, model_j) in enumerate(models.items()):
            ref_model = copy.deepcopy(list(models.values())[0])

            model_params = [model_i.model.state_dict(), model_j.model.state_dict()]
            interp_model = average_models(model_params, reduction_fn=red_fns["mean"])
            ref_model.model.load_state_dict(interp_model)

            res = trainer.test(ref_model, loader)[0]
            dist_matrix[i, j] = res["acc/test"]

    plt.figure(figsize=(5, 5))

    cmap = sns.light_palette("seagreen", as_cmap=True)

    mask = np.triu(np.ones_like(dist_matrix), k=1)

    vmin, vmax = 0, 1
    fmt = ".2g"
    title = "Interpolated models accuracy"

    sns.heatmap(dist_matrix, annot=True, cmap=cmap, cbar=False, mask=mask, vmin=vmin, vmax=vmax, fmt=fmt)
    plt.ylabel("Model Symbol")
    plt.title(title)
    plt.savefig(f"figures/{filename}")

    plt.show()

In [None]:
plot_performance = False

if plot_performance:
    plot_performance_matrix(models, loader=test_loader, filename="performance_orig.pdf")
    plot_performance_matrix(models_permuted_to_universe, loader=test_loader, filename="performance_univ.pdf")

## Merging

In [None]:
model_params = [model.model.state_dict() for model in models_permuted_to_universe.values()]

merged_params = average_models(model_params, reduction_fn=chosen_red_fn)
merged_model.model.load_state_dict(merged_params)

#### Evaluate

In [None]:
loader = test_loader
trainer.test(merged_model, loader)

### Repair

In [None]:
from ccmm.matching.repair import repair_model

repaired_model = repair_model(merged_model, models_permuted_to_universe, train_loader)

In [None]:
trainer.test(repaired_model, loader)

## Search for the best merge in the simplex

In [None]:
len(models_permuted_to_universe)

In [None]:
import numpy as np
from scipy.optimize import minimize
from scipy.optimize import differential_evolution

ref_model = copy.deepcopy(models_permuted_to_universe["a"])


def evaluate_model(alphas, models, loader):

    params = {symb: model.model.state_dict() for symb, model in models.items()}

    merged_model_params = average_models(params, reduction_fn=red_fns["mean"], coeffs=alphas)

    merged_model = copy.deepcopy(ref_model)
    merged_model.model.load_state_dict(merged_model_params)

    acc = trainer.test(merged_model, loader)[0]["acc/test"]

    # penalty = 1e6 * abs(np.sum(alphas) - 1)

    return -acc


# # Constraints: alphas should sum to 1
def constraint_eq(alphas):
    return np.sum(alphas) - 1


n = len(models)
# Bounds: alphas should be between 0 and 1
bounds = [(0, 1) for _ in range(n)]

# Perform the optimization
loss_fn = partial(evaluate_model, loader=train_loader)
# result = minimize(loss_fn, initial_guess,
#                   bounds=bounds)

# Perform the optimization
# result = minimize(loss_fn, initial_guess, method='SLSQP', bounds=bounds, constraints={'type': 'eq', 'fun': constraint_eq},
#                   options={'ftol': 1e-2, 'disp': True})

result = differential_evolution(
    evaluate_model,
    args=(models_permuted_to_universe, train_loader),
    bounds=bounds,
    strategy="best1bin",
    maxiter=10,
    popsize=20,
    tol=1e-3,
    mutation=(0.5, 1),
    recombination=0.7,
    disp=True,
)

# Get the optimal alphas
optimal_alphas = result.x

print(f"Optimal alphas: {optimal_alphas}")

In [None]:
print(f"Optimal alphas: {optimal_alphas}")

merged_model_params = average_models(
    {symb: model.model.state_dict() for symb, model in models_permuted_to_universe.items()},
    reduction_fn=red_fns["mean"],
    coeffs=optimal_alphas,
)

merged_model = copy.deepcopy(ref_model)

merged_model.model.load_state_dict(merged_model_params)

loss = trainer.test(merged_model, loader)[0]["loss/test"]

## Fisher averaging

In [None]:
models_permuted_to_universe = {symb: model.to("cpu") for symb, model in models_permuted_to_universe.items()}
fishers = {}

In [None]:
from ccmm.matching.fisher_merging import compute_fisher_for_model

num_fisher_samples = 500
fisher_train_subset = Subset(train_dataset, list(range(num_fisher_samples)))

fisher_train_loader = DataLoader(fisher_train_subset, batch_size=8)

num_classes = core_cfg.dataset.num_classes

fishers = {
    symbol: compute_fisher_for_model(model.model, fisher_train_loader, num_classes)
    for symbol, model in models_permuted_to_universe.items()
}

In [None]:
for model in models_permuted_to_universe.values():
    model.to("cpu")

In [None]:
merged_model = copy.deepcopy(models["a"])
for param_name, param in merged_model.model.named_parameters():
    param.data.zero_()

In [None]:
target_model = copy.deepcopy(models_permuted_to_universe["a"])

In [None]:
for symbol, model in models_permuted_to_universe.items():
    for param_name, param in model.model.named_parameters():

        to_add = torch.zeros_like(param.data)
        target_param = target_model.model.state_dict()[param_name].data

        fish_coeff = fishers[symbol][param_name]

        all_fishers = torch.stack([fishers[symb][param_name] for symb in symbols])
        fish_coeff = fish_coeff / all_fishers.sum(dim=0)

        tol = 1e-12

        num_small_fish = (fish_coeff < tol).sum()
        ratio_small_fish = num_small_fish / fish_coeff.numel()

        pylogger.info(f"Number of small fisher coefficients: {num_small_fish}, ratio: {ratio_small_fish}")
        pylogger.info(f"Average fisher: {fish_coeff.mean()}")

        to_add[fish_coeff < tol] = target_param[fish_coeff < tol] * (1 / num_models)

        to_add[fish_coeff >= tol] = param.data[fish_coeff >= tol] * (1 / num_models) * fish_coeff[fish_coeff >= tol]

        merged_model.model.state_dict()[param_name].add_(to_add)

In [None]:
trainer.test(merged_model, loader)

In [None]:
trainer.test(target_model, loader)

In [None]:
repaired_model = repair_model(merged_model, models_permuted_to_universe, train_loader)

In [None]:
trainer.test(repaired_model, loader)

## Merge in the basins of the endpoint models 

For each model having symbol s, try to align all the other models to it. Then, average the aligned models.

In [None]:
results = {symbol: {"vanilla": None, "repaired": None} for symbol in symbols}

for symbol in symbols:

    # all the other models permuted to the current model
    mapped_models = {
        other_symb: models_permuted_pairwise[symbol][other_symb]
        for other_symb, model in models.items()
        if other_symb != symbol
    }
    mapped_params = {symb: model.model.state_dict() for symb, model in mapped_models.items()}
    mapped_params[symbol] = models[symbol].model.state_dict()

    merged_model = copy.deepcopy(models[symbol])

    mean_model = average_models(mapped_params, reduction_fn=red_fns["mean"])

    merged_model.model.load_state_dict(mean_model)

    vanilla_res = trainer.test(merged_model, loader)[0]

    repaired_model = repair_model(merged_model, mapped_models, train_loader)

    repair_res = trainer.test(repaired_model, loader)[0]

    results[symbol]["vanilla"] = vanilla_res
    results[symbol]["repaired"] = repair_res

In [None]:
pprint(results)

In [None]:
# Extracting the models and their accuracies
vanilla_accuracies = [results[symbol]["vanilla"]["acc/test"] for symbol in symbols]
repaired_accuracies = [results[symbol]["repaired"]["acc/test"] for symbol in symbols]

# Set the width of the bars
bar_width = 0.35
index = np.arange(len(symbols))

# Plotting the bar chart
fig, ax = plt.subplots(figsize=(10, 6))

bars1 = ax.bar(index, vanilla_accuracies, bar_width, label="Vanilla", color=palette["light red"])
bars2 = ax.bar(index + bar_width, repaired_accuracies, bar_width, label="Repaired", color=palette["green"])

# Adding labels, title, and legend
ax.set_xlabel("Basin")
ax.set_ylabel("Accuracy")
ax.set_title("Accuracy when averaging in different basins")
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(models)
ax.legend(loc="lower center", bbox_to_anchor=(0.5, -0.3), ncol=2)
ax.set_ylim(0, max(repaired_accuracies) + 0.1)  # Adding some space above the tallest bar

# Adding the accuracy values on top of the bars
def add_labels(bars):
    for bar in bars:
        height = bar.get_height()
        ax.annotate(
            f"{height:.2f}",
            xy=(bar.get_x() + bar.get_width() / 2, height),
            xytext=(0, 3),  # 3 points vertical offset
            textcoords="offset points",
            ha="center",
            va="bottom",
        )


add_labels(bars1)
add_labels(bars2)

plt.savefig("figures/accuracy_in_different_basins.pdf", bbox_inches="tight")

plt.tight_layout()
plt.show()

## Merge subsets

### Matching the N-1 models in each subset

Match and merge all the (n-1)-subsets of the n models 

In [None]:
# for each symbol, collect the subset of the remaining n-1 models
symbol_subsets = {symbol: set(symbols).difference(symbol) for symbol in symbols}

merged_model = copy.deepcopy(models[symbols[0]])

matched_subsets_results = {}
merged_models = {}

for symbol_subset in symbol_subsets.values():

    combinations = get_all_symbols_combinations(symbol_subset)
    canonical_combinations = [(source, target) for (source, target) in combinations if source < target]
    model_subset = {symb: models[symb] for symb in symbol_subset}

    perm_indices, _ = frank_wolfe_synchronized_matching(
        models=model_subset,
        perm_spec=permutation_spec,
        symbols=list(symbol_subset),
        combinations=canonical_combinations,
        max_iter=max_iter,
        initialization_method=initialization_method,
        keep_soft_perms=keep_soft_perms,
    )

    pylogger.info(f"Symbol subset: {symbol_subset}")

    models_to_univ_subset = get_models_permuted_to_univ(perm_indices, model_subset, symbol_subset, keep_soft_perms)

    model_params = [model.model.state_dict() for model in models_to_univ_subset.values()]

    merged_params = average_models(model_params, reduction_fn=red_fns["mean"])
    merged_model.model.load_state_dict(merged_params)

    merged_results = trainer.test(merged_model, loader)[0]

    repaired_model = repair_model(merged_model, models_to_univ_subset, train_loader)

    repair_results = trainer.test(repaired_model, loader)[0]

    matched_subsets_results[tuple(symbol_subset)] = {"merged": merged_results, "repaired": repair_results}
    merged_models[tuple(symbol_subset)] = {"merged": merged_model, "repaired": repaired_model}

In [None]:
matched_subsets_results

#### Accuracy histogram when merging matched subsets

In [None]:
combinations = list(matched_subsets_results.keys())
merged_accuracies = [matched_subsets_results[combo]["merged"]["acc/test"] for combo in combinations]
repaired_accuracies = [matched_subsets_results[combo]["repaired"]["acc/test"] for combo in combinations]

combination_strings = ["(" + ",".join(sorted(combo)) + ")" for combo in combinations]

# Set the width of the bars
bar_width = 0.35
index = np.arange(len(combination_strings))

# Plotting the bar chart
fig, ax = plt.subplots(figsize=(12, 8))

ax.set_ylim(0, max(repaired_accuracies) + 0.1)  # Adding some space above the tallest bar

bars1 = ax.bar(index, merged_accuracies, bar_width, label="Vanilla", color=palette["light red"])
bars2 = ax.bar(index + bar_width, repaired_accuracies, bar_width, label="Repaired", color=palette["green"])

ax.set_ylabel("Accuracy")
ax.set_title("Accuracy when merging model subsets")
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(combination_strings, rotation=45, ha="right")
ax.legend(loc="lower center", bbox_to_anchor=(0.5, -0.3), ncol=2)

# Adding the accuracy values on top of the bars
def add_labels(bars):
    for bar in bars:
        height = bar.get_height()
        ax.annotate(
            f"{height:.2f}",
            xy=(bar.get_x() + bar.get_width() / 2, height),
            xytext=(0, 3),  # 3 points vertical offset
            textcoords="offset points",
            ha="center",
            va="bottom",
        )


add_labels(bars1)
add_labels(bars2)

plt.savefig("figures/accuracy_model_subsets_matched.pdf", bbox_inches="tight")
plt.tight_layout()
plt.show()

### Matching all the N models and only averaging the N-1 models in each subset

In [None]:
all_combinations = get_all_symbols_combinations(symbols)
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

perm_indices, _ = frank_wolfe_synchronized_matching(
    models=models,
    perm_spec=permutation_spec,
    symbols=symbols,
    combinations=canonical_combinations,
    max_iter=max_iter,
    initialization_method=initialization_method,
    keep_soft_perms=keep_soft_perms,
)

In [None]:
merged_model = copy.deepcopy(models[symbols[0]])
models_permuted_to_universe = get_models_permuted_to_univ(perm_indices, models, symbols, keep_soft_perms)

results_norealign = {}
merged_models_norealign = {}

for symbol_subset in symbol_subsets.values():
    pylogger.info(f"Symbol subset: {symbol_subset}")

    models_to_univ_subset = {symb: models_permuted_to_universe[symb] for symb in symbol_subset}

    model_params = [model.model.state_dict() for model in models_to_univ_subset.values()]

    merged_params = average_models(model_params, reduction_fn=red_fns["mean"])
    merged_model.model.load_state_dict(merged_params)

    merged_results = trainer.test(merged_model, loader)[0]

    repaired_model = repair_model(merged_model, models_to_univ_subset, train_loader)

    repair_results = trainer.test(repaired_model, loader)[0]

    results_norealign[tuple(symbol_subset)] = {"merged": merged_results, "repaired": repair_results}
    merged_models_norealign[tuple(symbol_subset)] = {"merged": merged_model, "repaired": repaired_model}

In [None]:
results_norealign

#### Accuracy histogram when merging subsets

In [None]:
combinations = list(results_norealign.keys())
merged_accuracies = [results_norealign[combo]["merged"]["acc/test"] for combo in combinations]
repaired_accuracies = [results_norealign[combo]["repaired"]["acc/test"] for combo in combinations]

combination_strings = ["(" + ",".join(sorted(combo)) + ")" for combo in combinations]

# Set the width of the bars
bar_width = 0.35
index = np.arange(len(combination_strings))

# Plotting the bar chart
fig, ax = plt.subplots(figsize=(12, 8))

bars1 = ax.bar(index, merged_accuracies, bar_width, label="Vanilla", color=palette["light red"])
bars2 = ax.bar(index + bar_width, repaired_accuracies, bar_width, label="Repaired", color=palette["green"])

ax.set_ylabel("Accuracy")
ax.set_title("Accuracy when merging model subsets")

ax.set_ylim(0, max(repaired_accuracies) + 0.1)  # Adding some space above the tallest bar

ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(combination_strings, rotation=45, ha="right")
ax.legend(loc="lower center", bbox_to_anchor=(0.5, -0.3), ncol=2)

# Adding the accuracy values on top of the bars
def add_labels(bars):
    for bar in bars:
        height = bar.get_height()
        ax.annotate(
            f"{height:.2f}",
            xy=(bar.get_x() + bar.get_width() / 2, height),
            xytext=(0, 3),  # 3 points vertical offset
            textcoords="offset points",
            ha="center",
            va="bottom",
        )


add_labels(bars1)
add_labels(bars2)

plt.savefig("figures/accuracy_model_subsets.pdf", bbox_inches="tight")

plt.tight_layout()
plt.show()