In [1]:
from engine.dataset2vec.data import get_preprocessing_pipeline

from engine.dataset2vec.train import LightningWrapper as D2vWrapper
from liltab.train.utils import LightningWrapper as LiltabWrapper
from pathlib import Path
from torch import Tensor
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

In [2]:
data_path = Path("data/mimic/mini_holdout")
liltab_encoder_path = "models/liltab.ckpt"
d2v_encoder_path = "models/d2v.ckpt"

In [52]:
tasks_paths = list(data_path.rglob("*test.csv"))
datasets = dict()

for task_path in tqdm(tasks_paths):
    task_name = str(task_path).split("/")[-3]
    if task_name not in datasets:
        datasets[task_name] = []
    df = pd.read_csv(task_path)
    X, y = Tensor(get_preprocessing_pipeline().fit_transform(df.iloc[:, :-1]).values), Tensor(
        df.iloc[:, -1].values
    ).reshape(-1, 1)
    datasets[task_name].append((X, y))

datasets = list(datasets.values())
n_datasets = len(datasets)

100%|██████████| 5917/5917 [01:09<00:00, 85.23it/s]


In [53]:
def get_sample(datasets):
    if np.random.uniform() <= 0.5:
        return sample_from_same_datasets(datasets)
    else:
        return sample_from_two_datasets(datasets)

def sample_from_same_datasets(datasets):
    dataset_idx = np.random.choice(n_datasets)
    dataset = datasets[dataset_idx]
    task_1_idx, task_2_idx = np.random.choice(len(dataset), size=2)
    return dataset[task_1_idx], dataset[task_2_idx], 1


def sample_from_two_datasets(datasets):
    dataset_1_idx, dataset_2_idx = np.random.choice(n_datasets, size=2)
    dataset_1, dataset_2 = datasets[dataset_1_idx], datasets[dataset_2_idx]
    task_1_idx = np.random.choice(len(dataset_1))
    task_2_idx = np.random.choice(len(dataset_2))

    return dataset_1[task_1_idx], dataset_2[task_2_idx], 0

def sample_with_random_size(arr):
    if isinstance(arr, int):
        arr = np.arange(arr)
    size = np.random.choice(np.arange(1, len(arr)))
    return np.random.choice(arr, size=size)

def index_tensor(tensor, row_idx, col_idx):
    return tensor[row_idx].T[col_idx].T

In [54]:
liltab_encoder = LiltabWrapper.load_from_checkpoint(liltab_encoder_path).model
d2v_encoder = D2vWrapper.load_from_checkpoint(d2v_encoder_path).encoder

In [55]:
(X1, y1), (X2, y2), label = sample_from_two_datasets(datasets)

In [56]:
encoders = {
    "liltab": lambda X, y: liltab_encoder.encode_support_set(X, y).mean(dim=0),
    "d2v": d2v_encoder,
}

In [57]:
encoder = encoders["d2v"]
correctness = []
for i in tqdm(range(10000)):
    (X1, y1), (X2, y2), label = get_sample(datasets)
    encoding_1 = encoder(X1, y1)
    encoding_2 = encoder(X2, y2)
    prediction = int(torch.exp(-torch.sqrt(((encoding_1 - encoding_2)**2).sum())).item() >= 0.5)
    correctness.append(prediction == label)
np.mean(correctness)

100%|██████████| 10000/10000 [00:56<00:00, 175.64it/s]


0.5614

In [58]:
import warnings
from sklearn.isotonic import IsotonicRegression

warnings.simplefilter("ignore")

In [59]:
encoder = encoders["liltab"]

probas = []
labels = []
for i in tqdm(range(1000)):
    (X1, y1), (X2, y2), label = get_sample(datasets)
    encoding_1 = encoder(X1, y1)
    encoding_2 = encoder(X2, y2)
    probas.append(
        torch.exp(-torch.sqrt(((encoding_1 - encoding_2) ** 2).sum())).item()
    )
    labels.append(label)
    correctness.append(prediction == label)
calib = IsotonicRegression().fit(np.array(probas).reshape(-1, 1), labels)

correctness = []
for i in tqdm(range(10000)):
    (X1, y1), (X2, y2), label = get_sample(datasets)
    encoding_1 = encoder(X1, y1)
    encoding_2 = encoder(X2, y2)
    prediction = int(
        calib.predict([torch.exp(-torch.sqrt(((encoding_1 - encoding_2) ** 2).sum())).item()])
        >= 0.5
    )
    correctness.append(prediction == label)
np.mean(correctness)

  1%|          | 8/1000 [00:00<00:12, 76.53it/s]

100%|██████████| 1000/1000 [00:09<00:00, 110.92it/s]
100%|██████████| 10000/10000 [01:31<00:00, 109.29it/s]


0.5605