In [None]:
%cd ../..

LATEX_PATH = "../latex/Distillation-MI-ICLR"

In [None]:
import wandb
import seaborn as sns

import matplotlib.pyplot as plt
import pandas as pd

api = wandb.Api()

In [None]:
RUNS = ["/mol-distill/56rx297n", "/mol-distill/jxk2dqi1", "/mol-distill/33msn9sy"]
NAMES = ["5-layers-kernel", "2-layers-kernel", "3-layers-kernel"]
KEYS = ["Sum", "GraphMVP", "ChemBertMTR-77M", "FRAD_QM9", "ThreeDInfomax", "GraphCL"]

MAX_EPOCH = 450

df = pd.DataFrame()
for run_id, name in zip(RUNS, NAMES):
    run = api.run(run_id)
    df_r = run.history()[run.history().epoch <= MAX_EPOCH]
    df_r["name"] = name

    df_r_processed = pd.DataFrame()
    for key in KEYS:
        col_name = f"train_loss_{key}" if not key == "Sum" else "train_loss"
        col_name_eval = f"eval_loss_{key}" if not key == "Sum" else "eval_loss"

        df_r["teacher"] = key

        df_r["loss"] = df_r[col_name]
        df_r["split"] = "train"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])

        df_r["loss"] = df_r[col_name_eval]
        df_r["split"] = "val"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])


    df = pd.concat([df, df_r_processed])
df

In [None]:
fig,axes = plt.subplots(2, len(KEYS)//2, figsize=(len(KEYS)/2*2.7, 4.3), sharex=True)
axes = axes.flatten()

for i, key in enumerate(KEYS):
    sns.lineplot(data=df[df.teacher == key].dropna(), x="epoch", y="loss", hue="name", ax = axes[i], palette="husl", legend=i==len(KEYS)-1, alpha=0.8, style="split")

    axes[i].set_ylim(
        df[df.teacher == key]["loss"].min(),
        df[df.teacher == key]["loss"].quantile(0.99)
    )


for ax,name in zip(axes, KEYS):
    ax.set_ylabel("")
    ax.set_xlabel("Epoch")
    ax.set_title(name)
    ax.set_xlim(0, 400)
axes[0].set_ylabel("Train Loss")

# Add in the legen -: train  --: eval
axes[-1].legend(["Train", "Eval"], loc='center left', bbox_to_anchor=(1, 0.5))




# Move legend outside
plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.8))



#plt.tight_layout()

plt.savefig(f"{LATEX_PATH}/figures/molecules/kernel_train_curve.pdf", bbox_inches="tight")

In [None]:
RUNS = ["/mol-distill/33msn9sy", "/mol-distill/wow4guql", "/mol-distill/27pq9iwq", "/mol-distill/wm0onriy", "/mol-distill/r099cum7", "/mol-distill/he9vr7df"]
NAMES = ["GINE-student", "GAT-student", "GCN-student", "TAG-student", "SAGE-student", "GIN-student"]
KEYS = ["Sum", "ChemBertMTR-77M", "FRAD_QM9"]

MAX_EPOCH = 450

df = pd.DataFrame()
for run_id, name in zip(RUNS, NAMES):
    run = api.run(run_id)
    df_r = run.history()[run.history().epoch <= MAX_EPOCH]
    df_r["name"] = name

    df_r_processed = pd.DataFrame()
    for key in KEYS:
        col_name = f"train_loss_{key}" if not key == "Sum" else "train_loss"
        col_name_eval = f"eval_loss_{key}" if not key == "Sum" else "eval_loss"

        df_r["teacher"] = key

        df_r["loss"] = df_r[col_name]
        df_r["split"] = "train"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])

        df_r["loss"] = df_r[col_name_eval]
        df_r["split"] = "val"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])


    df = pd.concat([df, df_r_processed])

In [None]:
df["archi"] = df.name.apply(lambda x: x.split("-")[0])
df

In [None]:
fig,axes = plt.subplots(2, len(KEYS), figsize=(len(KEYS)*2.7, 4.3), sharex=True)

cmap_offset = 4
to_iso = ["GINE", "GAT"]
not_to_iso = [n.split("-")[0] for n in NAMES if n.split("-")[0] not in to_iso]


def get_color(name, desat=0.7, cmap_offset = 2, cmap = "icefire"):
    arch = name.split("-")[0]
    if arch in to_iso:
        idx = to_iso.index(arch)+1+cmap_offset//2
        return sns.color_palette(cmap,len(NAMES)+cmap_offset)[-idx]
    else:
        idx = not_to_iso.index(arch)
        return sns.color_palette(cmap,len(NAMES)+cmap_offset, desat = desat)[idx]

cmap = {
    name: get_color(name)
    for name in NAMES
}

for i, models in enumerate([to_iso, not_to_iso]):
    for j, key in enumerate(KEYS):
        print(i,j)
        sns.lineplot(
            data=df[(df.teacher == key) & (df.archi.isin(models))].dropna(),
            x="epoch",
            y="loss",
            hue="name",
            ax = axes[i,j],
            palette=cmap,
            legend=j==len(KEYS)-1,
            alpha=1,
            style="split"
        )

        axes[i,j].set_ylim(
            df[df.teacher == key]["loss"].min(),
            df[df.teacher == key]["loss"].quantile(0.99)
        )

for i,models in enumerate([to_iso, not_to_iso]):
    for j, key in enumerate(KEYS):
        ax = axes[i,j]
        ax.set_ylabel("")
        ax.set_xlabel("Epoch")
        ax.set_title(name)
        ax.set_xlim(0, 400)
fig.supylabel("Train Loss")

# Add in the legen -: train  --: eval
axes[-1,-1].legend(["Train", "Eval"], loc='center left', bbox_to_anchor=(1, 0.5))




# Move legend outside
axes[-1,-1].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
axes[0,-1].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))



#plt.tight_layout()

plt.savefig(f"{LATEX_PATH}/figures/molecules/archi_train_curve.pdf", bbox_inches="tight")