In [None]:
from nn_core.common import PROJECT_ROOT
import json
import plotly.graph_objects as go

In [None]:
colors = {
    "charchoal": "#264653",
    "persian_green": "#2A9D8F",
    "saffron": "#E9C46A",
    "sandy_brown": "#F4A261",
    "burnt_sienna": "#e76f51",
}

## Experiment 2: Same classes -- disj samples

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_2_classification_results.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
datasets = ["cifar100", "tiny_imagenet"]
models = ["from_scratch_cnn", "efficient_net"]
modalities = ["end-to-end", "merged", "jumble"]
modality_names = {"end-to-end": "end-to-end", "merged": "relative", "jumble": "naive"}

end_to_end_results = {
    "cifar100": {"efficient_net": 0.7043, "from_scratch_cnn": 0.3933},
    "tiny_imagenet": {
        "efficient_net": 0.6863,
        "from_scratch_cnn": 0.2222,
    },
}

In [None]:
for dataset in datasets:
    for model in models:
        data[dataset][model]["end-to-end"] = {"total_acc": end_to_end_results[dataset][model]}

In [None]:
def grouped_barplot_accuracy(data):
    # Set color for each category
    modality_colors = {
        "end-to-end": colors["persian_green"],
        "merged": colors["saffron"],
        "jumble": colors["burnt_sienna"],
    }

    fig = go.Figure()

    positions = {category: i for i, category in enumerate(modalities)}

    added_legends = set()  # To keep track of which legends have been added
    x_labels = []
    counter = 0  # Count to control positioning

    for dataset in datasets:
        for model in models:
            x_label = f"{model}"
            x_labels.append(x_label)

            for modality in modalities:

                value = data[dataset][model][modality]["total_acc"]
                # Show legend only once for each category
                show_legend = modality not in added_legends
                added_legends.add(modality)
                fig.add_trace(
                    go.Bar(
                        x=[counter],
                        y=[value],
                        name=modality_names[modality],
                        width=0.15,
                        offset=positions[modality] * 0.18,
                        marker_color=modality_colors[modality],
                        showlegend=show_legend,
                        text=f"{value:.2f}",  # Display the value on top of the bars
                        textposition="outside",  # Position the text on top of the bars
                    )
                )
            counter += 1

    # Add dataset annotations in the middle of model groups
    fig.add_annotation(
        text="CIFAR100",
        xref="x",
        yref="paper",
        x=0.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )
    fig.add_annotation(
        text="TINY IMAGENET",
        xref="x",
        yref="paper",
        x=2.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )

    fig.update_layout(
        # title={
        #     'text': "Grouped Barplot of Model Accuracies",
        #     'font': {'size': 24}
        # },
        yaxis_title="Accuracy",
        barmode="group",
        bargap=0.15,
        bargroupgap=0.1,
        # legend_title={
        #     'text': "Training Type",
        #     'font': {'size': 20}
        # },
        legend=dict(
            x=0.9, y=1, font=dict(size=28)
        ),  # Adjust these values to position the legend inside the plot and adjust font size
        xaxis=dict(
            tickvals=[0.27, 1.27, 2.27, 3.27],
            ticktext=x_labels,
            range=[-0.25, 3.9],  # Adjusting the range to limit the inner padding
            title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
            tickfont=dict(size=28),  # Adjusting the tick labels font size
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        yaxis=dict(
            title_font=dict(size=28),  # Adjusting the y-axis title font size
            tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
            range=[0, 1],
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        height=700,
        margin=dict(l=0, r=0, t=50, b=100),  # to ensure dataset annotations have enough space
        font=dict(size=28),  # This sets a global font size for the plot
        plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
        paper_bgcolor="rgba(255,255,255)",
    )

    return fig

In [None]:
fig = grouped_barplot_accuracy(data)
fig.show()

dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP2_accuracy.png", format="png", width=width_pixels, scale=2
)

## Experiment 3: Totally disjoint

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_3_classification_results.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
end_to_end_results = {
    "cifar100": {"efficient_net": 0.6977, "from_scratch_cnn": 0.4001},
    "tiny_imagenet": {
        "efficient_net": 0.6903,
        "from_scratch_cnn": 0.2233,
    },
}

In [None]:
for dataset in datasets:
    for model in models:
        data[dataset][model]["1"]["end-to-end"] = {"total_acc": end_to_end_results[dataset][model]}

In [None]:
def grouped_barplot_accuracy(data):
    # Set color for each category
    modality_colors = {
        "end-to-end": colors["persian_green"],
        "merged": colors["saffron"],
        "jumble": colors["burnt_sienna"],
    }

    fig = go.Figure()

    positions = {category: i for i, category in enumerate(modalities)}

    added_legends = set()  # To keep track of which legends have been added
    x_labels = []
    counter = 0  # Count to control positioning

    for dataset in datasets:
        for model in models:
            x_label = f"{model}"
            x_labels.append(x_label)

            for modality in modalities:

                value = data[dataset][model]["1"][modality]["total_acc"]
                # Show legend only once for each category
                show_legend = modality not in added_legends
                added_legends.add(modality)
                fig.add_trace(
                    go.Bar(
                        x=[counter],
                        y=[value],
                        name=modality_names[modality],
                        width=0.15,
                        offset=positions[modality] * 0.18,
                        marker_color=modality_colors[modality],
                        showlegend=show_legend,
                        text=f"{value:.2f}",  # Display the value on top of the bars
                        textposition="outside",  # Position the text on top of the bars
                    )
                )
            counter += 1

    # Add dataset annotations in the middle of model groups
    fig.add_annotation(
        text="CIFAR100",
        xref="x",
        yref="paper",
        x=0.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )
    fig.add_annotation(
        text="TINY IMAGENET",
        xref="x",
        yref="paper",
        x=2.8,
        y=-0.15,
        showarrow=False,
        font=dict(size=28, color="black"),
        bgcolor="rgba(255, 255, 255, 0.7)",
    )

    fig.update_layout(
        # title={
        #     'text': "Grouped Barplot of Model Accuracies",
        #     'font': {'size': 24}
        # },
        yaxis_title="Accuracy",
        barmode="group",
        bargap=0.15,
        bargroupgap=0.1,
        # legend_title={
        #     'text': "Training Type",
        #     'font': {'size': 20}
        # },
        legend=dict(
            x=0.9, y=1, font=dict(size=28)
        ),  # Adjust these values to position the legend inside the plot and adjust font size
        xaxis=dict(
            tickvals=[0.27, 1.27, 2.27, 3.27],
            ticktext=x_labels,
            range=[-0.25, 3.9],  # Adjusting the range to limit the inner padding
            title_font=dict(size=28),  # Adjusting the axis title font size (if you have one)
            tickfont=dict(size=28),  # Adjusting the tick labels font size
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        yaxis=dict(
            title_font=dict(size=28),  # Adjusting the y-axis title font size
            tickfont=dict(size=28),  # Adjusting the y-axis tick labels font size
            range=[0, 1],
            layer="below traces",
            showgrid=True,  # Ensure the x-axis gridlines are turned on
            gridcolor="rgba(0,0,0,0.1)",  # Set the x-axis gridline color (this example uses a very light gray)
            gridwidth=1,  # Adjust the x-axis gridline width
            zeroline=False,
        ),
        height=700,
        margin=dict(l=0, r=0, t=50, b=100),  # to ensure dataset annotations have enough space
        font=dict(size=28),  # This sets a global font size for the plot
        plot_bgcolor="rgba(277,218,201, 0.2)",  # Set the plot background to barely visible white
        paper_bgcolor="rgba(255,255,255)",
    )

    return fig

In [None]:
fig = grouped_barplot_accuracy(data)
fig.show()

dpi = 300
width_pixels = 8 * dpi  # 8inches
height_pixels = 1 * dpi

fig.write_image(
    PROJECT_ROOT / "results" / "generated_plots" / "EXP3_accuracy.png", format="png", width=width_pixels, scale=2
)