In [None]:
import wandb
import seaborn as sns

import matplotlib.pyplot as plt
import pandas as pd

api = wandb.Api()

In [None]:
RUNS = ["philippe_phd/mol-distill/mj3al0lc", "philippe_phd/mol-distill/aypnwcyu", "philippe_phd/mol-distill/rv16elr0"]
NAMES = ["5-layers-kernel", "2-layers-kernel", "3-layers-kernel"]
KEYS = ["Sum", "ThreeDInfomax", "FRAD_QM9"]

df = pd.DataFrame()
for run_id, name in zip(RUNS, NAMES):
    run = api.run(run_id)
    df_r = run.history()
    df_r["name"] = name

    df_r_processed = pd.DataFrame()
    for key in KEYS:
        col_name = f"train_loss_{key}" if not key == "Sum" else "train_loss"
        col_name_eval = f"eval_loss_{key}" if not key == "Sum" else "eval_loss"

        df_r["teacher"] = key

        df_r["loss"] = df_r[col_name]
        df_r["split"] = "train"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])

        df_r["loss"] = df_r[col_name_eval]
        df_r["split"] = "val"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])


    df = pd.concat([df, df_r_processed])
df

In [None]:
fig,axes = plt.subplots(1, len(KEYS), figsize=(len(KEYS)*2.7, 2.3), sharex=True)

cmap = {name: sns.color_palette("husl",len(NAMES))[i] for i, name in enumerate(NAMES)}

for i, key in enumerate(KEYS):
    sns.lineplot(data=df[df.teacher == key].dropna(), x="epoch", y="loss", hue="name", ax = axes[i], palette=cmap, legend=i==len(KEYS)-1, alpha=0.8, style="split")

    axes[i].set_ylim(
        df[df.teacher == key]["loss"].min(),
        df[df.teacher == key]["loss"].quantile(0.95)
    )


for ax,name in zip(axes, KEYS):
    ax.set_ylabel("")
    ax.set_xlabel("Epoch")
    ax.set_title(name)
    ax.set_xlim(0, 190)
axes[0].set_ylabel("Train Loss")

# Add in the legen -: train  --: eval
axes[-1].legend(["Train", "Eval"], loc='center left', bbox_to_anchor=(1, 0.5))




# Move legend outside
plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5))

plt.tight_layout()

plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/kernel_train_curve.pdf", bbox_inches="tight")

In [None]:
RUNS = ["philippe_phd/mol-distill/33msn9sy", "philippe_phd/mol-distill/wow4guql", "philippe_phd/mol-distill/27pq9iwq", "philippe_phd/mol-distill/wm0onriy", "philippe_phd/mol-distill/r099cum7"]
NAMES = ["GINE-student", "GAT-student", "GCN-student", "TAG-student", "SAGE-student"]
KEYS = ["Sum", "GraphMVP", "ChemBertMTR-77M", "FRAD_QM9", "ThreeDInfomax", "GraphCL"]

MAX_EPOCH = 450

df = pd.DataFrame()
for run_id, name in zip(RUNS, NAMES):
    run = api.run(run_id)
    df_r = run.history()[run.history().epoch <= MAX_EPOCH]
    df_r["name"] = name

    df_r_processed = pd.DataFrame()
    for key in KEYS:
        col_name = f"train_loss_{key}" if not key == "Sum" else "train_loss"
        col_name_eval = f"eval_loss_{key}" if not key == "Sum" else "eval_loss"

        df_r["teacher"] = key

        df_r["loss"] = df_r[col_name]
        df_r["split"] = "train"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])

        df_r["loss"] = df_r[col_name_eval]
        df_r["split"] = "val"
        df_r_processed = pd.concat([df_r_processed, df_r[["loss", "split", "teacher", "epoch", "name"]]])


    df = pd.concat([df, df_r_processed])

In [None]:
sns.color_palette("BrBG",len(NAMES)+4)

In [None]:
fig,axes = plt.subplots(2, len(KEYS)//2, figsize=(len(KEYS)/2*2.7, 4.3), sharex=True)
axes = axes.flatten()

cmap_offset = 4
to_iso = ["GINE", "GAT"]
not_to_iso = [n.split("-")[0] for n in NAMES if n.split("-")[0] not in to_iso]

def get_color(name, desat=0.7, cmap_offset = 2, cmap = "PRGn"):
    arch = name.split("-")[0]
    if arch in to_iso:
        idx = to_iso.index(arch)+1+cmap_offset//2
        return sns.color_palette(cmap,len(NAMES)+cmap_offset)[-idx]
    else:
        idx = not_to_iso.index(arch)
        return sns.color_palette(cmap,len(NAMES)+cmap_offset, desat = desat)[idx]

cmap = {
    name: get_color(name)
    for name in NAMES
}

for i, key in enumerate(KEYS):
    sns.lineplot(data=df[df.teacher == key].dropna(), x="epoch", y="loss", hue="name", ax = axes[i], palette=cmap, legend=i==len(KEYS)-1, alpha=1, style="split")

    axes[i].set_ylim(
        df[df.teacher == key]["loss"].min(),
        df[df.teacher == key]["loss"].quantile(0.99)
    )


for ax,name in zip(axes, KEYS):
    ax.set_ylabel("")
    ax.set_xlabel("Epoch")
    ax.set_title(name)
    ax.set_xlim(0, 400)
axes[0].set_ylabel("Train Loss")

# Add in the legen -: train  --: eval
axes[-1].legend(["Train", "Eval"], loc='center left', bbox_to_anchor=(1, 0.5))




# Move legend outside
plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.8))



#plt.tight_layout()

plt.savefig("/home/philippe/Distill/latex/Distillation-MI-ICLR/figures/molecules/archi_train_curve.pdf", bbox_inches="tight")