# M$^3$SB: Multi-Model Merging via Spherical Barycenters - Experiments
This notebook contains all the experiments mentioned in the project report.

## Importing and Installing Package (Kaggle-Specific)

### Kaggle-Specific

Make sure to add the [m3sb-package](https://www.kaggle.com/datasets/mehdiamlal/m3sb-package)  dataset before executing the cells below and activate a `T4 GPU`.

In [None]:
%cd /kaggle/input/m3sb-package/m3sb

In [None]:
!pip install uv

In [None]:
!uv pip install -e .

Before running the cell below, restart the notebook: `Run > Restart & clear cell outputs`.

### Other Environments

The above only applies to Kaggle Notebooks, where the following experiments were originally executed. If you wish to run this notebook on another platform (e.g. Google Colab), you will have to upload the package's .zip file, go to its main directory, and:
1. Run `!pip install uv`
2. Run `!uv pip install -e .`
3. Restart the kerner.
4. Check if the runtime is a GPU (ideally a Nvidia T4 or better).

## Running Automated Experiments

In [None]:
from m3sb.experiment import Experiment, TaskVectorExperiment
import torch 

## Experiment 1-a

In [None]:
ex1a_config = {
    "name": "Experiment 1a: Merging Common Image Models",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.3333, 0.3334, 0.3333], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.3333, 0.3334, 0.3333]},
        "pairwise_slerp": {"weights": [0.3333, 0.3334, 0.3333]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex1a = Experiment(**ex1a_config)

In [None]:
ex1a.run()

In [None]:
ex1a.get_results_df().to_csv("ex1a_results.csv")

## Experiment 1-b

In [None]:
ex1b_config = {
    "name": "Experiment 1b: Merging Niche Models",
    "model_checkpoints": [
        "Payoto/vit-base-patch16-224-in21k-finetuned-eurosat",
        "akahana/vit-base-cats-vs-dogs",
        "nateraw/vit-base-beans"
    ],
    "datasets_config": [
        {"dataset_name": "tanganke/eurosat", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "Bingsu/Cat_and_Dog", "split": "test", "image_col": "image", "label_col": "labels"},
        {"dataset_name": "AI-Lab-Makerere/beans", "split": "test", "image_col": "image", "label_col": "labels"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.3333, 0.3334, 0.3333], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.3333, 0.3334, 0.3333]},
        "pairwise_slerp": {"weights": [0.3333, 0.3334, 0.3333]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex1b = Experiment(**ex1b_config)

In [None]:
ex1b.run()

In [None]:
ex1b.get_results_df().to_csv("ex1b_results.csv")

## Experiment 2

In [None]:
ex2_config = {
    "name": "Experiment 2: Merging More Models in Parameter Space",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102",
        "farleyknight-org-username/vit-base-mnist",
        "Payoto/vit-base-patch16-224-in21k-finetuned-eurosat",
        
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "ylecun/mnist", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "tanganke/eurosat", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.2, 0.2, 0.2, 0.2, 0.2], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.2, 0.2, 0.2, 0.2, 0.2]},
        "pairwise_slerp": {"weights": [0.2, 0.2, 0.2, 0.2, 0.2]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex2 = Experiment(**ex2_config)

In [None]:
ex2.run()

In [None]:
ex2.get_results_df().to_csv("ex2_results.csv")
results_df = ex2.get_results_df()
results_df

## Experiment 3-a

In [None]:
torch.manual_seed(42)
random_weights = torch.rand(3)

random_weights = random_weights / random_weights.sum()
random_weights

In [None]:
ex3a_config = {
    "name": "Experiment 3a: Random Interpolation Coefficients With Seed 42",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.4047, 0.4197, 0.1756], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.4047, 0.4197, 0.1756]},
        "pairwise_slerp": {"weights": [0.4047, 0.4197, 0.1756]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex3a = Experiment(**ex3a_config)

In [None]:
ex3a.run()

In [None]:
ex3a.get_results_df().to_csv("ex3a_results.csv")

In [None]:
ex3a.get_results_df()

## Experiment 3-b

In [None]:
torch.manual_seed(3)
random_weights = torch.rand(3)

random_weights = random_weights / random_weights.sum()
random_weights

In [None]:
ex3b_config = {
    "name": "Experiment 3-b: Random Interpolation Weights With Seed 3",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.0108, 0.2668, 0.7224], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.0108, 0.2668, 0.7224]},
        "pairwise_slerp": {"weights": [0.0108, 0.2668, 0.7224]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex3b = Experiment(**ex3b_config)

In [None]:
ex3b.run()

In [None]:
ex3b.get_results_df().to_csv("ex3b_results.csv")

In [None]:
ex3b.get_results_df()

## Experiment 4-a

In [None]:
ex4a_config = {
    "name": "Pairwise SLERP cycle consistency test (Model Parameters)",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "pairwise_slerp": {"weights": [0.5, 0.2, 0.3]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex4a = Experiment(**ex4a_config)

In [None]:
ex4a.run()

In [None]:
ex4a.get_results_df().to_csv("ex4a_results.csv")

In [None]:
ex4a.get_results_df()

## Experiment 4-b

In [None]:
ex4b_config = {
    "name": "Pairwise SLERP cycle consistency test (Model Parameters)",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "MaulikMadhavi/vit-base-flowers102",
        "nateraw/food"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "pairwise_slerp": {"weights": [0.5, 0.3, 0.2]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex4b = Experiment(**ex4b_config)

In [None]:
ex4b.run()

In [None]:
ex4b.get_results_df().to_csv("ex4b_results.csv")

## Experiment 4-c

In [None]:
ex4c_config = {
    "name": "Pairwise SLERP cycle consistency test (Model Parameters)",
    "model_checkpoints": [
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102",
        "pkr7098/cifar100-vit-base-patch16-224-in21k"
    ],
    "datasets_config": [
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"}
    ],
    "merge_configs": {
        "pairwise_slerp": {"weights": [0.2, 0.3, 0.5]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex4c = Experiment(**ex4c_config)

In [None]:
ex4c.run()

In [None]:
ex4c.get_results_df().to_csv("ex4c_results.csv")

## Experiment 5-a

In [None]:
ex5a_config = {
    "name": "Merging 3 Task Vectors",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.3333, 0.3334, 0.3333], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.3333, 0.3334, 0.3333]},
        "pairwise_slerp": {"weights": [0.3333, 0.3334, 0.3333]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex5a = TaskVectorExperiment(**ex5a_config)

In [None]:
ex5a.run()

In [None]:
ex5a.get_results_df().to_csv("ex5a_results.csv")

## Experiment 5-b

In [None]:
ex5b_config = {
    "name": "Merging 3 Task Vectors",
    "model_checkpoints": [
        "Payoto/vit-base-patch16-224-in21k-finetuned-eurosat",
        "akahana/vit-base-cats-vs-dogs",
        "nateraw/vit-base-beans"
    ],
    "datasets_config": [
        {"dataset_name": "tanganke/eurosat", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "Bingsu/Cat_and_Dog", "split": "test", "image_col": "image", "label_col": "labels"},
        {"dataset_name": "AI-Lab-Makerere/beans", "split": "test", "image_col": "image", "label_col": "labels"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.3333, 0.3334, 0.3333], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.3333, 0.3334, 0.3333]},
        "pairwise_slerp": {"weights": [0.3333, 0.3334, 0.3333]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex5b = TaskVectorExperiment(**ex5b_config)

In [None]:
ex5b.run()

In [None]:
ex5b.get_results_df().to_csv("ex5b_resuls.csv")

## Experiment 6

In [None]:
ex6_config = {
    "name": "5-Model Scalability Test on Task Vector Merging",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102",
        "farleyknight-org-username/vit-base-mnist",
        "Payoto/vit-base-patch16-224-in21k-finetuned-eurosat",
        
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "ylecun/mnist", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "tanganke/eurosat", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "barycenter": {"weights": [0.2, 0.2, 0.2, 0.2, 0.2], "iterations": 20, "threshold": 1e-5},
        "linear": {"weights": [0.2, 0.2, 0.2, 0.2, 0.2]},
        "pairwise_slerp": {"weights": [0.2, 0.2, 0.2, 0.2, 0.2]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex6 = TaskVectorExperiment(**ex6_config)

In [None]:
ex6.get_results_df().to_csv("ex6_results.csv")

## Experiment 7-a

In [None]:
ex7a_config = {
    "name": "Pairwise SLERP cycle consistency test - Task Vectors",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "linear": {"weights": [0.5, 0.2, 0.3]},
        "pairwise_slerp": {"weights": [0.5, 0.2, 0.3]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex7a = TaskVectorExperiment(**ex7a_config)

In [None]:
ex7a.run()

In [None]:
ex7a.get_results_df().to_csv("ex7a_results.csv")

## Experiment 7-b

In [None]:
ex7b_config = {
    "name": "Pairwise SLERP cycle consistency test - Task Vectors",
    "model_checkpoints": [
        "pkr7098/cifar100-vit-base-patch16-224-in21k",
        "MaulikMadhavi/vit-base-flowers102",
        "nateraw/food"
    ],
    "datasets_config": [
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"}
    ],
    "merge_configs": {
        "pairwise_slerp": {"weights": [0.5, 0.3, 0.2]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex7b = TaskVectorExperiment(**ex7b_config)

In [None]:
ex7b.run()

In [None]:
ex7b.get_results_df().to_csv("ex7b_results.csv")

## Experiment 7-c

In [None]:
ex7c_config = {
    "name": "Pairwise SLERP cycle consistency test - Task Vectors",
    "model_checkpoints": [
        "nateraw/food",
        "MaulikMadhavi/vit-base-flowers102",
        "pkr7098/cifar100-vit-base-patch16-224-in21k"
    ],
    "datasets_config": [
        {"dataset_name": "food101", "split": "validation", "image_col": "image", "label_col": "label"},
        {"dataset_name": "nkirschi/oxford-flowers", "split": "test", "image_col": "image", "label_col": "label"},
        {"dataset_name": "cifar100", "split": "test", "image_col": "img", "label_col": "fine_label"}
    ],
    "merge_configs": {
        "pairwise_slerp": {"weights": [0.2, 0.3, 0.5]}
    },
    "base_model_checkpoint": "google/vit-base-patch16-224-in21k"
}

ex7c = TaskVectorExperiment(**ex7c_config)

In [None]:
ex7c.run()

In [None]:
ex7c.get_results_df().to_csv("ex7c_results.csv")

## Plotting Models on Hypersphere
This section of the notebook uses some functions in the `m3sb package`, so make sure you install it in the environment by following the above guide.

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from m3sb.utils import load_model, get_task_vector, flatten_state_dict, run_pca
import plotly.graph_objects as go

In [3]:
def plot_finetuned_models_on_sphere_3d(
    finetuned_models_3d: list[torch.Tensor],
    finetuned_labels: list[str],
    title
):
    """
    Plots only the finetuned model vectors on a 3D sphere.

    Args:
        finetuned_models_3d (list[torch.Tensor]): list of 3D tensors for finetuned models.
        finetuned_labels (list[str]): list of labels for the finetuned models.
        title (str): The title of the plot.
    """
    fig = go.Figure()

    # Create a wireframe sphere for visual context
    u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
    x_sphere = np.cos(u) * np.sin(v)
    y_sphere = np.sin(u) * np.sin(v)
    z_sphere = np.cos(v)
    fig.add_trace(go.Surface(
        x=x_sphere, y=y_sphere, z=z_sphere,
        opacity=0.1, showscale=False, colorscale=[[0, 'lightblue'], [1, 'lightblue']]
    ))

    # Loop directly through and plot only the finetuned models
    for label, vec_3d in zip(finetuned_labels, finetuned_models_3d):
        vec = vec_3d.numpy()
        # Normalize the vector to project it onto the unit sphere's surface
        vec_normalized = vec / np.linalg.norm(vec)
        x, y, z = vec_normalized

        # Define marker style for finetuned models
        marker_config = dict(symbol='circle', color='blue', size=6)

        # Add the model's point to the plot
        fig.add_trace(go.Scatter3d(
            x=[x], y=[y], z=[z],
            mode='markers+text',
            name=label,
            text=[label],
            textposition="bottom center",
            marker=marker_config
        ))

    # Update layout for a clean 3D plot
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title="PC 1",
            yaxis_title="PC 2",
            zaxis_title="PC 3",
            aspectmode='data' # Ensures the sphere looks like a sphere
        ),
        legend_title="Model",
        margin=dict(l=0, r=0, b=0, t=40),
        showlegend=True
    )

    fig.show()

In [None]:
base_model = load_model("google/vit-base-patch16-224-in21k")
model1 = load_model("pkr7098/cifar100-vit-base-patch16-224-in21k").state_dict()
model2 = load_model("nateraw/food").state_dict()
model3 = load_model("MaulikMadhavi/vit-base-flowers102").state_dict()
model4 = load_model("farleyknight-org-username/vit-base-mnist").state_dict()
model5 = load_model("Payoto/vit-base-patch16-224-in21k-finetuned-eurosat").state_dict()
model6 = load_model("akahana/vit-base-cats-vs-dogs").state_dict()
model7 = load_model("nateraw/vit-base-beans").state_dict()

In [None]:
task_vector1 = get_task_vector(base_model.state_dict(), model1)
task_vector2 = get_task_vector(base_model.state_dict(), model2)
task_vector3 = get_task_vector(base_model.state_dict(), model3)
task_vector4 = get_task_vector(base_model.state_dict(), model4)
task_vector5 = get_task_vector(base_model.state_dict(), model5)
task_vector6 = get_task_vector(base_model.state_dict(), model6)
task_vector7 = get_task_vector(base_model.state_dict(), model7)

In [None]:
base_model_flat = flatten_state_dict(base_model.state_dict())
model1_flat = flatten_state_dict(model1)
model2_flat = flatten_state_dict(model2)
model3_flat = flatten_state_dict(model3)
model4_flat = flatten_state_dict(model4)
model5_flat = flatten_state_dict(model5)
model6_flat = flatten_state_dict(model6)
model7_flat = flatten_state_dict(model7)

In [8]:
compressed_models = run_pca([
    model1_flat,
    model2_flat,
    model3_flat,
    model4_flat,
    model5_flat,
    model6_flat,
    model7_flat,
], n_components=3)

In [10]:
plot_finetuned_models_on_sphere_3d(compressed_models, ["CIFAR-100", "Food-101", "Flowers-102", "MNIST", "EuroSAT", "CatsVsDogs", "Beans"], title="Model Parameters on Hypersphere")