In [None]:
color_dict = {
    "Eggshell": "#f4f1de",
    "Burnt sienna": "#e07a5f",
    "Delft Blue": "#3d405b",
    "Cambridge blue": "#81b29a",
    "Sunset": "#f2cc8f",
}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Use LaTeX fonts for a professional look
plt.rcParams.update(
    {
        "text.usetex": True,  # Use LaTeX for text rendering
        "font.family": "serif",  # Use serif fonts
        "font.size": 18,  # Set default font size
    }
)

# Set seaborn style for better aesthetics
sns.set_style("whitegrid")

# Data
methods = ["Zeroshot", "Weight Averaging", "Task Vector", "Consensus TA", r"\textbf{TSV-M (Ours)}"]
tasks = ["8 tasks", "14 tasks", "20 tasks"]

accuracy_data = np.array(
    [
        [64.70, 68.20, 65.23],  # Zeroshot
        [79.56, 76.73, 71.60],  # Weight Averaging
        [84.93, 79.41, 74.01],  # Task Vector
        [86.34, 82.22, 79.00],  # Consensus TA
        [92.98, 89.17, 87.72],  # TSV-M (Ours)
    ]
)

n_tasks = len(tasks)
n_methods = len(methods)
x = np.arange(n_tasks)  # the label locations
total_width = 0.8  # total width for all bars at one x location
width = total_width / n_methods  # the width of each bar

fig, ax = plt.subplots(figsize=(14, 7))

# Set the colors using seaborn's deep color palette
# colors = sns.color_palette("deep", n_colors=n_methods)
colors = list(color_dict.values())

# Plot bars
for i in range(n_methods):
    ax.bar(
        x + (i - n_methods / 2 + 0.5) * width,
        accuracy_data[i],
        width=width,
        label=methods[i],
        color=colors[i],
        edgecolor="black",
    )

# Add the exact values on top of the bars
for i in range(n_methods):
    for j in range(n_tasks):
        ax.text(
            x[j] + (i - n_methods / 2 + 0.5) * width,
            accuracy_data[i, j] + 0.7,
            f"{accuracy_data[i, j]:.2f}",
            ha="center",
            va="bottom",
            fontsize=18,
        )

# Set labels and title
ax.set_ylabel("Accuracy (\%)", fontsize=24)
ax.set_xticks(x)
ax.set_xticklabels(tasks, fontsize=18)
ax.set_ylim(60, 100)

# Place the legend outside the plot
ax.legend(loc="center", bbox_to_anchor=(0.5, -0.15), fontsize=18, ncol=5)

# Add gridlines for better readability
ax.yaxis.grid(True, linestyle="--", which="major", color="grey", alpha=0.7)

# Adjust layout to accommodate the legend
plt.tight_layout()
plt.show()

In [None]:
# save fig as pdf

fig.savefig("main_results.pdf", bbox_inches="tight")

## PROCRUSTES ERROR

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Replace these with your actual JSON file paths
file_names = ["errors_lossless.json", "errors_TSV.json"]
map_file_name = {"errors_lossless.json": "full rank", "errors_TSV.json": "low rank"}

# Initialize lists to hold error values for each file
errors_u = []
errors_v = []
file_labels = []

# Loop over the files
for idx, file_name in enumerate(file_names):
    with open(file_name, "r") as f:
        data = json.load(f)

    # Extract error_u and error_v values
    error_u_values = []
    error_v_values = []
    for layer in data.values():
        error_u_values.append(layer["error_u"])
        error_v_values.append(layer["error_v"])

    errors_u.extend(error_u_values)
    errors_v.extend(error_v_values)

    # Create labels for the data
    file_labels.extend([map_file_name[file_name]] * len(error_u_values))

# Create a DataFrame for seaborn
df = pd.DataFrame({"Error U": errors_u, "Error V": errors_v, "File": file_labels})

# Define custom colors
color_dict = {
    "Eggshell": "#f4f1de",
    "Burnt sienna": "#e07a5f",
    "Delft Blue": "#3d405b",
    "Cambridge blue": "#81b29a",
    "Sunset": "#f2cc8f",
}

# Specify the order of the categories
file_order = ["full rank", "low rank"]

# Map colors to 'File' categories for each plot
palette_u = [color_dict["Burnt sienna"], color_dict["Cambridge blue"]]
palette_v = [color_dict["Delft Blue"], color_dict["Sunset"]]

# Set up the matplotlib figure
fig = plt.figure(figsize=(14, 7))
sns.set_context("talk", font_scale=1.5)  # Adjust the font scale as needed


# Violin plot for Error U
plt.subplot(1, 2, 1)
sns.violinplot(x="File", y="Error U", data=df, inner="quartile", order=file_order, palette=palette_u)
plt.title("Approximation error for $U$", fontsize=24)
plt.xlabel("")

# Violin plot for Error V
plt.subplot(1, 2, 2)
sns.violinplot(
    x="File",
    y="Error V",
    data=df,
    inner="quartile",
    order=file_order,
    palette=palette_v,
)
plt.title("Approximation error for $V$", fontsize=24)
plt.xlabel("")

plt.tight_layout()
plt.show()

In [None]:
# save pdf
fig.savefig("approximation_errors.pdf", bbox_inches="tight")