In [None]:
import os

os.chdir("..")

In [None]:
from pathlib import Path
import pytorch_lightning as pl
import yaml
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader
import pickle as pkl
import torch
import numpy as np

from gbdsim.data.data_pairs_with_landmarkers_distance_generator import (
    DatasetsPairsWithLandmarkersGenerator,
)
from gbdsim.data.generator_dataset import GeneratorDataset
from gbdsim.experiment_config import ExperimentConfig
import warnings
from itertools import chain
import pandas as pd
import seaborn as sns
from scipy.stats import ks_2samp
import matplotlib.pyplot as plt

warnings.simplefilter("ignore")

pl.seed_everything(1)

In [None]:
with open("config/tabrepo/gbdsim.yaml") as f:
    config = ExperimentConfig.model_validate(yaml.load(f, Loader=yaml.CLoader))

In [None]:
train_dataset = GeneratorDataset(
    DatasetsPairsWithLandmarkersGenerator.from_paths(
        list(Path("data/tabrepo/datasets").iterdir()),
        Path("data/tabrepo/raw_ranks.csv"),
        Path("data/tabrepo/selected_pipelines.json"),
        "train",
    ).generate_pair_of_datasets_with_label,
    config.data.train_dataset_size,
    False,
)
train_loader = DataLoader(
    train_dataset,
    config.data.train_batch_size,
    collate_fn=lambda x: x,
    num_workers=7,
    pin_memory=True,
    worker_init_fn=lambda id: seed_everything(id, verbose=False),  # type: ignore # noqa: E501
)

val_dataset = GeneratorDataset(
    DatasetsPairsWithLandmarkersGenerator.from_paths(
        list(Path("data/tabrepo/datasets").iterdir()),
        Path("data/tabrepo/raw_ranks.csv"),
        Path("data/tabrepo/selected_pipelines.json"),
        "test",
    ).generate_pair_of_datasets_with_label,
    config.data.train_dataset_size,
    False,
)
val_loader = DataLoader(
    val_dataset,
    config.data.val_batch_size,
    collate_fn=lambda x: x,
    num_workers=7,
    pin_memory=True,
    worker_init_fn=lambda id: seed_everything(id, verbose=False),  # type: ignore # noqa: E501
)

In [None]:
train_data = list(chain(*[batch for batch in train_loader]))
val_data = list(chain(*[batch for batch in val_loader]))

train_similarities = [obs[-1].item() for obs in train_data]
val_similarities = [obs[-1].item() for obs in val_data]

### OOD Analysis - similarities only

In [None]:
histogram_data = pd.DataFrame(
    {
        "similarity": train_similarities + val_similarities,
        "sample": ["train"] * len(val_similarities)
        + ["test"] * len(val_similarities),
    }
)
sns.kdeplot(data=histogram_data, x="similarity", hue="sample", fill=True)
pval = ks_2samp(train_similarities, val_similarities).pvalue
plt.title(
    f"Comparison of train-test similarity distributions \n Kolmogorov-Smirnoff test p-value = {pval:.10f}"
)

### Landmarker distance vs residuals

In [None]:
with open(
    "results/tabrepo/gbdsim/2025_05_24__12_58_44/final_model.pkl", "rb"
) as f:
    gbdsim = pkl.load(f).model

with open(
    "results/tabrepo/dataset2vec/2025_05_24__13_07_49/final_model.pkl", "rb"
) as f:
    dataset2vec = pkl.load(f).model

In [None]:
with torch.no_grad():
    gbdsim_similarities = [
        gbdsim.calculate_dataset_distance(
            obs[0].to(gbdsim.device),
            obs[1].to(gbdsim.device),
            obs[2].to(gbdsim.device),
            obs[3].to(gbdsim.device),
        )[0].item()
        for obs in val_data
    ]
    dataset2vec_similarities = [
        dataset2vec.calculate_dataset_distance(
            obs[0].to(gbdsim.device),
            obs[1].to(gbdsim.device),
            obs[2].to(gbdsim.device),
            obs[3].to(gbdsim.device),
        )[0].item()
        for obs in val_data
    ]

In [None]:
sns.set_style("whitegrid")

In [None]:
sns.scatterplot(
    x=val_similarities,
    y=np.array(val_similarities) - np.array(gbdsim_similarities),
)
plt.xlabel("Ground-truth label")
plt.ylabel("Residual")

In [None]:
sns.scatterplot(
    x=val_similarities,
    y=np.array(val_similarities) - np.array(dataset2vec_similarities),
)
plt.xlabel("Ground-truth label")
plt.ylabel("Residual")

In [None]:
hist_data = pd.DataFrame(
    {
        "similarity": val_similarities + gbdsim_similarities,
        "origin": ["Ground-truth"] * len(val_similarities)
        + ["Predicted"] * len(gbdsim_similarities),
    }
)
sns.kdeplot(data=hist_data, x="similarity", hue="origin", fill=True)