In [None]:
import os

os.chdir("..")

In [None]:
import pickle as pkl
from gbdsim.data.data_pairs_generator import DatasetsPairsGenerator
from pathlib import Path
import json
import tqdm
import torch
import pytorch_lightning as pl
import numpy as np
from sklearn.metrics import confusion_matrix
from itertools import chain
from collections import Counter


pl.seed_everything(123)

In [None]:
with open(
    "results/uci/gbdsim/2025_04_13__17_31_46/final_model.pkl", "rb"
) as f:
    model = pkl.load(f)

In [None]:
with open("data/uci/meta_split.json", "r") as f:
    meta_split = json.load(f)

val_paths_with_data = list(
    filter(
        lambda p: p.stem in meta_split["val"],
        Path("data/uci/raw").iterdir(),
    )
)

observations = [
    DatasetsPairsGenerator.from_paths(
        val_paths_with_data
    ).generate_pair_of_datasets_with_label(return_datasets_paths=True)
    for _ in tqdm.trange(1000)
]

In [None]:
with torch.no_grad():
    probabilities = [
        model.model.calculate_dataset_origin_probability(
            obs[0].cuda(),
            obs[1].cuda(),
            obs[2].cuda(),
            obs[3].cuda(),
        )
        for obs in tqdm.tqdm(observations)
    ]
probabilities = torch.concat(probabilities, dim=0).cpu().numpy()

In [None]:
predictions = (probabilities > 0.5).astype(int)
labels = np.array([obs[4] for obs in observations])
dataset_names = [(obs[5].stem, obs[6].stem) for obs in observations]
datasets = [obs[:4] for obs in observations]

#### Confusion matrix

In [None]:
confusion_matrix(labels, predictions)

In [None]:
missclassified_idx = np.where(predictions != labels)[0]
properly_classified_idx = np.where(predictions == labels)[0]

In [None]:
missclassified_dataset_pairs = [
    dataset_names[idx] for idx in missclassified_idx
]
missclassified_pairs_where_labels_are_different = [
    pair for pair in missclassified_dataset_pairs if pair[0] != pair[1]
]
missclassified_pairs_where_labels_are_same = [
    pair for pair in missclassified_dataset_pairs if pair[0] == pair[1]
]

#### Most problematic datasets

In [None]:
Counter(
    list(chain(*missclassified_pairs_where_labels_are_different))
).most_common(30)

In [None]:
Counter(list(chain(*missclassified_pairs_where_labels_are_same))).most_common(
    30
)

#### Dimension difference

In [None]:
missclassified_dataset_pairs = [datasets[idx] for idx in missclassified_idx]
row_count_ratios = [
    max(
        dataset[0].shape[0] / dataset[2].shape[0],
        dataset[2].shape[0] / dataset[0].shape[0],
    )
    for dataset in missclassified_dataset_pairs
]
np.mean(row_count_ratios), np.std(row_count_ratios)

In [None]:
properly_classified_dataset_pairs = [
    datasets[idx] for idx in properly_classified_idx
]
row_count_ratios = [
    max(
        dataset[0].shape[0] / dataset[2].shape[0],
        dataset[2].shape[0] / dataset[0].shape[0],
    )
    for dataset in properly_classified_dataset_pairs
]
np.mean(row_count_ratios), np.std(row_count_ratios)

#### Statistics difference

In [None]:
missclassified_dataset_pairs = [datasets[idx] for idx in missclassified_idx]
means = [
    max(
        dataset[0].mean() / (dataset[2].mean() + 1e-2),
        dataset[2].mean() / (dataset[0].mean() + 1e-2),
    )
    for dataset in missclassified_dataset_pairs
]
np.mean(means), np.std(means)

In [None]:
properly_classified_dataset_pairs = [
    datasets[idx] for idx in properly_classified_idx
]
means = [
    max(
        dataset[0].mean() / (dataset[2].mean() + 1e-2),
        dataset[2].mean() / (dataset[0].mean() + 1e-2),
    )
    for dataset in properly_classified_dataset_pairs
]
np.mean(means), np.std(means)

#### Priors difference

In [None]:
missclassified_dataset_pairs = [datasets[idx] for idx in missclassified_idx]
means = [
    max(
        dataset[1].mean() / (dataset[3].mean() + 1e-2),
        dataset[3].mean() / (dataset[1].mean() + 1e-2),
    )
    for dataset in missclassified_dataset_pairs
]
np.mean(means), np.std(means)

In [None]:
properly_classified_dataset_pairs = [
    datasets[idx] for idx in properly_classified_idx
]
means = [
    max(
        dataset[1].mean() / (dataset[3].mean() + 1e-2),
        dataset[3].mean() / (dataset[1].mean() + 1e-2),
    )
    for dataset in properly_classified_dataset_pairs
]
np.mean(means), np.std(means)