In [2]:
from engine.dataset2vec.data import get_preprocessing_pipeline

from engine.dataset2vec.train import LightningWrapper as D2vWrapper
from liltab.train.utils import LightningWrapper as LiltabWrapper
from sklearn.metrics import calinski_harabasz_score
from pathlib import Path
from torch import Tensor
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

In [3]:
data_path = Path("data/mimic/mini_holdout")
liltab_encoder_path = "models/liltab.ckpt"
d2v_encoder_path = "models/d2v.ckpt"

In [4]:
tasks_paths = list(data_path.rglob("*test.csv"))
datasets = dict()

for task_path in tqdm(tasks_paths):
    task_name = str(task_path).split("/")[-3]
    if task_name not in datasets:
        datasets[task_name] = []
    df = pd.read_csv(task_path)
    X, y = Tensor(get_preprocessing_pipeline().fit_transform(df.iloc[:, :-1]).values), Tensor(
        df.iloc[:, -1].values
    ).reshape(-1, 1)
    datasets[task_name].append((X, y))

datasets = list(datasets.values())
n_datasets = len(datasets)

100%|██████████| 5915/5915 [00:22<00:00, 264.05it/s]


In [5]:
def get_sample(dataset):
    idx = np.random.choice(len(dataset))
    return dataset[idx]

In [6]:
liltab_encoder = LiltabWrapper.load_from_checkpoint(liltab_encoder_path).model
d2v_encoder = D2vWrapper.load_from_checkpoint(d2v_encoder_path).encoder

/home/dawid/miniconda3/envs/encoders/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/home/dawid/miniconda3/envs/encoders/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


In [7]:
encoders = {
    "liltab": lambda X, y: liltab_encoder.encode_support_set(X, y).mean(dim=0),
    "d2v": d2v_encoder,
}

In [8]:
NUMBER_OF_POINTS_PER_DATASET = 100
NUMBER_OF_DATASETS = 12

In [9]:

encoder = encoders["d2v"]
embeddings_all = []
labels = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for _ in range(NUMBER_OF_POINTS_PER_DATASET):
            X, y = get_sample(dataset)
            embeddings_all.append(encoder(X, y))
        labels.extend([i]*NUMBER_OF_POINTS_PER_DATASET)
embeddings_all = torch.stack(embeddings_all).numpy()

12it [00:00, 21.11it/s]


In [14]:
calinski_harabasz_score(embeddings_all, labels)

61.861518255224055

In [16]:
encoder = encoders["liltab"]
embeddings_all = []
labels = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for j in range(100):
            X, y = get_sample(dataset)
            embeddings_all.append(encoder(X, y))
        labels.extend([i]*NUMBER_OF_POINTS_PER_DATASET)
embeddings_all = torch.stack(embeddings_all).numpy()

12it [00:01, 10.09it/s]


In [17]:
calinski_harabasz_score(embeddings_all, labels)

70.9527387974541