In [None]:
from nn_core.common import PROJECT_ROOT
import json
import numpy as np


def adjust_cmap_alpha(cmap, alpha=1.0):
    # Get the colormap colors
    colors = cmap(np.arange(cmap.N))

    # Set the alpha value
    colors[:, -1] = alpha

    # Create a new colormap with the modified colors
    new_cmap = plt.matplotlib.colors.ListedColormap(colors)
    return new_cmap


def rgba_to_rgb(rgba, background=(1, 1, 1)):
    print(rgba)
    """Convert an RGBA color to an RGB color, blending over a specified background color."""
    return [rgba[i] * rgba[3] + background[i] * (1 - rgba[3]) for i in range(3)]

## Experiment 1 -- part shared part novel 

### CKA

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_1_cka_analysis.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
data

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt


def decimal_to_rgb_color(decimal_value, cmap="viridis"):
    """
    Convert a decimal value (between 0 and 1) to the corresponding RGB color in the given colormap.
    """
    if not (0 <= decimal_value <= 1):
        raise ValueError("decimal_value should be between 0 and 1 inclusive.")

    colormap = plt.get_cmap(cmap)
    color = colormap(decimal_value)[:3]

    color = [round(c, 2) for c in color]
    return tuple(color)


def compute_num_tasks(C, S, N):
    return (C - S) // N

In [None]:
num_classes = {"cifar100": 100, "tiny_imagenet": 200}
models = ["vanilla_cnn", "efficient_net"]
S = {"cifar100": [80, 60, 40, 20], "tiny_imagenet": [100, 50]}
N = {"cifar100": [10, 5], "tiny_imagenet": [25]}
datasets = ["cifar100", "tiny_imagenet"]
dataset_names = {"cifar100": "CIFAR100", "tiny_imagenet": "TINY"}
model_names = {"vanilla_cnn": "VanillaCNN", "efficient_net": "EfficientNet"}

In [None]:
# cmap = "coolwarm"
cmap = sns.light_palette("seagreen", as_cmap=True)
# cmap = adjust_cmap_alpha(cmap, alpha=1)
cmap = sns.color_palette("vlag", as_cmap=True)


num_cols = 10
num_rows = {"cifar100": 10, "tiny_imagenet": 3}

header = r"""
\begin{table}
\centering
\begin{tabular}{cccccccccc}
    \toprule
    & & & \multicolumn{7}{c}{CKA Measure} \\
    \cmidrule(lr){5-10}
    Dataset & S & N & \# tasks & non-shared & shared & total & non-shared & shared & total \\
    \midrule
"""


rows = []
for dataset in datasets:

    dataset_str = f"\\texttt{{{dataset_names[dataset]}}}"
    rows.append(
        f"        \\parbox[t]{{2mm}}{{\\multirow{{{num_rows[dataset]}}}{{*}}{{ \\rotatebox[origin=c]{{90}}{{{dataset_str}}} }}}}"
    )

    dataset_results = data[dataset]

    models_str = [f"\\texttt{{{model_names[model]}}}" for model in models]
    rows.append(
        f"& & & & \multicolumn{{3}}{{c}}{{{models_str[0]}}} & \multicolumn{{3}}{{c}}{{{models_str[1]}}}    \\\\    \\cmidrule(lr){{5-7}} \\cmidrule(lr){{8-10}}"
    )
    for n in N[dataset]:
        for s in S[dataset]:
            row = f"& ${s}$ & ${n}$ & {compute_num_tasks(num_classes[dataset], s, n)}"

            for model in models:

                run_results = data[dataset][model][f"S{s}"][f"N{n}"]
                non_shared = run_results.get("cka_non_shared", "---")
                non_shared_col = decimal_to_rgb_color(non_shared, cmap)

                shared = run_results.get("cka_shared", "---")
                shared_col = decimal_to_rgb_color(shared, cmap)

                total = run_results.get("cka_orig_aggr_rel_rel", "---")
                total_col = decimal_to_rgb_color(total, cmap)

                # rel_abs = cka_values.get('cka_rel_abs', '---')
                # rel_abs_col = decimal_to_rgb_color(rel_abs, cmap)

                # model_row = f"& \cellcolor[rgb]{{{non_shared_col[:3]}}}{non_shared:.2f} & \cellcolor[rgb]{{{shared_col[:3]}}}{shared:.2f} & \cellcolor[rgb]{{{total_col[:3]}}}{total:.2f}"
                model_row = f"& {non_shared:.2f} & {shared:.2f} & \cellcolor[rgb]{{{total_col[:3]}}}{total:.2f}"
                row += model_row
            row += "\\\\"
            rows.append(row.replace("(", "").replace(")", ""))
        rows.append(f"        \\cmidrule(lr){{2-{num_cols}}}")

# remove last cmidrule
rows.pop(-1)

footer = r"""
    \bottomrule
\end{tabular}
\caption{(\texttt{Experiment 1}). CKA values for different configurations spanning model, dataset, number of shared classes and number of novel classes per task.}\label{tab:cka-part-shared-part-novel}
\end{table}
"""

