## Imports

In [None]:
import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ccmm.utils.plot import Palette
from nn_core.common import PROJECT_ROOT

In [None]:
palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True

## Merge using arbitrary model as reference point

In [None]:
test_accs = [78.55, 65.32, 78.81, 77.72, 78.61]
train_accs = [82.05, 67.28, 82.63, 81.36, 82.02]

train_accs = np.array(train_accs)
test_accs = np.array(test_accs)

print(train_accs.mean(), train_accs.std())
print(test_accs.mean(), test_accs.std())

## Matching accuracies for different seeds on git re-basin

### Table

In [None]:
train_accs = {
    (1, 2): np.array(
        [
            0.7619400024414062,
            0.7817599773406982,
            0.7826399803161621,
            0.7992200255393982,
            0.7707399725914001,
            0.7551400065422058,
            0.7847200036048889,
            0.7534199953079224,
            0.8099200129508972,
        ]
    ),
    (1, 3): np.array(
        [
            0.6714800000190735,
            0.6920400261878967,
            0.6909800171852112,
            0.6874200105667114,
            0.624239981174469,
            0.6920199990272522,
            0.6646000146865845,
            0.7092000246047974,
            0.6822999715805054,
        ]
    ),
    (2, 3): np.array(
        [
            0.7508599758148193,
            0.7400599718093872,
            0.7481200098991394,
            0.7238600254058838,
            0.7573999762535095,
            0.741919994354248,
            0.7041199803352356,
            0.7325999736785889,
            0.7776399850845337,
        ]
    ),
}

test_accs = {
    (1, 2): np.array(
        [
            0.727400004863739,
            0.7450000047683716,
            0.7450000047683716,
            0.765500009059906,
            0.7368999719619751,
            0.7258999943733215,
            0.7419000267982483,
            0.7160000205039978,
            0.7760000228881836,
        ]
    ),
    (1, 3): np.array(
        [
            0.6434999704360962,
            0.6574000120162964,
            0.6657999753952026,
            0.6507999897003174,
            0.6031000018119812,
            0.6621999740600586,
            0.6305000185966492,
            0.6744999885559082,
            0.6492999792098999,
        ]
    ),
    (2, 3): np.array(
        [
            0.7014999985694885,
            0.7060999870300293,
            0.711899995803833,
            0.6819000244140625,
            0.7226999998092651,
            0.7037000060081482,
            0.6658999919891357,
            0.6990000009536743,
            0.7368999719619751,
        ]
    ),
}

In [None]:
current_tuple = (2, 3)

In [None]:
latex_row_str = ""
for train_acc in train_accs[current_tuple]:
    latex_row_str += f"{train_acc:.2f} & "

latex_row_str += (
    f"{train_accs[current_tuple].mean():.2f}"
    + " & "
    + f"{train_accs[current_tuple].std():.3f}"
    + f" & {train_accs[current_tuple].max() - train_accs[current_tuple].min():.3f}"
    "\\\\ \n"
)

for test_acc in test_accs[current_tuple]:
    latex_row_str += f"{test_acc:.2f} & "

latex_row_str += (
    f"{test_accs[current_tuple].mean():.2f}"
    + " & "
    + f"{test_accs[current_tuple].std():.3f}"
    + f" & {test_accs[current_tuple].max() - test_accs[current_tuple].min():.3f}"
    "\\\\"
)

In [None]:
print(latex_row_str)

### Plot

In [None]:
# Seed-wise data from the provided table
seeds = list(range(1, 10))
train_accuracy = [0.76, 0.78, 0.78, 0.80, 0.77, 0.76, 0.78, 0.75, 0.81]
test_accuracy = [0.73, 0.75, 0.75, 0.77, 0.74, 0.73, 0.74, 0.72, 0.78]

# Frank-Wolfe results are consistent across seeds
fw_train_accuracy = [0.78] * 9  # Constant for all seeds
fw_test_accuracy = [0.75] * 9  # Constant for all seeds

In [None]:
import matplotlib.pyplot as plt

# Create a new figure
plt.figure()

# Plotting Git Re-Basin train and test accuracy lines
plt.plot(seeds, train_accuracy, marker="o", linestyle="-", color=palette["light red"], label="Git Re-Basin - train")
plt.plot(seeds, test_accuracy, marker="o", linestyle="-", color=palette["green"], label="Git Re-Basin - test")

# Plotting Frank-Wolfe train and test accuracy lines
plt.plot(seeds, fw_train_accuracy, linestyle="--", color=palette["light red"], label="Frank-Wolfe - train")
plt.plot(seeds, fw_test_accuracy, linestyle="--", color=palette["green"], label="Frank-Wolfe - test")

# Adding labels and title
plt.xlabel("Seed")
plt.ylabel("Accuracy")
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)

plt.savefig("figures/git-re-basin-variance.pdf", bbox_inches="tight")

plt.show()