In [1]:
from engine.dataset2vec.data import get_preprocessing_pipeline

from engine.dataset2vec.train import LightningWrapper as D2vWrapper
from liltab.train.utils import LightningWrapper as LiltabWrapper
from pathlib import Path
from torch import Tensor
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

In [2]:
data_path = Path("data/mimic/mini_holdout")
liltab_encoder_path = "models/liltab.ckpt"
d2v_encoder_path = "models/d2v.ckpt"

In [3]:
tasks_paths = list(data_path.rglob("*test.csv"))
datasets = dict()

for task_path in tqdm(tasks_paths):
    task_name = str(task_path).split("/")[-3]
    if task_name not in datasets:
        datasets[task_name] = []
    df = pd.read_csv(task_path)
    X, y = Tensor(get_preprocessing_pipeline().fit_transform(df.iloc[:, :-1]).values), Tensor(
        df.iloc[:, -1].values
    ).reshape(-1, 1)
    datasets[task_name].append((X, y))

datasets = list(datasets.values())
n_datasets = len(datasets)

100%|██████████| 5917/5917 [01:19<00:00, 74.25it/s]


In [11]:
def get_sample(dataset):
    idx = np.random.choice(len(dataset))
    return dataset[idx]

In [12]:
liltab_encoder = LiltabWrapper.load_from_checkpoint(liltab_encoder_path).model
d2v_encoder = D2vWrapper.load_from_checkpoint(d2v_encoder_path).encoder

/mnt/linux/sync/research/rethinking_encoder_warmstart/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/mnt/linux/sync/research/rethinking_encoder_warmstart/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


In [13]:
encoders = {
    "liltab": lambda X, y: liltab_encoder.encode_support_set(X, y).mean(dim=0),
    "d2v": d2v_encoder,
}

In [14]:
NUMBER_OF_POINTS_PER_DATASET = 100
NUMBER_OF_DATASETS = 12

In [15]:
encoder = encoders["d2v"]
embeddings_all = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for j in range(NUMBER_OF_POINTS_PER_DATASET):
            X, y = get_sample(dataset)
            if j == 0:
                embeddings_all.append([])
            embeddings_all[i].append(encoder(X, y))

12it [00:02,  5.84it/s]


In [19]:
centroids = []
for i, embeddings in enumerate(embeddings_all):
    centroids.append(torch.stack(embeddings).mean(dim=0))

centroids_stacked = torch.stack(centroids)
global_centroid = centroids_stacked.mean(dim=0)
numerator = (
    ((centroids_stacked - global_centroid) ** 2).sum()
    * NUMBER_OF_POINTS_PER_DATASET
    / NUMBER_OF_DATASETS
)

denominator = 0
for centroid in centroids:
    for embeddings in embeddings_all:
        embeddings = torch.stack(embeddings)
        denominator += ((embeddings - centroid) ** 2).sum()

denominator /= (
    NUMBER_OF_POINTS_PER_DATASET * NUMBER_OF_DATASETS - NUMBER_OF_DATASETS
)
numerator / denominator

tensor(2.1774)

In [21]:
encoder = encoders["liltab"]
embeddings_all = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for j in range(100):
            X, y = get_sample(dataset)
            if j == 0:
                embeddings_all.append([])
            embeddings_all[i].append(encoder(X, y))

12it [00:03,  3.07it/s]


In [23]:
centroids = []
for i, embeddings in enumerate(embeddings_all):
    centroids.append(torch.stack(embeddings).mean(dim=0))

centroids_stacked = torch.stack(centroids)
global_centroid = centroids_stacked.mean(dim=0)
numerator = (
    ((centroids_stacked - global_centroid) ** 2).sum()
    * NUMBER_OF_POINTS_PER_DATASET
    / NUMBER_OF_DATASETS
)

denominator = 0
for centroid in centroids:
    for embeddings in embeddings_all:
        embeddings = torch.stack(embeddings)
        denominator += ((embeddings - centroid) ** 2).sum()

denominator /= (
    NUMBER_OF_POINTS_PER_DATASET * NUMBER_OF_DATASETS - NUMBER_OF_DATASETS
)
numerator / denominator

tensor(2.2849)