In [None]:
import wandb
import logging

pylogger = logging.getLogger(__name__)

In [None]:
from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np

api = wandb.Api()
entity, project = "crisostomi", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags):

    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
tags = ["matching", "pairwise", "final"]

In [None]:
runs = get_runs(entity, project, positive_tags=tags, negative_tags="git_rebasin")

In [None]:
model_pairs = [(1, 2), (2, 3), (1, 3)]
all_seeds = range(1, 5)
matchers = ["git_rebasin", "frank_wolfe"]

In [None]:
exps = {matcher: {pair: {seed: {} for seed in all_seeds} for pair in model_pairs} for matcher in matchers}
print(exps)

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

matcher_key = "matching/matcher/_target_"

alternating_diff_classname = "ccmm.matching.matcher.AlternatingDiffusionMatcher"
gitrebasin_classname = "ccmm.matching.matcher.GitRebasinMatcher"
quadratic_classname = "ccmm.matching.matcher.QuadraticMatcher"
frankwolfe_classname = "ccmm.matching.matcher.FrankWolfeMatcher"

model_key = "model/name"
matcher_mapping = {
    alternating_diff_classname: "alternating_diffusion",
    gitrebasin_classname: "git_rebasin",
    quadratic_classname: "quadratic",
    frankwolfe_classname: "frank_wolfe",
}

In [None]:
model = "MLP"

In [None]:
def remove_nones(array):
    return np.array([x for x in array if x is not None])

In [None]:
# api.artifact(f'run-{run.id}-train_acc_interpolations_table:v0', type='table')

for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    if cfg[model_key] != model:
        continue

    seed = cfg[seed_key]
    model_pair = cfg[model_pair_key]

    matcher_classname = cfg[matcher_key]
    matcher_mapped = matcher_mapping[matcher_classname]

    # runtime = run.summary.get("_runtime")
    # steps = run.summary.get("trainer/global_step")

    hist = run.scan_history()

    train_acc_curve = remove_nones(np.array([row["train_acc"] for row in hist if "train_acc" in row]))
    test_acc_curve = remove_nones(np.array([row["test_acc"] for row in hist if "test_acc" in row]))

    train_loss_curve = remove_nones(np.array([row["train_loss"] for row in hist if "train_loss" in row]))
    test_loss_curve = remove_nones(np.array([row["test_loss"] for row in hist if "test_loss" in row]))

    test_loss_barrier = run.history(keys=["test_loss_barrier"])["test_loss_barrier"][0]
    train_loss_barrier = run.history(keys=["train_loss_barrier"])["train_loss_barrier"][0]

    exps[matcher_mapped][tuple(model_pair)][seed] = {
        "train_acc_curve": train_acc_curve,
        "test_acc_curve": test_acc_curve,
        "train_loss_curve": train_loss_curve,
        "test_loss_curve": test_loss_curve,
        "test_loss_barrier": test_loss_barrier,
        "train_loss_barrier": train_loss_barrier,
    }

In [None]:
exps

In [None]:
import pandas as pd
import plotly.express as px

# Assuming 'data' is your dictionary containing the results

# Convert dictionary to DataFrame
records = []
for algo, algo_data in exps.items():
    for pair, pair_data in algo_data.items():
        for seed, metrics in pair_data.items():
            if metrics:  # Check if metrics are not empty
                record = {
                    "algorithm": algo,
                    "pair": pair,
                    "seed": seed,
                    "accuracy": metrics["acc"],
                    "loss": metrics["loss"],
                }
                records.append(record)

df = pd.DataFrame(records)

In [None]:
df

In [None]:
# Calculate mean accuracy and loss for each algorithm
mean_metrics = df.groupby(["algorithm", "pair"]).mean().reset_index()
mean_metrics.head(n=6)

In [None]:
# Convert dictionary to DataFrame
records = []
for algo, algo_data in exps.items():
    for pair, pair_data in algo_data.items():
        for seed, metrics in pair_data.items():
            if metrics:  # Check if metrics are not empty
                record = {
                    "algorithm": algo,
                    "pair_seed": f"Pair {pair} - seed {seed}",  # Combining pair and seed
                    "accuracy": metrics["acc"],
                    "loss": metrics["loss"],
                }
                records.append(record)

df = pd.DataFrame(records)

