In [None]:
import wandb
import logging

pylogger = logging.getLogger(__name__)

In [None]:
from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objs as go

api = wandb.Api()
entity, project = "gladia", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags=None):
    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
tags = ["merge_n_models", "8x", "cifar100"]  # 2x, 4x, 8x, cifar100, vgg

In [None]:
runs = get_runs(entity, project, positive_tags=tags)

In [None]:
mergers = ["frank_wolfe", "git_rebasin", "naive"]

In [None]:
exps = {merger: {"repaired": {}, "untouched": {}} for merger in mergers}
print(exps)

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

merger_key = "matching/merger/_target_"

gitrebasin_classname = "ccmm.matching.merger.GitRebasinMerger"
frankwolfe_classname = "ccmm.matching.merger.FrankWolfeSynchronizedMerger"
naive_classname = "ccmm.matching.merger.DummyMerger"

model_key = "model/name"
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    naive_classname: "naive",
}

## Collect runs

In [None]:
for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    if "merged" in cfg["core/tags"]:
        repaired_key = "untouched"
    elif "repaired" in cfg["core/tags"]:
        repaired_key = "repaired"
    else:
        pylogger.warning("Run is neither merged nor repaired, skipping")
        continue

    seed = cfg[seed_key]
    model_pair = cfg[model_pair_key]

    merger_mapped = merger_mapping[cfg[merger_key]]

    hist = run.scan_history()

    train_acc = run.history(keys=["acc/train"])["acc/train"][0]
    test_acc = run.history(keys=["acc/test"])["acc/test"][0]

    train_loss = run.history(keys=["loss/train"])["loss/train"][0]
    test_loss = run.history(keys=["loss/test"])["loss/test"][0]

    exps[merger_mapped][repaired_key] = {
        "train_acc": train_acc,
        "test_acc": test_acc,
        "train_loss": train_loss,
        "test_loss": test_loss,
    }

In [None]:
exps

In [None]:
records = []

for merger_name, merger_repaired_data in exps.items():
    for repaired_flag, metrics in merger_repaired_data.items():
        if metrics:
            record = {
                "merger": merger_name + "_" + repaired_flag,
                "train_acc": metrics["train_acc"],
                "test_acc": metrics["test_acc"],
                "train_loss": metrics["train_loss"],
                "test_loss": metrics["test_loss"],
            }

            records.append(record)

df = pd.DataFrame(records)

In [None]:
df

In [None]:
matcher_to_latex_map = {
    "frank_wolfe_repaired": r"\texttt{Frank-Wolfe}$^\dagger$",
    "git_rebasin_repaired": r"\texttt{Git-Rebasin}$^\dagger$",
    "naive_untouched": r"\texttt{Naive}",
    "naive_repaired": r"\texttt{Naive}$^\dagger$",
    "frank_wolfe_untouched": r"\texttt{Frank-Wolfe}",
    "git_rebasin_untouched": r"\texttt{Git-Rebasin}",
}

ordering = [
    "naive_untouched",
    "naive_repaired",
    "git_rebasin_untouched",
    "git_rebasin_repaired",
    "frank_wolfe_untouched",
    "frank_wolfe_repaired",
]

df["merger"] = pd.Categorical(df["merger"], ordering)
df.sort_values(by="merger", ascending=True, inplace=True)

In [None]:
df

In [None]:
import seaborn as sns

# cmap = "coolwarm"
cmap = sns.light_palette("seagreen", as_cmap=True)
cmap_reverse = sns.light_palette("seagreen", as_cmap=True, reverse=True)
# cmap = adjust_cmap_alpha(cmap, alpha=1)
# cmap = sns.color_palette("vlag", as_cmap=True)

In [None]:
from ccmm.utils.plot import decimal_to_rgb_color

max_loss_value = 6.0

header = r"""
\begin{table}
    \begin{center}
        \begin{tabular}{lccc}
        \toprule
        \textbf{Matcher}        & \multicolumn{2}{c}{\textbf{Barrier}}                   \\
                                & \textbf{Train}                       & \textbf{Test}   \\
        \midrule
        """


body = ""

for row in df.iterrows():
    row = row[1]
    merger = row["merger"]

    if merger == "naive_repaired":
        continue

    test_acc = row["test_acc"]
    train_acc = row["train_acc"]
    test_loss = row["test_loss"]
    train_loss = row["train_loss"]

    test_acc_col = decimal_to_rgb_color(test_acc, cmap)[:3]
    train_acc_col = decimal_to_rgb_color(train_acc, cmap)[:3]
    test_loss_col = decimal_to_rgb_color(test_loss / max_loss_value, cmap_reverse)[:3]
    train_loss_col = decimal_to_rgb_color(train_loss / max_loss_value, cmap_reverse)[:3]

    col_and_val = lambda color, value: f"\\cellcolor[rgb]{{{color}}}{value:.2f}"

    body += f"""
                & {matcher_to_latex_map[merger]} &  {col_and_val(train_acc_col, train_acc)} & {col_and_val(test_acc_col, test_acc)} & {col_and_val(train_loss_col, train_loss)} & {col_and_val(test_loss_col, test_loss)} \\\\""".replace(
        "(", ""
    ).replace(
        ")", ""
    )

footer = r"""
        \bottomrule
        \end{tabular}
    \end{center}
    \caption{Mean and standard deviation of the test and train loss barrier for each matcher.}
    \label{tab:MLP_loss_barrier}
\end{table}"""

table = header + body + footer

In [None]:
print(table)