In [1]:
from engine.dataset2vec.data import get_preprocessing_pipeline

from engine.dataset2vec.train import LightningWrapper as D2vWrapper
from liltab.train.utils import LightningWrapper as LiltabWrapper
from pathlib import Path
from torch import Tensor
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

In [2]:
data_path = Path("data/uci/raw")
liltab_encoder_path = "models/liltab.ckpt"
d2v_encoder_path = "models/d2v.ckpt"

In [3]:
dataframes = [pd.read_csv(data_path) for data_path in data_path.iterdir()]
datasets = [
    (
        (
            Tensor(
                get_preprocessing_pipeline()
                .fit_transform(df.iloc[:, :-1])
                .values
            ),
            Tensor(df.iloc[:, -1].values).reshape(-1, 1),
        )
    )
    for df in dataframes
]
n_datasets = len(datasets)

In [4]:
def get_sample(dataset):
    X, y = dataset
    rows_idx = sample_with_random_size(X.shape[0]).tolist()
    cols_idx = sample_with_random_size(X.shape[1]).tolist()
    return index_tensor(X, rows_idx, cols_idx), y[rows_idx]


def sample_with_random_size(arr):
    if isinstance(arr, int):
        arr = np.arange(arr)
    size = np.random.choice(np.arange(1, len(arr)))
    return np.random.choice(arr, size=size)

def index_tensor(tensor, row_idx, col_idx):
    return tensor[row_idx].T[col_idx].T

In [5]:
liltab_encoder = LiltabWrapper.load_from_checkpoint(liltab_encoder_path).model
d2v_encoder = D2vWrapper.load_from_checkpoint(d2v_encoder_path).encoder

/mnt/linux/sync/research/rethinking_encoder_warmstart/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/mnt/linux/sync/research/rethinking_encoder_warmstart/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


In [6]:
encoders = {
    "liltab": lambda X, y: liltab_encoder.encode_support_set(X, y).mean(dim=0),
    "d2v": d2v_encoder,
}

In [47]:
NUMBER_OF_POINTS_PER_DATASET = 100
NUMBER_OF_DATASETS = 103

In [42]:
encoder = encoders["d2v"]
embeddings_all = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for j in range(NUMBER_OF_POINTS_PER_DATASET):
            X, y = get_sample(dataset)
            if j == 0:
                embeddings_all.append([])
            embeddings_all[i].append(encoder(X, y))

103it [01:39,  1.04it/s]


In [53]:
centroids = []
for i, embeddings in enumerate(embeddings_all):
    centroids.append(torch.stack(embeddings).mean(dim=0))

centroids_stacked = torch.stack(centroids)
global_centroid = centroids_stacked.mean(dim=0)
numerator = ((centroids_stacked - global_centroid)**2).sum()*NUMBER_OF_POINTS_PER_DATASET/NUMBER_OF_DATASETS

denominator = 0
for centroid in centroids:
    for embeddings in embeddings_all:
        embeddings = torch.stack(embeddings)
        denominator += ((embeddings - centroid)**2).sum()

denominator /= (NUMBER_OF_POINTS_PER_DATASET*NUMBER_OF_DATASETS - NUMBER_OF_DATASETS)
numerator / denominator

tensor(0.3974)

In [None]:
mean_intra_distances = []
for i, embeddings in enumerate(embeddings_all):
    mean_intra_distances.append(
        torch.cdist(torch.stack(embeddings), torch.stack(embeddings))
        .mean()
        .item()
        / 2
    )
mean_intra_distance = np.mean(mean_intra_distances)

centroids = []
for i, embeddings in enumerate(embeddings_all):
    centroids.append(torch.stack(embeddings).mean(dim=0))
mean_inter_distance = (
    torch.cdist(torch.stack(centroids), torch.stack(centroids)).mean().item()
    / 2
)

mean_inter_distance / mean_intra_distance

1.9153724200362352

In [54]:
encoder = encoders["liltab"]
embeddings_all = []
with torch.no_grad():
    for i, dataset in tqdm(enumerate(datasets)):
        for j in range(100):
            X, y = get_sample(dataset)
            if j == 0:
                embeddings_all.append([])
            embeddings_all[i].append(encoder(X, y))

103it [03:31,  2.05s/it]


In [57]:
mean_intra_distances = []
for i, embeddings in enumerate(embeddings_all):
    mean_intra_distances.append(
        torch.cdist(torch.stack(embeddings), torch.stack(embeddings))
        .mean()
        .item()
        / 2
    )
mean_intra_distance = np.mean(mean_intra_distances)

centroids = []
for i, embeddings in enumerate(embeddings_all):
    centroids.append(torch.stack(embeddings).mean(dim=0))
mean_inter_distance = (
    torch.cdist(torch.stack(centroids), torch.stack(centroids)).mean().item()
    / 2
)

mean_inter_distance / mean_intra_distance

1.701639624095202

In [58]:
centroids = []
for i, embeddings in enumerate(embeddings_all):
    centroids.append(torch.stack(embeddings).mean(dim=0))

centroids_stacked = torch.stack(centroids)
global_centroid = centroids_stacked.mean(dim=0)
numerator = (
    ((centroids_stacked - global_centroid) ** 2).sum()
    * NUMBER_OF_POINTS_PER_DATASET
    / NUMBER_OF_DATASETS
)

denominator = 0
for centroid in centroids:
    for embeddings in embeddings_all:
        embeddings = torch.stack(embeddings)
        denominator += ((embeddings - centroid) ** 2).sum()

denominator /= (
    NUMBER_OF_POINTS_PER_DATASET * NUMBER_OF_DATASETS - NUMBER_OF_DATASETS
)
numerator / denominator

tensor(0.3513)