In [None]:
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from rae import PROJECT_ROOT
import math

In [None]:
def encoder_factory(encoder_type, num_layers: int, in_channels: int, out_channels: int, **params):
    assert num_layers > 0
    if encoder_type == "GCN2Conv":
        convs = []
        for layer in range(num_layers):
            convs.append(GCN2Conv(layer=layer + 1, channels=out_channels, **params))
        return nn.ModuleList(convs)

    elif encoder_type == "GCNConv":
        convs = []
        convs = [
            GCNConv(
                in_channels=in_channels,
                out_channels=out_channels,
                **params,
            )
        ]
        in_channels = out_channels
        for layer in range(num_layers - 1):
            convs.append(
                GCNConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    **params,
                )
            )
        return nn.ModuleList(convs)

    elif encoder_type == "GATConv":
        convs = []
        convs = [
            GATConv(
                in_channels=in_channels,
                out_channels=out_channels,
                **params,
            )
        ]
        in_channels = out_channels
        for layer in range(num_layers - 1):
            convs.append(
                GATConv(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    **params,
                )
            )

        return nn.ModuleList(convs)

    elif encoder_type == "GINConv":
        convs = []
        current_in_channels = in_channels
        for layer in range(num_layers):
            convs.append(
                GINConv(
                    nn=nn.Linear(
                        in_features=current_in_channels,
                        out_features=out_channels,
                    ),
                    **params,
                )
            )
            current_in_channels = out_channels
        return nn.ModuleList(convs)

    else:
        raise NotImplementedError

In [None]:
experiments = torch.load(PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{'Cora'}_data_manifold_experiments.pt")
len(experiments)

In [None]:
stats = pd.read_csv(
    PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{'Cora'}_data_manifold_stats.tsv", sep="\t", index_col=0
)
stats

In [None]:
# Filter experiments that reach at least 0.7 acc.
VAL_ACC_LOWER_BOUND = 0.5

df_max_acc = stats.groupby(["experiment"]).agg([np.max])["val_acc"]
best_experiments = df_max_acc.loc[df_max_acc["amax"] > VAL_ACC_LOWER_BOUND]
best_experiments = best_experiments.reset_index().experiment
df_filtered = stats[stats["experiment"].isin(best_experiments)]
df_filtered, len(set(df_filtered.experiment))

In [None]:
experiments_valacc_similarity_correlation = []
for exp in set(stats.experiment):
    d_exp = df_filtered.loc[stats["experiment"] == exp]
    exp_corr = d_exp.corr(method="pearson")
    corr = exp_corr["val_acc"]["reference_distance"]
    if not math.isnan(corr):
        experiments_valacc_similarity_correlation.append(corr)
p_corr = np.mean(experiments_valacc_similarity_correlation)

print("Pearson correlation val_acc - ref_similarity: ", p_corr)

In [None]:
best_run = torch.load(PROJECT_ROOT / "experiments" / "sec:data-manifold" / f"{'Cora'}_best_run.pt")
best_run_latents = [best_run["best_epoch"]["rel_x"]]
best_run

In [None]:
from typing import *


def get_distance(latents1: torch.Tensor, latents_ref: Sequence[torch.Tensor]):
    assert not isinstance(latents_ref, (np.ndarray, torch.Tensor))
    dists = [F.cosine_similarity(latents1, latent_ref).mean().item() for latent_ref in latents_ref]
    return np.mean(dists)

In [None]:
import math

filtered_experiments = [
    x for x in experiments if not math.isnan(x["best_epoch"]["loss"]) or not np.isnan(x["best_epoch"]["rel_x"]).any()
]
len(filtered_experiments)

In [None]:
import json

points = {"score": [], "similarity": [], "hyperparams": [], "optimizer": [], "encoder": [], "color": []}
for run in filtered_experiments:
    distance = get_distance(latents1=F.normalize(run["best_epoch"]["rel_x"], dim=-1, p=2), latents_ref=best_run_latents)
    if np.isnan(distance):
        continue
    score = run["best_epoch"]["val_acc"]
    points["score"].append(score)
    points["similarity"].append(distance)
    hyperparams = {}
    for key in ("encoder", "lr", "hidden_fn", "conv_fn", "optimizer"):
        run_value = run[key]
        if key == "encoder":
            run_value = run_value[0]
        elif "_fn" in key:
            run_value = type(run_value).__name__
        elif key == "optimizer":
            run_value = run_value.__name__

        hyperparams[key] = run_value
    points["optimizer"].append(hyperparams["optimizer"])
    points["encoder"].append(hyperparams["encoder"])
    points["hyperparams"].append(json.dumps(hyperparams))
    points["color"].append(f'{points["optimizer"]}_{points["encoder"]}')

In [None]:
max(points["similarity"])

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
from tueplots import bundles
from tueplots import figsizes


N_ROWS = 1
N_COLS = 1
RATIO = 1

plt.rcParams.update(bundles.iclr2023(usetex=True))
plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))

