In [None]:
import os

os.chdir("..")

In [None]:
import pytorch_lightning as pl
import pickle as pkl
import yaml
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader
from gbdsim.experiment_config import ExperimentConfig
import torch
import numpy as np

from gbdsim.causal.data import generate_synthetic_causal_data_example
from gbdsim.data.generator_dataset import GeneratorDataset
from itertools import chain
from scipy.stats import ks_2samp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from gbdsim.causal.factory import MlpCausalGraphFractory
from gbdsim.causal.similarity import calculate_graph_distance
from tqdm import tqdm

pl.seed_everything(1)

In [None]:
with open("config/synthetic/gbdsim.yaml") as f:
    config = ExperimentConfig.model_validate(yaml.load(f, Loader=yaml.CLoader))

### OOD Analysis - similarities only

In [None]:
train_dataset = GeneratorDataset(
    generate_synthetic_causal_data_example,
    1000,
    False,
)
train_loader = DataLoader(
    train_dataset,
    config.data.train_batch_size,
    collate_fn=lambda x: x,
    num_workers=1,
    pin_memory=True,
    worker_init_fn=lambda id: seed_everything(id, verbose=False),  # type: ignore # noqa: E501
)

val_dataset = GeneratorDataset(
    generate_synthetic_causal_data_example,
    1000,
    True,
)
val_loader = DataLoader(
    val_dataset,
    config.data.val_batch_size,
    collate_fn=lambda x: x,
    num_workers=1,
    pin_memory=True,
    worker_init_fn=lambda id: seed_everything(id, verbose=False),  # type: ignore # noqa: E501
)

In [None]:
train_data = list(chain(*[batch for batch in train_loader]))
val_data = list(chain(*[batch for batch in val_loader]))

train_similarities = [obs[-1].item() for obs in train_data]
val_similarities = [obs[-1].item() for obs in val_data]

In [None]:
histogram_data = pd.DataFrame(
    {
        "similarity": train_similarities + val_similarities,
        "sample": ["train"] * len(train_similarities)
        + ["test"] * len(val_similarities),
    }
)
sns.kdeplot(data=histogram_data, x="similarity", hue="sample", fill=True)
pval = ks_2samp(train_similarities, val_similarities).pvalue
plt.title(
    f"Comparison of train-test similarity distributions \n Kolmogorov-Smirnoff test p-value = {pval:.3f}"
)

### OOD Analysis - graph distance between samples

In [None]:
train_datasets = []
val_datasets = []
for _ in range(2000):
    g = MlpCausalGraphFractory.generate_causal_graph()
    train_datasets.append((g, g.generate_data()))
for _ in range(2000):
    g = MlpCausalGraphFractory.generate_causal_graph()
    val_datasets.append((g, g.generate_data()))
train_datasets = train_datasets[:100]
val_datasets = val_datasets[:100]

In [None]:
intra_train_distances = [
    calculate_graph_distance(
        train_datasets[i][0].nx_graph, train_datasets[j][0].nx_graph
    ).item()
    for i in tqdm(range(len(train_datasets)))
    for j in range(len(train_datasets))
    if i <= j
]

intra_test_distances = [
    calculate_graph_distance(
        val_datasets[i][0].nx_graph, val_datasets[j][0].nx_graph
    ).item()
    for i in tqdm(range(len(val_datasets)))
    for j in range(len(val_datasets))
    if i <= j
]
inter_distances = [
    calculate_graph_distance(
        train_datasets[i][0].nx_graph, val_datasets[j][0].nx_graph
    ).item()
    for i in tqdm(range(len(val_datasets)))
    for j in range(len(train_datasets))
]

In [None]:
histogram_data = pd.DataFrame(
    {
        "similarity": intra_train_distances
        + intra_test_distances
        + inter_distances[: len(inter_distances) // 2],
        "distance": ["train"] * len(intra_train_distances)
        + ["test"] * len(intra_test_distances)
        + ["between_samples"] * (len(inter_distances) // 2),
    }
)

In [None]:
sns.kdeplot(data=histogram_data, x="similarity", hue="distance", fill=True)

In [None]:
ks_2samp(intra_train_distances, intra_test_distances).pvalue

In [None]:
ks_2samp(inter_distances, intra_test_distances).pvalue

### GED vs performance

In [None]:
with open(
    "results/synthetic/gbdsim/2025_04_20__16_50_11/final_model.pkl", "rb"
) as f:
    gbdsim = pkl.load(f).model

with open(
    "results/synthetic/dataset2vec/2025_04_20__16_20_36/final_model.pkl", "rb"
) as f:
    dataset2vec = pkl.load(f).model

In [None]:
with torch.no_grad():
    gbdsim_similarities = [
        gbdsim.calculate_dataset_distance(
            obs[0].to(gbdsim.device),
            obs[1].to(gbdsim.device),
            obs[2].to(gbdsim.device),
            obs[3].to(gbdsim.device),
        )[0][0].item()
        for obs in val_data
    ]
    dataset2vec_similarities = [
        dataset2vec.calculate_dataset_distance(
            obs[0].to(gbdsim.device),
            obs[1].to(gbdsim.device),
            obs[2].to(gbdsim.device),
            obs[3].to(gbdsim.device),
        )[0].item()
        for obs in val_data
    ]

In [None]:
sns.set_style("whitegrid")

In [None]:
sns.scatterplot(
    x=val_similarities,
    y=np.array(val_similarities) - np.array(gbdsim_similarities),
)
plt.xlabel("Ground-truth label")
plt.ylabel("Residual")

In [None]:
hist_data = pd.DataFrame(
    {
        "similarity": val_similarities + gbdsim_similarities,
        "origin": ["Ground-truth"] * len(val_similarities)
        + ["Predicted"] * len(gbdsim_similarities),
    }
)
sns.kdeplot(data=hist_data, x="similarity", hue="origin", fill=True)

In [None]:
sns.scatterplot(
    x=val_similarities,
    y=np.array(gbdsim_similarities),
)
plt.xlabel("Ground-truth label")
plt.ylabel("Prediction")

In [None]:
sns.scatterplot(
    x=val_similarities,
    y=np.array(val_similarities) - np.array(dataset2vec_similarities),
)
plt.xlabel("Ground-truth label")
plt.ylabel("Residual")