## Imports

In [None]:
import wandb
import logging
from nn_core.common import PROJECT_ROOT

from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objs as go
import plotly.io as pio

import matplotlib.pyplot as plt

In [None]:
from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")

plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)

pylogger = logging.getLogger(__name__)
palette

## Get runs

In [None]:
api = wandb.Api()
entity, project = "gladia", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags=None):
    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
considered_model = "ResNet"  # ResNet or MLP

In [None]:
considered_model_tag = "4x" if considered_model == "ResNet" else "mlp"
tags = ["scaling", considered_model_tag]

In [None]:
runs = get_runs(entity, project, positive_tags=tags)  # negative_tags=["git_rebasin"])

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

merger_key = "matching/merger/_target_"

gitrebasin_classname = "ccmm.matching.merger.GitRebasinMerger"
frankwolfe_classname = "ccmm.matching.merger.FrankWolfeSynchronizedMerger"
naive_classname = "ccmm.matching.merger.DummyMerger"
git_rebasin_pairwise_classname = "ccmm.matching.merger.GitRebasinPairwiseMerger"

model_key = "model/name"
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    git_rebasin_pairwise_classname: "git_rebasin_pairwise",
    naive_classname: "naive",
}
mergers = ["git_rebasin", "frank_wolfe", "git_rebasin_pairwise"]

In [None]:
max_num_models = 20
exps = {merger: [{} for i in range(max_num_models + 1)] for merger in mergers}

## Collect runs

In [None]:
for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    num_models = len(cfg["matching/model_seeds"])

    model_name = cfg[model_key]
    merger = cfg[merger_key]

    hist = run.scan_history()
    merger_mapped = merger_mapping[cfg[merger_key]]

    train_acc = run.history(keys=["acc/train"])["acc/train"][0]
    test_acc = run.history(keys=["acc/test"])["acc/test"][0]

    train_loss = run.history(keys=["loss/train"])["loss/train"][0]
    test_loss = run.history(keys=["loss/test"])["loss/test"][0]

    exps[merger_mapped][num_models] = {
        "train_acc": train_acc,
        "test_acc": test_acc,
        "train_loss": train_loss,
        "test_loss": test_loss,
        "runtime": run.summary["_runtime"],
    }

In [None]:
exps

In [None]:
# plot train and test accuracies
records = []

# exps has structure {merger_name: [ {'acc:' acc(1, 2), 'loss': loss(1,2)}, {'acc': acc(1, 2,3), 'loss': loss(1, 2,3)}, ...], ...}]}
# where acc(1, 2, 3) is the accuracy of the model merged from seeds 1, 2, 3
for merger_name, merger_data in exps.items():
    for results in merger_data:
        if len(results) == 0:
            continue

        record = {
            "merger": merger_name,
            "train_acc": results["train_acc"],
            "test_acc": results["test_acc"],
            "train_loss": results["train_loss"],
            "test_loss": results["test_loss"],
            "runtime": results["runtime"],
        }

        records.append(record)


df = pd.DataFrame(records)

In [None]:
df

In [None]:
merger_subset = {"frank_wolfe", "git_rebasin"}

In [None]:
merger_dfs = {merger: df[df["merger"] == merger] for merger in merger_subset}

In [None]:
for merger_df in merger_dfs.values():
    merger_df.index = range(2, len(merger_df) + 2)

In [None]:
pretty_metric = {
    "acc": "Accuracy",
    "loss": "Loss",
}

color_map = {
    "git_rebasin": {
        "train": palette["light red"],
        "test": palette["light red"],
    },
    "frank_wolfe": {
        "train": palette["green"],
        "test": palette["green"],
    },
    "git_rebasin_pairwise": {
        "train": palette["dark blue"],
        "test": palette["dark blue"],
    },
}


dash_map = {
    "git_rebasin": {
        "train": "-",
        "test": ":",
    },
    "frank_wolfe": {
        "train": "-",
        "test": ":",
    },
    "git_rebasin_pairwise": {
        "train": "-",
        "test": ":",
    },
}

merger_map = {
    "git_rebasin": r"MergeMany",
    "frank_wolfe": r"$C^2M^3$",
    "git_rebasin_pairwise": r"MergePairwise",
}

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(6, 3))

for merger_name, merger_df in merger_dfs.items():
    for metric_ind, metric in enumerate(["acc", "loss"]):
        ax = axes[metric_ind]
        show_legend = True if metric_ind == 0 else False

        train_label = merger_map[merger_name] + " (train)"
        test_label = merger_map[merger_name] + " (test)"

        ax.plot(
            merger_df.index,
            merger_df[f"train_{metric}"],
            linestyle=dash_map[merger_name]["train"],
            color=color_map[merger_name]["train"],
            label=train_label,
        )
        ax.plot(
            merger_df.index,
            merger_df[f"test_{metric}"],
            linestyle=dash_map[merger_name]["test"],
            color=color_map[merger_name]["test"],
            label=test_label,
        )

        ax.set_title(pretty_metric[metric])
        ax.set_xlabel("Number of models")
        ax.set_ylabel(metric.capitalize())

# Adjust legend and layout
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center", bbox_to_anchor=(0.55, -0.2), ncol=2)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(f"figures/scaling_exp_{model_name}.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# one single plot with df index as x axis and runtime as y

fig, ax = plt.subplots(1, 1, figsize=(6, 3))

for merger_name in merger_subset:
    merger_df = merger_dfs[merger_name]
    ax.plot(
        merger_df.index,
        merger_df["runtime"],
        linestyle="-",
        color=color_map[merger_name]["train"],
        label=merger_map[merger_name],
    )

plt.legend()
ax.set_title("Runtime")
ax.set_xlabel("Number of models")
# set x to have integer values
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
ax.set_ylabel("Runtime (s)")
plt.tight_layout(rect=[0, 0, 1, 0.95])

plt.savefig(f"figures/runtime_exp_{model_name}.pdf", format="pdf", bbox_inches="tight")

In [None]:
import matplotlib.pyplot as plt

# Your data (assuming these are lists)
num_models = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
C2M3 = [40.0, 87.2, 95.2, 142.9, 278.15, 275.0, 289.59, 398.26, 524.09, 735.81]
MM = [2.53, 16.5, 28.0, 28.9, 63.88, 52.97, 85.39, 66.4, 167.88, 113.62]

# Create the plot
plt.figure(figsize=(10, 6))  # Optional: Adjust figure size for better readability
plt.plot(num_models, C2M3, marker="o", linestyle="-", color="blue", label="C2M3")
plt.plot(num_models, MM, marker="x", linestyle="--", color="green", label="MM")

# Customize the plot
plt.xlabel("Number of Models")
plt.ylabel("Runtime")
plt.title("Runtime vs. Number of Models for C2M3 and MM")
plt.legend()
plt.grid(axis="y", linestyle="--")  # Adds a subtle grid to the y-axis

# Display the plot
plt.show()