# Plotting
# Accuracy Bar Plot
fig_acc = px.bar(
    df, x="pair_seed", y="accuracy", color="algorithm", barmode="group", title="Accuracy by Algorithm, Pair, and seed"
)
fig_acc.show()

# Loss Bar Plot
fig_loss = px.bar(
    df, x="pair_seed", y="loss", color="algorithm", barmode="group", title="Loss by Algorithm, Pair, and seed"
)
fig_loss.show()

In [None]:
for matcher in matcher_mapping.values():
    mean_approach = df[df["algorithm"] == matcher]["accuracy"].mean()
    var_approach = df[df["algorithm"] == matcher]["accuracy"].var()

    print(f"{matcher} diff: mean {mean_approach}, var {var_approach}")

In [None]:
# Convert dictionary to DataFrame
records = []
for algo, algo_data in exps.items():
    for pair, pair_data in algo_data.items():
        for seed, metrics in pair_data.items():
            if metrics:  # Check if metrics are not empty
                record = {
                    "algorithm": algo,
                    "pair": f"{pair[0]}-{pair[1]}",
                    "seed": seed,
                    "accuracy": metrics["acc"],
                    "loss": metrics["loss"],
                }
                records.append(record)

df = pd.DataFrame(records)

# Sort by 'pair' and 'seed'
df["sort_key"] = df["pair"] + "-Seed" + df["seed"].astype(str)
df.sort_values(by="sort_key", inplace=True)

# Pivot the DataFrame to calculate differences
pivot_df = df.pivot_table(index="sort_key", columns="algorithm", values="accuracy")

pivot_df["accuracy_diff"] = (
    pivot_df.iloc[:, 0] - pivot_df.iloc[:, 1]
)  # Assuming the first column is 'alternating_diffusion'

pivot_df.reset_index(inplace=True)

total_diff = pivot_df["accuracy_diff"].mean()
total_diff_row = pd.DataFrame(
    [{"sort_key": "Total", "accuracy_diff": total_diff, "color": "green" if total_diff > 0 else "red"}]
)

# Concatenate the total difference row to the existing DataFrame
pivot_df = pd.concat([pivot_df, total_diff_row], ignore_index=True)


# Determine the color based on which algorithm performs better
pivot_df["color"] = pivot_df["accuracy_diff"].apply(lambda x: "green" if x > 0 else "red")

# Prepare data for plotting
plot_data = pivot_df[["sort_key", "accuracy_diff", "color"]]

# # Manually insert space every 4 positions in x-axis labels
# plot_data['x_label'] = plot_data['sort_key']
# for i in range(3, len(plot_data['x_label']), 4):
#     plot_data['x_label'].iloc[i] = ''  # Inserting empty string to create a gap

# Plotting the differences
fig = px.bar(plot_data, x="sort_key", y="accuracy_diff", color="color", title="Performance Difference in Accuracy")
fig.update_xaxes(type="category")  # Setting x-axis as category type
fig.update_layout(
    xaxis={"categoryorder": "array", "categoryarray": plot_data["sort_key"]}
)  # Explicitly setting the order
fig.show()

In [None]:
import plotly.graph_objects as go

# Create the figure
fig = go.Figure()

# Track the current index for spacing
current_index = 0

# Iterate over the DataFrame to add bars individually
for index, row in plot_data.iterrows():
    # Add a bar for each data point
    fig.add_trace(
        go.Bar(x=[row["sort_key"]], y=[row["accuracy_diff"]], marker_color=row["color"], name=row["sort_key"])
    )

    # Increment index
    current_index += 1

    # Add an invisible bar (for spacing) every 4 bars
    if current_index % 4 == 0:
        fig.add_trace(
            go.Bar(
                x=[f"Space-{current_index // 4}"],
                y=[None],
                marker=dict(color="rgba(255, 255, 255, 0)"),  # Invisible bar
                showlegend=False,
            )
        )

# Update the layout to adjust the bar width and space between bars
fig.update_traces(marker_line_width=1.5, width=0.4)  # Adjust the bar width as needed
fig.update_layout(
    title="Performance Difference in Accuracy",
    xaxis_title="Pair-Seed",
    yaxis_title="Accuracy Difference",
    barmode="group",
)

# Show the figure
fig.show()