In [None]:
from nn_core.common import PROJECT_ROOT
import json
import plotly.graph_objects as go
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
colors = {
    "charchoal": "#264653",
    "persian_green": "#2A9D8F",
    "saffron": "#E9C46A",
    "sandy_brown": "#F4A261",
    "burnt_sienna": "#e76f51",
}

## Experiment 2: Same classes -- disj samples

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_2_classification_results.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
datasets = ["cifar100", "tiny_imagenet"]
models = ["from_scratch_cnn", "efficient_net"]
modalities = ["end-to-end", "merged", "jumble"]
modality_names = {"end-to-end": "end-to-end", "merged": "relative", "jumble": "naive"}

end_to_end_results = {
    "cifar100": {"efficient_net": 0.7043, "from_scratch_cnn": 0.3933},
    "tiny_imagenet": {
        "efficient_net": 0.6863,
        "from_scratch_cnn": 0.2222,
    },
}

In [None]:
for dataset in datasets:
    for model in models:
        data[dataset][model]["end-to-end"] = {"total_acc": end_to_end_results[dataset][model]}

In [None]:
def grouped_barplot_accuracy(data):
    # Set color for each category
    modality_colors = {
        "end-to-end": colors["persian_green"],
        "merged": colors["saffron"],
        "jumble": colors["burnt_sienna"],
    }

    fig = go.Figure()

    positions = {category: i for i, category in enumerate(modalities)}

    added_legends = set()  # To keep track of which legends have been added
    x_labels = []
    counter = 0  # Count to control positioning

    for dataset in datasets:
        for model in models:
            x_label = f"{model}"
            x_labels.append(x_label)

            for modality in modalities:

                value = data[dataset][model][modality]["total_acc"]
                # Show legend only once for each category
                show_legend = modality not in added_legends
                added_legends.add(modality)
                fig.add_trace(
                    go.Bar(
                        x=[counter],
                        y=[value],
                        name=modality_names[modality],
                        width=0.15,
                        offset=positions[modality] * 0.18,
                        marker_color=modality_colors[modality],
                        showlegend=show_legend,
                        text=f"{value:.2f}",  # Display the value on top of the bars
                        textposition="outside",  # Position the text on top of the bars
                    )
                )
            counter += 1

    # Add dataset annotations in the middle of model groups
    fig.add_annotation(
        text="CIFAR100",
        xref="x",
        yref="paper",
        x=0.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )
    fig.add_annotation(
        text="TINY IMAGENET",
        xref="x",
        yref="paper",
        x=2.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )

    fig.update_layout(
        # title={
        #     'text': "Grouped Barplot of Model Accuracies",
        #     'font': {'size': 24}
        # },
        yaxis_title="Accuracy",
        barmode="group",
        bargap=0.15,
        bargroupgap=0.1,
        # legend_title={
        #     'text': "Training Type",
        #     'font': {'size': 20}
        # },
        legend=dict(
            x=0.9, y=1, font=dict(size=28)
        ),  # Adjust these values to position the legend inside the plot and adjust font size
        xaxis=dict(
            tickvals=[0.27, 1.27, 2.27, 3.27],
            ticktext=x_labels,
            range=[-0.25, 3.9],  # Adjusting the range to limit the inner padding
            title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
            tickfont=dict(size=28),  # Adjusting the tick labels font size
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        yaxis=dict(
            title_font=dict(size=28),  # Adjusting the y-axis title font size
            tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
            range=[0, 1],
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        height=700,
        margin=dict(l=0, r=0, t=50, b=100),  # to ensure dataset annotations have enough space
        font=dict(size=28),  # This sets a global font size for the plot
        plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
        paper_bgcolor="rgba(255,255,255)",
    )

    return fig

In [None]:
fig = grouped_barplot_accuracy(data)
fig.show()

dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP2_accuracy.png", format="png", width=width_pixels, scale=2
)

## Experiment 3: Totally disjoint

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_3_classification_results.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
end_to_end_results = {
    "cifar100": {"efficient_net": 0.6977, "from_scratch_cnn": 0.4001},
    "tiny_imagenet": {
        "efficient_net": 0.6903,
        "from_scratch_cnn": 0.2233,
    },
}

modalities = ["end-to-end", "merged", "jumble", "task_embed_abs"]
modality_names = {
    "end-to-end": "end-to-end",
    "merged": "relative",
    "jumble": "naive",
    "task_embed_abs": "task embedding",
}

In [None]:
for dataset in datasets:
    for model in models:
        data[dataset][model]["1"]["end-to-end"] = {"total_acc": end_to_end_results[dataset][model]}

In [None]:
def grouped_barplot_accuracy(data):
    # Set color for each category
    modality_colors = {
        "end-to-end": colors["persian_green"],
        "merged": colors["saffron"],
        "jumble": colors["burnt_sienna"],
        "task_embed_abs": colors["charchoal"],
    }

    fig = go.Figure()

    positions = {category: i for i, category in enumerate(modalities)}

    added_legends = set()  # To keep track of which legends have been added
    x_labels = []
    counter = 0  # Count to control positioning

    for dataset in datasets:
        for model in models:
            x_label = f"{model}"
            x_labels.append(x_label)

            for modality in modalities:

                value = data[dataset][model]["1"][modality]["total_acc"]
                # Show legend only once for each category
                show_legend = modality not in added_legends
                added_legends.add(modality)
                fig.add_trace(
                    go.Bar(
                        x=[counter],
                        y=[value],
                        name=modality_names[modality],
                        width=0.15,
                        offset=positions[modality] * 0.18,
                        marker_color=modality_colors[modality],
                        showlegend=show_legend,
                        text=f"{value:.2f}",  # Display the value on top of the bars
                        textposition="outside",  # Position the text on top of the bars
                    )
                )
            counter += 1

    # Add dataset annotations in the middle of model groups
    fig.add_annotation(
        text="CIFAR100",
        xref="x",
        yref="paper",
        x=0.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )
    fig.add_annotation(
        text="TINY IMAGENET",
        xref="x",
        yref="paper",
        x=2.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )

    fig.update_layout(
        # title={
        #     'text': "Grouped Barplot of Model Accuracies",
        #     'font': {'size': 24}
        # },
        yaxis_title="Accuracy",
        barmode="group",
        bargap=0.15,
        bargroupgap=0.1,
        # legend_title={
        #     'text': "Training Type",
        #     'font': {'size': 20}
        # },
        legend=dict(
            x=1, y=1, font=dict(size=28)
        ),  # Adjust these values to position the legend inside the plot and adjust font size
        xaxis=dict(
            tickvals=[0.3, 1.3, 2.3, 3.3],
            ticktext=x_labels,
            range=[-0.25, 3.9],  # Adjusting the range to limit the inner padding
            title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
            tickfont=dict(size=28),  # Adjusting the tick labels font size
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        yaxis=dict(
            title_font=dict(size=28),  # Adjusting the y-axis title font size
            tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
            range=[0, 1],
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        height=700,
        margin=dict(l=0, r=0, t=50, b=100),  # to ensure dataset annotations have enough space
        font=dict(size=28),  # This sets a global font size for the plot
        plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
        paper_bgcolor="rgba(255,255,255)",
    )

    return fig

In [None]:
fig = grouped_barplot_accuracy(data)
fig.show()

dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP3_accuracy.png", format="png", width=width_pixels, scale=2
)

## Separability

In [None]:
palette = {
    "end-to-end": colors["persian_green"],
    "ours": colors["saffron"],
    "naive": colors["burnt_sienna"],
}

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_1_separability_analysis.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
import pandas as pd

rows = []

names_mapping = {
    "mean_separability_ours": "ours",
    "mean_separability_naive": "naive",
    "mean_separability_original": "end-to-end",
}

# Iterate through the JSON data to flatten it
for dataset, dataset_data in data.items():
    for model, model_data in dataset_data.items():
        for S_value, S_data in model_data.items():
            for N_value, N_data in S_data.items():
                for method, value in N_data.items():
                    rows.append(
                        {
                            "dataset": dataset,
                            "model": model,
                            "S_value": S_value,
                            "N_value": N_value,
                            "method": names_mapping[method],
                            "mean_separability": value,
                        }
                    )

# Convert list of rows to DataFrame
df = pd.DataFrame(rows)

print(df)

In [None]:
import plotly.graph_objects as go

N = 10
# Filter the DataFrame for only rows where N_value is N10
df_N10 = df[df["N_value"] == f"N{N}"]

# Create a combined column for method and model for plotting
df_N10["method_model"] = df_N10["method"] + " - " + df_N10["model"]

# Order the categories for alternate plotting
ordered_categories = [
    "ours - efficient_net",
    "naive - efficient_net",
    "end-to-end - efficient_net",
    "ours - vanilla_cnn",
    "naive - vanilla_cnn",
    "end-to-end - vanilla_cnn",
]
df_N10["method_model"] = pd.Categorical(df_N10["method_model"], categories=ordered_categories, ordered=True)

# Sort dataframe for the ordered plotting
df_N10 = df_N10.sort_values(by="method_model")

# Create the bar plot
fig = go.Figure()

# Keep track of methods added to legend
added_to_legend = set()

# Iterate over ordered categories to plot bars
for index, method_model in enumerate(ordered_categories):
    subset = df_N10[df_N10["method_model"] == method_model]
    method = method_model.split(" - ")[0]

    # If method already added to legend, skip its addition for subsequent traces
    show_in_legend = method not in added_to_legend
    added_to_legend.add(method)

    fig.add_trace(
        go.Bar(
            x=subset["S_value"],
            y=subset["mean_separability"],
            name=method,
            marker_color=palette[method],
            text=subset["mean_separability"].round(2),  # <-- Add text here
            textposition="outside",  # <-- Specify text position here
            legendgroup=method,  # Group by method in the legend
            showlegend=show_in_legend,  # Only show the method in the legend once
        )
    )


fig.add_annotation(
    text="EfficientNet",
    xref="x",
    yref="paper",
    x=-0.2,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="VanillaCNN",
    xref="x",
    yref="paper",
    x=0.2,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="EfficientNet",
    xref="x",
    yref="paper",
    x=0.8,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="VanillaCNN",
    xref="x",
    yref="paper",
    x=1.2,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="EfficientNet",
    xref="x",
    yref="paper",
    x=1.8,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="VanillaCNN",
    xref="x",
    yref="paper",
    x=2.2,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="EfficientNet",
    xref="x",
    yref="paper",
    x=2.8,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="VanillaCNN",
    xref="x",
    yref="paper",
    x=3.2,
    y=-0.15,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="S",
    xref="x",
    yref="paper",
    x=-0.5,
    y=-0.3,
    showarrow=False,
    font=dict(size=24, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


vals = [80, 60, 40, 20]
for ind, val in enumerate(vals):
    fig.add_annotation(
        text=str(val),
        xref="x",
        yref="paper",
        x=ind,
        y=-0.3,
        showarrow=False,
        font=dict(size=24, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )

# Update the layout and show
fig.update_layout(
    barmode="group",
    bargap=0.2,
    bargroupgap=0.1,
    legend_title_text="",
    margin=dict(l=0, r=0, t=50, b=100),  # to ensure dataset annotations have enough space
    font=dict(size=28),  # This sets a global font size for the plot
    plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
    paper_bgcolor="rgba(255,255,255)",
    legend=dict(
        x=0.9, y=1, font=dict(size=28)
    ),  # Adjust these values to position the legend inside the plot and adjust font size
    xaxis=dict(
        tickvals=[0.27, 1.27, 2.27, 3.27],
        # ticktext=x_labels,
        ticktext=["", "", "", ""],
        range=[-0.5, 3.9],  # Adjusting the range to limit the inner padding
        title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
        tickfont=dict(size=28),  # Adjusting the tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
        gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
        gridwidth=1,  # Adjust the x-axis gridline width
        zeroline=False,
    ),
    yaxis=dict(
        title_font=dict(size=28),  # Adjusting the y-axis title font size
        tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
        range=[0, 2.5],
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
        gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
        gridwidth=1,  # Adjust the x-axis gridline width
        zeroline=False,
    ),
)

fig.show()

In [None]:
# import plotly.express as px

# # Assuming df is your DataFrame
# fig = px.line(
#     df,
#     x='method',
#     y='mean_separability',
#     color='dataset',       # Differentiates lines by dataset
#     line_dash='model',     # Differentiates lines by model using dashes
#     markers=True,          # Adds markers to the line for each data point
#     title='Comparison of Methods across Datasets and Models',
#     labels={'mean_separability': 'Mean Separability'}
# )

# fig.show()

In [None]:
# df_cifar = df[df['dataset'] == 'cifar100'].copy()  # Create a copy to avoid SettingWithCopyWarning

# label_mapping = {}

# for column in ['model', 'S_value', 'N_value', 'method']:
#     # Update categories for 'S_value' and 'N_value' after subsetting
#     df_cifar[column] = df_cifar[column].astype('category')

#     if column in ['S_value', 'N_value']:
#         df_cifar[column] = df_cifar[column].cat.set_categories(df_cifar[column].unique())

#     label_mapping[column] = dict(enumerate(df_cifar[column].cat.categories))
#     df_cifar[column] = df_cifar[column].astype('category')

# fig = go.Figure(data=
#     go.Parcoords(
#         line=dict(color=df_cifar['mean_separability'],
#                   colorscale='Viridis',
#                   showscale=True,
#                   cmin=df_cifar['mean_separability'].min(),
#                   cmax=df_cifar['mean_separability'].max()),
#         dimensions=[
#             dict(label='Model', values=df_cifar['model'].cat.codes, tickvals=list(label_mapping['model'].keys()), ticktext=list(label_mapping['model'].values())),
#             dict(label='S_value', values=df_cifar['S_value'].cat.codes, tickvals=list(label_mapping['S_value'].keys()), ticktext=list(label_mapping['S_value'].values())),
#             dict(label='N_value', values=df_cifar['N_value'].cat.codes, tickvals=list(label_mapping['N_value'].keys()), ticktext=list(label_mapping['N_value'].values())),
#             dict(label='Method', values=df_cifar['method'].cat.codes, tickvals=list(label_mapping['method'].keys()), ticktext=list(label_mapping['method'].values())),
#             dict(label='Mean Separability', values=df_cifar['mean_separability'])
#         ]
#     )
# )

# fig.show()

In [None]:
import plotly.graph_objects as go

# Subset data for CIFAR100
df_cifar = df[df["dataset"] == "cifar100"]

# Create a unique list of S_value and model combinations
s_model_combinations = df_cifar.drop_duplicates(subset=["S_value", "model"])[["S_value", "model"]]

# Define color map for methods
palette = {
    "end-to-end": colors["persian_green"],
    "ours": colors["saffron"],
    "naive": colors["burnt_sienna"],
}

# Set up plot
fig = go.Figure()

# For each S_value and model combination, add a line with dots
counter = 0
for _, row in s_model_combinations.iterrows():
    subset = df_cifar[(df_cifar["S_value"] == row["S_value"]) & (df_cifar["model"] == row["model"])]

    # For each method in the subset, add a dot to the line
    for method in ["ours", "naive", "end-to-end"]:
        method_value = subset[subset["method"] == method]["mean_separability"].values[0]
        show_legend = (
            True
            if row["S_value"] == s_model_combinations["S_value"].iloc[0]
            and row["model"] == s_model_combinations["model"].iloc[0]
            else False
        )
        fig.add_trace(
            go.Scatter(
                x=[method_value],
                y=[counter],
                mode="markers",
                name=method,
                marker=dict(color=palette[method], size=15, line=dict(color="black", width=1)),
                legendgroup=method,
                showlegend=show_legend,
            )
        )

    counter += 1

# Update the layout and show
fig.update_layout(
    barmode="group",
    bargap=0.2,
    bargroupgap=0.1,
    legend_title_text="",
    font=dict(size=28),  # This sets a global font size for the plot
    plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
    paper_bgcolor="rgba(255,255,255)",
    margin=dict(l=300, r=10, t=70, b=0),  # to ensure dataset annotations have enough space
    # legend=dict(
    #     x=0.9, y=1, font=dict(size=28)
    # ),  # Adjust these values to position the legend inside the plot and adjust font size
    xaxis=dict(
        title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
        tickfont=dict(size=28),  # Adjusting the tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
        gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
        gridwidth=1,  # Adjust the x-axis gridline width
        zeroline=False,
    ),
    yaxis=dict(
        tickvals=list(range(len(s_model_combinations))),
        ticktext=[f"{row['S_value'][1:]}" for _, row in s_model_combinations.iterrows()],
        title_font=dict(size=28),  # Adjusting the y-axis title font size
        tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
        gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
        gridwidth=1,  # Adjust the x-axis gridline width
        zeroline=False,
    ),
)

fig.add_annotation(
    text="S",
    xref="paper",
    yref="paper",
    x=-0.020,
    y=1.15,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="VanillaCNN",
    xref="paper",
    yref="paper",
    x=-0.15,
    y=0.73,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="EfficientNet",
    xref="paper",
    yref="paper",
    x=-0.15,
    y=0.22,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.show()

In [None]:
dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / f"EXP1_separability_{N}.png",
    format="png",
    width=width_pixels,
    scale=2,
)

## Exp 2

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_2_separability_analysis.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
rows = []

names_mapping = {
    "mean_separability_ours": "ours",
    "mean_separability_naive": "naive",
    "mean_separability_original": "end-to-end",
}

model_mapping = {"from_scratch_cnn": "vanilla_cnn", "efficient_net": "efficient_net"}

models = model_mapping.values()
methods = names_mapping.values()

# Iterate through the JSON data to flatten it
for dataset, dataset_data in data.items():
    for model, model_data in dataset_data.items():
        for method, value in model_data.items():
            rows.append(
                {
                    "dataset": dataset,
                    "model": model_mapping[model],
                    "method": names_mapping[method],
                    "mean_separability": value,
                }
            )

df = pd.DataFrame(rows)

print(df)

In [None]:
from plotly.subplots import make_subplots

fig = make_subplots(rows=2, cols=2)


added_to_legend = set()
# Plot for each dataset, model combination
for row, dataset in enumerate(["cifar100", "tiny_imagenet"], start=1):
    for col, model in enumerate(["vanilla_cnn", "efficient_net"], start=1):
        filtered_df = df[(df["dataset"] == dataset) & (df["model"] == model)]
        for method in ["ours", "naive", "end-to-end"]:
            show_in_legend = method not in added_to_legend
            added_to_legend.add(method)
            y = filtered_df[filtered_df["method"] == method]["mean_separability"].values[0]
            fig.add_trace(
                go.Bar(
                    name=method,
                    x=[method],
                    y=[y],
                    marker_color=palette[method],
                    showlegend=show_in_legend,
                    text=y.round(2),  # <-- Add text here
                    textposition="inside",  # <-- Specify text position here
                ),
                row=row,
                col=col,
                # Only show the method in the legend once
            )

# Hide x-axis labels
for axis in fig.layout:
    if "xaxis" in axis:
        fig.layout[axis].showticklabels = False

# Hide y-axis labels
for axis in fig.layout:
    if "yaxis" in axis:
        fig.layout[axis].showticklabels = False


fig.add_annotation(
    text="CIFAR100",
    xref="paper",
    yref="y",
    x=-0.07,
    y=0.65,
    showarrow=False,
    font=dict(size=28, color="black"),
    # bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="Tiny",
    xref="paper",
    yref="paper",
    x=-0.07,
    y=0.15,
    showarrow=False,
    font=dict(size=28, color="black"),
    # bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="VanillaCNN",
    xref="paper",
    yref="paper",
    x=0.17,
    y=1.2,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="EfficientNet",
    xref="paper",
    yref="paper",
    x=0.82,
    y=1.2,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


# Update the layout and show
fig.update_layout(
    barmode="group",
    bargap=0.5,
    legend_title_text="",
    font=dict(size=28),  # This sets a global font size for the plot
    plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
    paper_bgcolor="rgba(255,255,255)",
    # legend=dict(
    #     x=0.9, y=1, font=dict(size=28)
    # ),  # Adjust these values to position the legend inside the plot and adjust font size
    xaxis=dict(
        # tickvals=[0.27, 1.27, 2.27, 3.27],
        # ticktext=x_labels,
        # ticktext=["", "", "", ""],
        # range=[-0.5, 3.9],  # Adjusting the range to limit the inner padding
        title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
        tickfont=dict(size=28),  # Adjusting the tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
        gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
        gridwidth=1,  # Adjust the x-axis gridline width
        zeroline=False,
    ),
    yaxis=dict(
        title_font=dict(size=28),  # Adjusting the y-axis title font size
        tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
    ),
    margin=dict(l=150, r=0, t=100, b=100),  # to ensure dataset annotations have enough space
)

# Update layout
fig.show()

In [None]:
dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP2_separability.png", format="png", width=width_pixels, scale=2
)

## EXP 3 Separability

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_3_separability_analysis.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
rows = []

names_mapping = {
    "mean_separability_ours": "ours",
    "mean_separability_naive": "naive",
    "mean_separability_original": "end-to-end",
}

model_mapping = {"from_scratch_cnn": "vanilla_cnn", "efficient_net": "efficient_net"}

models = model_mapping.values()
methods = names_mapping.values()

# Iterate through the JSON data to flatten it
for dataset, dataset_data in data.items():
    for model, model_data in dataset_data.items():
        for method, value in model_data["1"].items():
            rows.append(
                {
                    "dataset": dataset,
                    "model": model_mapping[model],
                    "method": names_mapping[method],
                    "mean_separability": value,
                }
            )

df = pd.DataFrame(rows)

print(df)

In [None]:
from plotly.subplots import make_subplots

fig = make_subplots(rows=2, cols=2)


added_to_legend = set()
# Plot for each dataset, model combination
for row, dataset in enumerate(["cifar100", "tiny_imagenet"], start=1):
    for col, model in enumerate(["vanilla_cnn", "efficient_net"], start=1):
        filtered_df = df[(df["dataset"] == dataset) & (df["model"] == model)]
        for method in ["ours", "naive", "end-to-end"]:
            show_in_legend = method not in added_to_legend
            added_to_legend.add(method)
            y = filtered_df[filtered_df["method"] == method]["mean_separability"].values[0]
            fig.add_trace(
                go.Bar(
                    name=method,
                    x=[method],
                    y=[y],
                    marker_color=palette[method],
                    showlegend=show_in_legend,
                    text=y.round(2),  # <-- Add text here
                    textposition="inside",  # <-- Specify text position here
                ),
                row=row,
                col=col,
                # Only show the method in the legend once
            )

# Hide x-axis labels
for axis in fig.layout:
    if "xaxis" in axis:
        fig.layout[axis].showticklabels = False

# Hide y-axis labels
for axis in fig.layout:
    if "yaxis" in axis:
        fig.layout[axis].showticklabels = False

# for annotation in fig['layout']['annotations']:
# annotation['font'] = dict(size=28)  # You can adjust '20' to the desired font size

fig.add_annotation(
    text="CIFAR100",
    xref="paper",
    yref="y",
    x=-0.07,
    y=0.65,
    showarrow=False,
    font=dict(size=28, color="black"),
    # bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="Tiny",
    xref="paper",
    yref="paper",
    x=-0.07,
    y=0.15,
    showarrow=False,
    font=dict(size=28, color="black"),
    # bgcolor="rgba(255, 255, 255, 0.7)",
)


fig.add_annotation(
    text="VanillaCNN",
    xref="paper",
    yref="paper",
    x=0.17,
    y=1.2,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)

fig.add_annotation(
    text="EfficientNet",
    xref="paper",
    yref="paper",
    x=0.82,
    y=1.2,
    showarrow=False,
    font=dict(size=28, color="black"),
    bgcolor="rgba(255, 255, 255, 0.7)",
)


# Update the layout and show
fig.update_layout(
    barmode="group",
    bargap=0.5,
    legend_title_text="",
    font=dict(size=28),  # This sets a global font size for the plot
    plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
    paper_bgcolor="rgba(255,255,255)",
    # legend=dict(
    #     x=0.9, y=1, font=dict(size=28)
    # ),  # Adjust these values to position the legend inside the plot and adjust font size
    xaxis=dict(
        # tickvals=[0.27, 1.27, 2.27, 3.27],
        # ticktext=x_labels,
        # ticktext=["", "", "", ""],
        # range=[-0.5, 3.9],  # Adjusting the range to limit the inner padding
        title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
        tickfont=dict(size=28),  # Adjusting the tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
        gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
        gridwidth=1,  # Adjust the x-axis gridline width
        zeroline=False,
    ),
    yaxis=dict(
        title_font=dict(size=28),  # Adjusting the y-axis title font size
        tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
    ),
    margin=dict(l=150, r=0, t=100, b=100),  # to ensure dataset annotations have enough space
)

# Update layout
fig.show()

In [None]:
dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP3_separability.png", format="png", width=width_pixels, scale=2
)

## Class subsets

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_class_subsets.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
# Convert the long experiment IDs into "Experiment 1", "Experiment 2", etc.
df["Readable_ID"] = ["Class subset " + str(i + 1) for i in range(len(df))]

# Plotly grouped bar plot
fig = go.Figure()

# Adding bars for task_specific_test_acc
fig.add_trace(
    go.Bar(
        x=df["Readable_ID"],
        y=df["task_specific_test_acc"],
        name="Accuracy over space A",
        marker_color=colors["burnt_sienna"],
        text=df["task_specific_test_acc"].round(2),  # <-- Add text here
        textposition="outside",  # <-- Specify text position here
    )
)

# Adding bars for restricted_test_acc
fig.add_trace(
    go.Bar(
        x=df["Readable_ID"],
        y=df["restricted_test_acc"],
        name="Accuracy over space B",
        marker_color=colors["persian_green"],
        text=df["restricted_test_acc"].round(2),  # <-- Add text here
        textposition="outside",  # <-- Specify text position here
    )
)

# Update layout
fig.update_layout(
    barmode="group",
    font=dict(size=28),  # This sets a global font size for the plot
    yaxis_title="Accuracy",
    xaxis=dict(
        # tickvals=[0.27, 1.27, 2.27, 3.27],
        # ticktext=x_labels,
        # ticktext=["", "", "", ""],
        # range=[-0.5, 3.9],  # Adjusting the range to limit the inner padding
        title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
        tickfont=dict(size=28),  # Adjusting the tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
    ),
    plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
    paper_bgcolor="rgba(255,255,255)",
    yaxis=dict(
        title_font=dict(size=28),  # Adjusting the y-axis title font size
        tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
        range=[0, 1],
    ),
)

fig.show()

In [None]:
dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP_class_subsets.png", format="png", width=width_pixels, scale=2
)

# Matrioska

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "matrioska.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
df = pd.DataFrame(data)

In [None]:
df

In [None]:
# Scatter plot using graph_objects
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=df["num_train_classes"],
        y=df["score"],
        mode="markers",
        marker_color=colors["persian_green"],
        marker=dict(size=15, line=dict(color="black", width=1)),
        name="Test Accuracy",
    )
)

fig.update_layout(xaxis_title="Number of Train Classes", yaxis_title="Test Accuracy")

fig.update_layout(
    barmode="group",
    font=dict(size=28),  # This sets a global font size for the plot
    yaxis_title="Accuracy",
    xaxis=dict(
        title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
        tickfont=dict(size=28),  # Adjusting the tick labels font size
        layer="below traces",
        showgrid=True,  # Ensure the x-axis gridlines are turned on
    ),
    plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
    paper_bgcolor="rgba(255,255,255)",
    yaxis=dict(
        title_font=dict(size=28),  # Adjusting the y-axis title font size
        tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
    ),
)

fig.show()

In [None]:
dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / f"matrioska.png", format="png", width=width_pixels, scale=2
)