fig, ax = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=150)


def plot_points(ax, pts, s=5):
    df = pd.DataFrame(pts)
    ax.set_aspect("auto")

    ax.scatter(df.similarity, df.score, s=s)

    z = np.polyfit(df.similarity, df.score, 1)
    trend_line = np.poly1d(z)
    ax.plot(np.asarray(sorted(df.similarity)), trend_line(sorted(df.similarity)), "C3--")


#     ax.set_xlabel('Similarity')
#     ax.set_ylabel('Score')

plot_points(ax, points)

In [None]:
fig.savefig("score_vs_distance.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o score_vs_distance.pdf score_vs_distance.svg
!rm score_vs_distance.svg

In [None]:
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Correlation over time

In [None]:
from pytorch_lightning import seed_everything
import random


# Filter experiments that reach at least 0.7 acc.
VAL_ACC_LOWER_BOUND = 0.9

df_max_acc = stats.groupby(["experiment"]).agg([np.max])["val_acc"]
best_experiments = df_max_acc.loc[df_max_acc["amax"] > VAL_ACC_LOWER_BOUND]
best_experiments = best_experiments.reset_index().experiment
df = stats[stats["experiment"].isin(best_experiments)]
available_experiments = sorted(set(df.experiment))


N_ROWS = 1
N_COLS = 1
RATIO = 1

plt.rcParams.update(bundles.iclr2023(usetex=True))
plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))


def plot_score_dist_over_time(ax, df):
    ax.set_aspect("auto")
    ax2 = ax.twinx()
    ax.plot(df.epoch, df.val_acc, "C0-")
    # ax.set_ylabel("Validation Accuracy  ", color="C0")

    ax2.plot(df.epoch, df.reference_distance, "C1-")
    # ax2.set_ylabel("Reference similarity", color="C1")


#     ax.set_xlabel("epochs")


fig, axes = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=150)
plot_score_dist_over_time(axes, df.loc[df["experiment"] == available_experiments[5]])

In [None]:
fig.savefig("correlation_over_time.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o correlation_over_time.pdf correlation_over_time.svg
!rm correlation_over_time.svg

In [None]:
# Plot both figures!

In [None]:
N_ROWS = 1
N_COLS = 2
RATIO = 0.8

plt.rcParams.update(bundles.iclr2023(usetex=True))
plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))

fig, [col1, col2] = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=150)

plot_points(col1, points, s=1)
plot_score_dist_over_time(col2, df.loc[df["experiment"] == available_experiments[5]])
plt.subplots_adjust(wspace=0.4)

In [None]:
fig.savefig("correlation_subfigure.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o correlation_subfigure.pdf correlation_subfigure.svg
!rm correlation_subfigure.svg

# Correlation grid (supmat)

In [None]:
# Filter experiments that reach at least 0.7 acc.
VAL_ACC_LOWER_BOUND = 0.5

df_max_acc = stats.groupby(["experiment"]).agg([np.max])["val_acc"]
best_experiments = df_max_acc.loc[df_max_acc["amax"] > VAL_ACC_LOWER_BOUND]
best_experiments = best_experiments.reset_index().experiment
df = stats[stats["experiment"].isin(best_experiments)]
available_experiments = sorted(set(df.experiment))
df

In [None]:
from pytorch_lightning import seed_everything
import random

seed_everything(0)
random.shuffle(available_experiments)

N_ROWS = 10
N_COLS = 10
RATIO = 1

plt.rcParams.update(bundles.iclr2023(usetex=True))
plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))


fig, axes = plt.subplots(nrows=N_ROWS, ncols=N_COLS, dpi=200, figsize=(15, 15))


def plot_score_dist_over_time(ax, df):
    ax2 = ax.twinx()
    ax.plot(df.epoch, df.val_acc, "C0-")
    #     ax.set_ylabel('Validation Accuracy  ', color='C0')

    ax2.plot(df.epoch, df.reference_distance, "C1-")
    #         ax2.set_ylabel('Reference similarity', color='C1')

    ax.set_yticklabels([])
    ax2.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xticks([])
    ax.set_yticks([])
    ax2.set_yticks([])
    ax.set_aspect("auto")


i = 0
for row in axes:
    for ax in row:
        df_plot = df.loc[df["experiment"] == available_experiments[i]]
        plot_score_dist_over_time(ax, df_plot)
        i += 1

In [None]:
fig.savefig("correlation_grid.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o correlation_grid.pdf correlation_grid.svg
!rm correlation_grid.svg