In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from omegaconf import OmegaConf
from src.plots.latex import set_size, update_rcParams

# CHANGE: require runs over different depth
runs = {
    "trained": [
        "../multirun/2022-03-28/13-54-52",
        "../multirun/2022-03-28/13-55-24",
        "../multirun/2022-03-28/13-56-06",
        "../multirun/2022-03-28/13-56-24",
    ],
    "untrained": [
        "../multirun/2022-03-31/08-25-34",
        "../multirun/2022-03-31/10-58-10",
    ],
}
runs = {key: [Path(p) for p in ps] for key, ps in runs.items()}


In [None]:
records = []
for key, roots in runs.items():
    for root in roots:
        for experiment in filter(lambda p: p.is_dir(), root.iterdir()):
            cfg = OmegaConf.load(experiment / ".hydra" / "config.yaml")

            ckas = np.load(experiment / "cka" / "ckas_test.npy")
            diag_idx = np.diag_indices_from(ckas[0])
            for arr in ckas:
                identical_layer_ckas = arr[diag_idx[0], diag_idx[1]]
                for i, layer_cka in enumerate(identical_layer_ckas):
                    records.append(
                        (i, layer_cka, cfg.dataset.name, cfg.model.name, cfg.model.n_layers, cfg.n_epochs)
                    )

df = pd.DataFrame.from_records(
    records, columns=["Layer", "CKA", "Dataset", "Model", "N_Layers", "Epochs"]
)



In [None]:
df["Arch"] = df["Model"] + df["N_Layers"].apply(lambda x: f"_L{str(x)}")

In [None]:
for dataset in df.Dataset.unique():
    with plt.style.context("seaborn"):
        g = sns.catplot(
            data=df[(df.Model == "GAT2017") & (df.Dataset == dataset)],
            x="Layer",
            y="CKA",
            hue="N_Layers", 
            col="Epochs",
            row="Dataset",
            kind="box",
            sharex=False,
            sharey=True
        )
        # g.axes.flat[0].set_ylim(0, 1)

In [None]:
font_scale=2
width, height = set_size(fraction=1)
print(width, height)
with plt.style.context("seaborn"):
    with update_rcParams(
        {
            "axes.labelsize": 8 * font_scale,
            "font.size": 8 * font_scale,
            "legend.fontsize": 6 * font_scale,
            "xtick.labelsize": 6 * font_scale,
            "ytick.labelsize": 6 * font_scale,
        }
    ):
        df["Layers"] = df["N_Layers"]
        fig, ax = plt.subplots(1, 1, figsize=(width,height))
        g = sns.boxplot(
            data=df[(df.Model == "GCN2017") & (df.Dataset == "CiteSeer") & (df.Epochs==0)],
            x="Layer",
            y="CKA",
            hue="Layers", 
            ax=ax,
        )
        ax.set_xticklabels([1, 2, 3, 4, 5])
        ax.set_title("CiteSeer | Untrained GCN")
        ax.set_ylabel("Similarity (CKA)")
        fig.savefig("../reports/cka_init_gcn_citeseer.pdf", bbox_inches="tight")

        fig, ax = plt.subplots(1, 1, figsize=(width,height))
        g = sns.boxplot(
            data=df[(df.Model == "GCN2017") & (df.Dataset == "CiteSeer") & (df.Epochs==500)],
            x="Layer",
            y="CKA",
            hue="N_Layers", 
            ax=ax,
        )
        ax.set_xticklabels([1, 2, 3, 4, 5])
        ax.legend_.remove()
        ax.set_title("CiteSeer | GCN")
        ax.set_ylabel("Similarity (CKA)")
        fig.savefig("../reports/cka_gcn_citeseer.pdf", bbox_inches="tight")

        fig, ax = plt.subplots(1, 1, figsize=(width,height))
        g = sns.boxplot(
            data=df[(df.Model == "GCN2017") & (df.Dataset == "Pubmed") & (df.Epochs==500)],
            x="Layer",
            y="CKA",
            hue="N_Layers", 
            ax=ax,
        )
        ax.set_xticklabels([1, 2, 3, 4, 5])
        ax.legend(loc="lower left", ncol=3, frameon=True)
        ax.set_title("Pubmed | GCN")
        ax.set_ylabel("Similarity (CKA)")
        fig.savefig("../reports/cka_gcn_pubmed.pdf", bbox_inches="tight")

In [None]:
font_scale=2
width, height = set_size(fraction=1)
print(width, height)
with plt.style.context("seaborn"):
    with update_rcParams(
        {
            "axes.labelsize": 8 * font_scale,
            "font.size": 8 * font_scale,
            "legend.fontsize": 6 * font_scale,
            "xtick.labelsize": 6 * font_scale,
            "ytick.labelsize": 6 * font_scale,
        }
    ):
        df["Layers"] = df["N_Layers"]
        fig, ax = plt.subplots(1, 1, figsize=(width,height))
        g = sns.boxplot(
            data=df[(df.Model == "GAT2017") & (df.Dataset == "CS") & (df.Epochs==500)],
            x="Layer",
            y="CKA",
            hue="N_Layers", 
            ax=ax,
        )
        ax.set_xticklabels([1, 2, 3, 4, 5])
        ax.legend(loc="lower left", ncol=3, frameon=True)
        ax.set_title("CS | GAT")
        ax.set_ylabel("Similarity (CKA)")
        fig.savefig("../reports/cka_gat_cs.pdf", bbox_inches="tight")

In [None]:
for model in df.Model.unique():
    for dataset in df.Dataset.unique():
        font_scale=2
        width, height = set_size(fraction=1)
        print(width, height)
        with plt.style.context("seaborn"):
            with update_rcParams(
                {
                    "axes.labelsize": 8 * font_scale,
                    "font.size": 8 * font_scale,
                    "legend.fontsize": 6 * font_scale,
                    "xtick.labelsize": 6 * font_scale,
                    "ytick.labelsize": 6 * font_scale,
                }
            ):
                df["Layers"] = df["N_Layers"]
                
                fig, ax = plt.subplots(1, 1, figsize=(width,height))
                g = sns.boxplot(
                    data=df[(df.Model == model) & (df.Dataset == dataset) & (df.Epochs==500)],
                    x="Layer",
                    y="CKA",
                    hue="N_Layers", 
                    ax=ax,
                )
                ax.set_xticklabels([1, 2, 3, 4, 5])
                ax.legend(loc="lower left", ncol=3, frameon=True)
                if model == "GAT2017":
                    model_name = "GAT"
                elif model == "GCN2017":
                    model_name = "GCN"
                else:
                    raise ValueError
                ax.set_title(f"{dataset} | {model_name}")
                ax.set_ylabel("Similarity (CKA)")
                fig.savefig(f"../reports/appendix/cka_{model}_{dataset}.pdf", bbox_inches="tight")