In [10]:
from engine.dataset2vec.data import get_preprocessing_pipeline

from engine.dataset2vec.train import LightningWrapper as D2vWrapper
from liltab.train.utils import LightningWrapper as LiltabWrapper
from pathlib import Path
from sklearn.metrics import calinski_harabasz_score
from torch import Tensor
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

In [2]:
data_path = Path("data/uci/raw")
liltab_encoder_path = "models/liltab.ckpt"
d2v_encoder_path = "models/d2v.ckpt"

In [3]:
dataframes = [pd.read_csv(data_path) for data_path in data_path.iterdir()]
datasets = [
    (
        (
            Tensor(
                get_preprocessing_pipeline()
                .fit_transform(df.iloc[:, :-1])
                .values
            ),
            Tensor(df.iloc[:, -1].values).reshape(-1, 1),
        )
    )
    for df in dataframes
]
n_datasets = len(datasets)

In [4]:
def get_sample(dataset):
    X, y = dataset
    rows_idx = sample_with_random_size(X.shape[0]).tolist()
    cols_idx = sample_with_random_size(X.shape[1]).tolist()
    return index_tensor(X, rows_idx, cols_idx), y[rows_idx]


def sample_with_random_size(arr):
    if isinstance(arr, int):
        arr = np.arange(arr)
    size = np.random.choice(np.arange(1, len(arr)))
    return np.random.choice(arr, size=size)

def index_tensor(tensor, row_idx, col_idx):
    return tensor[row_idx].T[col_idx].T

In [5]:
liltab_encoder = LiltabWrapper.load_from_checkpoint(liltab_encoder_path).model
d2v_encoder = D2vWrapper.load_from_checkpoint(d2v_encoder_path).encoder

/home/dawid/miniconda3/envs/encoders/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/home/dawid/miniconda3/envs/encoders/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


In [6]:
encoders = {
    "liltab": lambda X, y: liltab_encoder.encode_support_set(X, y).mean(dim=0),
    "d2v": d2v_encoder,
}

In [7]:
NUMBER_OF_POINTS_PER_DATASET = 100
NUMBER_OF_DATASETS = 103

In [9]:
encoder = encoders["d2v"]
embeddings_all = []
labels = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for _ in range(NUMBER_OF_POINTS_PER_DATASET):
            X, y = get_sample(dataset)
            embeddings_all.append(encoder(X, y))
        labels.extend([i]*NUMBER_OF_POINTS_PER_DATASET)
embeddings_all = torch.stack(embeddings_all).numpy()

103it [00:28,  3.64it/s]


In [11]:
calinski_harabasz_score(embeddings_all, labels)

244.63536816174695

In [12]:
encoder = encoders["liltab"]
embeddings_all = []
labels = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for _ in range(NUMBER_OF_POINTS_PER_DATASET):
            X, y = get_sample(dataset)
            embeddings_all.append(encoder(X, y))
        labels.extend([i]*NUMBER_OF_POINTS_PER_DATASET)
embeddings_all = torch.stack(embeddings_all).numpy()

103it [01:08,  1.49it/s]


In [13]:
calinski_harabasz_score(embeddings_all, labels)

131.92549155749902