full_table = header + "\n".join(rows) + footer

print(full_table)

### Classification analysis

In [None]:
file_path = PROJECT_ROOT / "paper_results" / "exp_1_classification_results.json"
with open(file_path) as json_file:
    data = json.load(json_file)

In [None]:
end_to_end_results = {
    "cifar100": {"efficient_net": 0.7043, "vanilla_cnn": 0.3933},
    "tiny_imagenet": {
        "efficient_net": 0.6863,
        "vanilla_cnn": 0.2222,
    },
}

In [None]:
print(datasets)
print(S["cifar100"])
print(N["cifar100"])

In [None]:
max_alphas = {
    dataset: {
        model: max(
            [data[dataset][model][f"S{s}"][f"N{n}"]["merged"]["total_acc"] for s in S[dataset] for n in N[dataset]]
        )
        - end_to_end_results[dataset][model]
        for model in models
    }
    for dataset in datasets
}

In [None]:
num_cols = 14

header = r"""
\begin{table}
        \resizebox{\textwidth}{!}{%

        \centering
        \begin{tabular}{cccccccccccccc} %
                % HEADER

                \toprule
                Dataset & $S$                                       & $N$ & tasks & vanilla & non-shared     & shared         & total          & improv & vanilla & non-shared     & shared         & total          & improv \\
                \midrule
                % DATASET
"""

row_per_dataset = {"tiny_imagenet": 3, "cifar100": 9}
rows = []
for dataset in datasets:

    dataset_str = f"\\texttt{{{dataset_names[dataset]}}}"
    rows.append(
        f"        \\parbox[t]{{2mm}}{{\\multirow{{{row_per_dataset[dataset]}}}{{*}}{{ \\rotatebox[origin=c]{{90}}{{{dataset_str}}} }}}}"
    )

    dataset_results = data[dataset]

    models_str = [f"\\texttt{{{model_names[model]}}}" for model in models]
    rows.append(
        f"& & & & \multicolumn{{4}}{{c}}{{{models_str[0]}}} & {end_to_end_results[dataset][models[0]]:.2f} & \multicolumn{{4}}{{c}}{{{models_str[1]}}} & {end_to_end_results[dataset][models[1]]:.2f}   \\\\    \\cmidrule(lr){{5-9}} \\cmidrule(lr){{10-14}}"
    )
    for n in N[dataset]:
        for s in S[dataset]:
            row = f"& ${s}$ & ${n}$ & {compute_num_tasks(num_classes[dataset], s, n)}"

            for model in models:

                run_results = data[dataset][model][f"S{s}"][f"N{n}"]
                non_shared = run_results["merged"].get("non_shared_class_acc", "---")
                non_shared_col = decimal_to_rgb_color(non_shared, cmap)

                shared = run_results["merged"].get("shared_class_acc", "---")
                shared_col = decimal_to_rgb_color(shared, cmap)

                total = run_results["merged"].get("total_acc", "---")
                total_col = decimal_to_rgb_color(total, cmap)

                jumble = run_results["jumble_abs"].get("total_acc", "---")
                jumble_col = decimal_to_rgb_color(jumble, cmap)

                end_to_end_res = end_to_end_results[dataset][model]

                # TODO
                alpha_max = max_alphas[dataset][model]
                cmap = sns.light_palette("seagreen", as_cmap=True)

                cmap = adjust_cmap_alpha(cmap, alpha=alpha_max)

                improvement = abs(total - end_to_end_res)
                improv_col = decimal_to_rgb_color(improvement, cmap)
                model_row = f"& {jumble: .2f} & {non_shared:.2f} & {shared:.2f} & {total:.2f} & \cellcolor[rgb]{{{improv_col}}}+{improvement:.2f}"
                row += model_row
            row += "\\\\"
            rows.append(row.replace("(", "").replace(")", ""))
        rows.append(f"        \\cmidrule(lr){{2-{num_cols}}}")

# remove last cmidrule
rows.pop(-1)

footer = r"""
                \bottomrule                                                                                                                           \\
        \end{tabular}
        }
        \caption{(\texttt{Experiment 1}). Accuracy obtained by a simple classifier trained on the original absolute space (first line of each block), trained on the original relative space (second line of each block) and trained on the merged spaces (following lines).}\label{tab:part-shared-part-novel}
\end{table}
"""

full_table = header + "\n".join(rows) + footer

print(full_table)

## Experiment 2: CKA 

In [None]:
values = [[0.9085, 0.9158], [0.8410, 0.8772]]

colors = [[decimal_to_rgb_color(value, cmap=cmap) for value in l] for l in values]

In [None]:
colors

In [None]:
table_str = ""
for values_row, colors_row in zip(values, colors):
    row_str = ""
    for value, color in zip(values_row, colors_row):
        cell_str = f"\\cellcolor[rgb]{{{color[:3]}}}{value:.4f}&".replace("(", "").replace(")", "")
        row_str += cell_str
    table_str += row_str[:-2] + " \\\\ \n"

print(table_str)