In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from rae import PROJECT_ROOT
import random

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from pytorch_lightning import seed_everything
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes
import random

DEVICE: str = "cuda"

In [None]:
from pathlib import Path

EXPERIMENT_DIR: Path = PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison"
dataset_name: str = "cifar10"
dataset_split: str = "train"
dataset_perc: float = 0.4
label_key: str = "label"
DATASET_KEY = f"{dataset_name}_{dataset_split}_{dataset_perc}"

In [None]:
import timm

ENCODERS = (
    "vit_base_patch16_224",
    #     "rexnet_100",
    #     "vit_base_patch16_384",
    "vit_small_patch16_224",
)

In [None]:
VISION_DATASET_DIR: Path = EXPERIMENT_DIR / "encoded_data" / DATASET_KEY
VISION_DATASET_DIR

In [None]:
from typing import Sequence, List
from PIL.Image import Image


def encode_field(batch, src_field: str, tgt_field: str, transformation):
    src_data = batch[src_field]
    transformed = transformation(src_data)

    return {tgt_field: transformed}


@torch.no_grad()
def image_encode(images: Sequence[Image], transform, encoder):
    images: List[torch.Tensor] = [transform(image.convert("RGB")) for image in images]
    images: torch.Tensor = torch.stack(images, dim=0).to(DEVICE)
    encoding = encoder(images)

    return list(encoding.cpu().numpy())

In [None]:
from tqdm import tqdm
import functools
from timm.data import resolve_data_config
from datasets import load_dataset, load_from_disk, Dataset

from timm.data import create_transform

USE_CACHED: bool = True
if not VISION_DATASET_DIR.exists() or not USE_CACHED:

    def get_dataset(split: str, perc: float):
        assert 0 < perc <= 1
        dataset = load_dataset(
            dataset_name,
            split=split,
            use_auth_token=True,
        )
        seed_everything(42)

        # Select a random subset
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        assert indices[:10] == [9926, 35283, 27382, 36541, 23204, 30508, 24079, 48121, 43668, 43464], indices[:10]
        indices = indices[: int(len(indices) * perc)]
        dataset = dataset.select(indices)

        return dataset

    data: Dataset = get_dataset(split=dataset_split, perc=dataset_perc)
else:
    data: Dataset = load_from_disk(dataset_path=str(VISION_DATASET_DIR))

print(data)

FORCE_RECOMPUTE: bool = False
missing_encoders = [encoder for encoder in ENCODERS if FORCE_RECOMPUTE or encoder not in data.column_names]
for encoder_name in tqdm(missing_encoders):
    tgt_field: str = encoder_name
    encoder = timm.create_model(encoder_name, pretrained=True, num_classes=0).requires_grad_(False).eval().to(DEVICE)
    config = resolve_data_config({}, model=encoder)
    transform = create_transform(**config)

    data = data.map(
        functools.partial(
            encode_field,
            src_field="img",
            tgt_field=tgt_field,
            transformation=functools.partial(
                image_encode,
                transform=transform,
                encoder=encoder,
            ),
        ),
        num_proc=1,
        batched=True,
        batch_size=64,
        desc=f"{encoder_name}",
    )
    encoder = encoder.cpu()

    data.save_to_disk(str(VISION_DATASET_DIR))

if "index" not in data.column_names:
    data = data.map(lambda x, index: {"index": index}, with_indices=True)
    data.save_to_disk(str(VISION_DATASET_DIR))

data.set_format(columns=ENCODERS, output_all_columns=True, type="torch")

In [None]:
data

In [None]:
NUM_ANCHORS: int = 500

In [None]:
def get_latents(words, encoder):
    return data.select([int(x) for x in words])[encoder]

In [None]:
import numpy as np
from rae.openfaiss import FaissIndex
from rae.modules.attention import *
from torch_cluster import fps

rel_proj = RelativeAttention(
    n_anchors=NUM_ANCHORS,
    n_classes=None,
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
)


class LatentSpace:
    def __init__(
        self,
        encoding_type: str,
        encoder: str,
        vectors: torch.Tensor,
        ids: Sequence[int],
    ):
        self.encoding_type: str = encoding_type
        self.vectors: torch.Tensor = vectors
        self.ids: Sequence[int] = ids
        self.encoder: str = encoder

        self._cached_anchors = {}
        self._cache_index = None

    def to_faiss(self) -> FaissIndex:
        if self._cache_index is not None:
            return self._cache_index
        index: FaissIndex = FaissIndex(d=self.vectors.size(1))

        index.add_vectors(
            embeddings=list(zip([str(sample_id) for sample_id in self.ids], self.vectors.cpu().numpy())),
            normalize=True,
        )

        self._cache_index = index
        return index

    def to_relative(
        self, anchor_choice: str = None, seed: int = None, anchors: Optional[Sequence[str]] = None
    ) -> "RelativeSpace":
        assert self.encoding_type != "relative"  # TODO: for now
        anchors = self.get_anchors(anchor_choice=anchor_choice, seed=seed) if anchors is None else anchors

        anchor_latents: torch.Tensor = get_latents(words=anchors, encoder=self.encoder)

        relative_vectors = rel_proj(x=self.vectors, anchors=anchor_latents.cpu())[AttentionOutput.SIMILARITIES].cpu()
        return RelativeSpace(vectors=relative_vectors, encoder=self.encoder, anchors=anchors, ids=self.ids)

    def get_anchors(self, anchor_choice: str, seed: int) -> Sequence[str]:
        key = (seed, anchor_choice)
        if key in self._cached_anchors:
            # print(f"Cache match: {key} in {self._cached_anchors.keys()}")
            return self._cached_anchors[key]
        else:
            # print(f"Cache miss: {key} not in {self._cached_anchors.keys()}")
            pass
        # Select anchors
        seed_everything(seed)
        if anchor_choice == "uniform" or anchor_choice.startswith("top_"):
            limit: int = len(self.ids) if anchor_choice == "uniform" else int(anchor_choice[4:])
            anchor_set: Sequence[int] = random.sample(self.ids[:limit], NUM_ANCHORS)
        elif anchor_choice == "fps":
            anchor_fps = get_latents(words=self.ids, encoder=self.encoder)
            anchor_fps = F.normalize(anchor_fps, p=2, dim=-1)
            anchor_fps = fps(anchor_fps, random_start=True, ratio=NUM_ANCHORS / len(self.ids))
            anchor_set: Sequence[int] = [self.ids[word_index] for word_index in anchor_fps.cpu().tolist()]
        elif anchor_choice == "kmeans":
            vectors = F.normalize(get_latents(words=self.ids, encoder=self.encoder), p=2)
            clustered = KMeans(n_clusters=NUM_ANCHORS).fit_predict(vectors.cpu().numpy())

            all_targets = sorted(set(clustered))
            cluster2embeddings = {target: vectors[clustered == target] for target in all_targets}
            cluster2centroid = {
                cluster: centroid.mean(dim=0).cpu().numpy() for cluster, centroid in cluster2embeddings.items()
            }
            centroids = np.array(list(cluster2centroid.values()), dtype="float32")

            index: FaissIndex = FaissIndex(d=vectors.shape[1])
            index.add_vectors(
                list(zip([str(sample_id) for sample_id in self.ids], vectors.cpu().numpy())), normalize=False
            )
            centroids = index.search_by_vectors(query_vectors=centroids, k_most_similar=1, normalize=True)

            anchor_set = [list(word2score.keys())[0] for word2score in centroids]
        else:
            assert NotImplementedError

        result = sorted(anchor_set)
        self._cached_anchors[key] = result

        return result


class RelativeSpace(LatentSpace):
    def __init__(
        self,
        vectors: torch.Tensor,
        ids: Sequence[int],
        anchors: Sequence[str],
        encoder: str = None,
    ):
        super().__init__(encoding_type="relative", vectors=vectors, encoder=encoder, ids=ids)
        self.anchors: Sequence[str] = anchors

In [None]:
BENCHMARK_ENCODERS = (
    "vit_base_patch16_224",
    "vit_small_patch16_224",
    #     "vit_base_resnet50_384",
    #     "rexnet_100"
)
# BENCHMARK_ENCODERS = list(ENCODERS)
assert all(x in data.column_names for x in BENCHMARK_ENCODERS)

In [None]:
data.features[label_key].names

In [None]:
SEED: int = 42
seed_everything(SEED)

NUM_TARGETS = 200
ALL_IDS = list(range(len(data)))
TARGET_CLASS_IDS = [
    data.features[label_key].str2int(x) for x in ["bird", "ship", "cat", "frog"]
]  # words to take the neighborhoods from
print(f"{TARGET_CLASS_IDS=}")

pivot_space = LatentSpace(
    encoding_type="absolute",
    encoder=BENCHMARK_ENCODERS[0],
    vectors=data.select(ALL_IDS)[BENCHMARK_ENCODERS[0]],
    ids=ALL_IDS,
)

id2target = {}
for i_target, target_class in enumerate(TARGET_CLASS_IDS):
    target_data = data.filter(lambda x: x[label_key] == target_class).shuffle(seed=SEED)
    target_data = [sample for sample in target_data if sample["index"] not in id2target]
    for sample in target_data[:NUM_TARGETS]:
        id2target[sample["index"]] = i_target
    print(target_class, len(target_data))
TARGETS = list(id2target.values())
BENCHMARK_IDS = list(id2target.keys())
print(len(BENCHMARK_IDS))
ANCHOR_IDS = pivot_space.get_anchors(anchor_choice="uniform", seed=SEED)  # TODO: stratified
ANCHOR_IDS[:10]

In [None]:
# from umap import UMAP
from sklearn.manifold import TSNE
from enum import auto
from sklearn.decomposition import PCA


def latents_distance(latents):
    assert len(latents) == 2
    # for x in latents:
    #     assert x.shape[1] == 300

    dist = F.pairwise_distance(latents[0], latents[1], p=2).mean()
    return f"{dist:.2f}"


class Reduction(StrEnum):
    PCA = auto()
    TSNE = auto()
    UMAP = auto()
    FIRST_DIMS = auto()


def to_df(latents, mode: Reduction = "pca"):
    latents2d: np.ndarray
    if mode == Reduction.PCA:
        latents2d = PCA(n_components=2, random_state=1).fit_transform(latents.cpu())
    elif mode == Reduction.FIRST_DIMS:
        latents2d = latents[:, [0, 1]].cpu().numpy()
    elif mode == Reduction.TSNE:
        latents2d = TSNE(n_components=2, init="pca", learning_rate="auto", random_state=42).fit_transform(latents.cpu())
    # elif mode == Reduction.UMAP:
    #     latents2d = UMAP(n_components=2).fit_transform(latents.cpu())
    else:
        raise NotImplementedError

    # F.normalize(latents2d, p=2, dim=-1)
    df = pd.DataFrame(
        {
            "x": latents2d[:, 0].tolist(),
            "y": latents2d[:, 1].tolist(),
            "target": TARGETS,
        }
    )
    return df

In [None]:
def plot_bg(
    ax,
    df,
    cmap,
    norm,
    size,
    bg_alpha,
):
    """Create and return a plot of all our movie embeddings with very low opacity.
    (Intended to be used as a basis for further - more prominent - plotting of a
    subset of movies. Having the overall shape of the map space in the background is
    useful for context.)
    """
    ax.scatter(df.x, df.y, c=cmap(norm(df["target"])), alpha=bg_alpha, s=size)
    return ax


def hightlight_cluster(
    ax,
    df,
    target,
    alpha,
    cmap,
    norm,
    size=0.5,
):
    cluster_df = df[df["target"] == target]
    ax.scatter(cluster_df.x, cluster_df.y, c=cmap(norm(cluster_df["target"])), alpha=alpha, s=size)


def plot_latent_space(ax, df, targets, size, cmap, norm, bg_alpha, alpha):
    ax = plot_bg(ax, df, bg_alpha=bg_alpha, cmap=cmap, norm=norm, size=size)
    for target in targets:
        hightlight_cluster(ax, df, target, alpha=alpha, size=size, cmap=cmap, norm=norm)
    return ax

In [None]:
ae_latents = {}
anchors_latents = {}
for encoder in BENCHMARK_ENCODERS:
    ae_latents[encoder] = get_latents(words=BENCHMARK_IDS, encoder=encoder)
    anchors_latents[encoder] = get_latents(words=ANCHOR_IDS, encoder=encoder)

import copy

original_ae_latents = copy.deepcopy(ae_latents)
original_anchor_latents = copy.deepcopy(anchors_latents)

In [None]:
REDUCTION = Reduction.TSNE

In [None]:
from sklearn.decomposition import PCA
from rae.modules.attention import *


col_config = ((None, None),)
N_ROWS = len(BENCHMARK_ENCODERS)
N_COLS = len(col_config) + 1
print(
    N_ROWS,
    N_COLS,
)
plt.rcParams.update(bundles.iclr2023())
plt.rcParams.update(figsizes.iclr2023(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

import matplotlib as mpl

num_colors = len(TARGET_CLASS_IDS)
cmap = mpl.colors.ListedColormap(plt.cm.get_cmap("Set1", 10).colors[:num_colors], name="rgb", N=num_colors)
norm = plt.Normalize(min(TARGETS), max(TARGETS))

fig, axes = plt.subplots(dpi=300, nrows=N_ROWS, ncols=N_COLS, sharey="row", sharex="row", squeeze=True)

S = 7
BG_ALPHA = 0.35
ALPHA = 0.5

TARGET_HIGHLIGHT = [1]
for ax_encoders, (_, latents) in zip(axes[0], ae_latents.items()):
    plot_latent_space(
        ax_encoders,
        to_df(latents, mode=REDUCTION),
        targets=TARGET_HIGHLIGHT,
        size=S,
        bg_alpha=BG_ALPHA,
        alpha=ALPHA,
        cmap=cmap,
        norm=norm,
    )

# distances = {"absolute": latents_distance(list(ae_latents.values()))}

for col_i, (quant_mode, bin_size) in enumerate(col_config):
    rel_attention = RelativeAttention(
        n_anchors=NUM_ANCHORS,
        n_classes=None,
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
    )
    assert sum(x.numel() for x in rel_attention.parameters()) == 0
    rels = []
    for row_ax, (enc_name, latents), (a_enc_name, a_latents) in zip(
        axes[1], ae_latents.items(), anchors_latents.items()
    ):
        print(enc_name, a_enc_name)
        assert enc_name == a_enc_name
        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]
        rels.append(rel)
        plot_latent_space(
            row_ax,
            to_df(rel, mode=REDUCTION),
            targets=TARGET_HIGHLIGHT,
            size=S,
            bg_alpha=BG_ALPHA,
            alpha=ALPHA,
            cmap=cmap,
            norm=norm,
        )
#     distances[f"relative({quant_mode}, {bin_size})"] = latents_distance(rels)

# distances

In [None]:
set(norm(TARGETS))

In [None]:
file_name: str = f"embeddings-qualitative-{DATASET_KEY}-{REDUCTION}"
file_name_pdf: str = f"{file_name}.pdf"
file_name_svg: str = f"{file_name}.svg"

fig.savefig(f"{file_name}.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o $file_name_pdf $file_name_svg
!rm $file_name_svg

In [None]:
ANCHOR_CHOICES = ("uniform", "fps", "kmeans")

In [None]:
from typing import *

In [None]:
SEARCH_IDS = [str(x) for x in range(len(data))][:]

In [None]:
from rae.cka import CudaCKA as CKA

EncPair = Tuple[str, str]


@torch.no_grad()
def evaluate_retrieval(latent_space1: LatentSpace, latent_space2: LatentSpace, k: int = 5):
    performance = {
        key: []
        for key in (
            "src_enc",
            "tgt_enc",
            "topk_jaccard",
            "mrr",
            "linear_cka",
            "rbf_kernel_cka",
            "mse",
            "cosine_sim",
        )
    }

    # index1: FaissIndex = latent_space1.to_faiss()
    index2: FaissIndex = latent_space2.to_faiss()

    target_neighbors = index2.search_by_vectors(
        query_vectors=latent_space2.vectors.cpu().numpy(), k_most_similar=k, normalize=True
    )
    actual_neighbors = index2.search_by_vectors(
        query_vectors=latent_space1.vectors.cpu().numpy(), k_most_similar=k, normalize=True
    )

    target_neighbors: Mapping[str, Mapping[str, float]] = {
        word: topk for word, topk in zip(SEARCH_IDS, target_neighbors)
    }
    actual_neighbors: Mapping[str, Mapping[str, float]] = {
        word: topk for word, topk in zip(SEARCH_IDS, actual_neighbors)
    }

    target_words: Mapping[str, Set[str]] = {
        search_word: set(target_neighbors[search_word].keys()) for search_word in SEARCH_IDS
    }
    actual_words: Mapping[str, Set[str]] = {
        search_word: set(actual_neighbors[search_word].keys()) for search_word in SEARCH_IDS
    }

    topk_jaccard = {
        search_word: len(set.intersection(target_words[search_word], actual_words[search_word]))
        / len(set.union(target_words[search_word], actual_words[search_word]))
        for search_word in SEARCH_IDS
    }
    topk_jaccard = np.mean(list(topk_jaccard.values()))

    search_word2word2rank = {
        search_word: {key: index for index, key in enumerate(word2sim.keys(), start=1)}
        for search_word, word2sim in actual_neighbors.items()
    }
    mrr = {
        search_word: (
            #                 word2rank.get(search_word, K)
            0
            if search_word not in word2rank
            else 1 / word2rank[search_word]
        )
        for search_word, word2rank in search_word2word2rank.items()
    }
    mrr = np.mean(list(mrr.values()))

    # semantic_horizon = []
    # for search_word, neighbors in actual_words.items():
    #     neighbor2ranking = {
    #         neighbor: {
    #             key: index
    #             for index, key in enumerate(
    #                 enc_type2enc_names2word2topk["absolute"][(enc_name2, enc_name2)][neighbor].keys(), start=1
    #             )
    #         }
    #         for neighbor in neighbors
    #     }
    #     neighbor2mrr = {
    #         neighbor: (
    #             #                 topk.get(search_word, K)
    #             0
    #             if search_word not in ranking
    #             else 1 / ranking[search_word]
    #         )
    #         for neighbor, ranking in neighbor2ranking.items()
    #     }
    #     semantic_horizon.append(np.mean(list(neighbor2mrr.values())))
    #
    # semantic_horizon = np.mean(semantic_horizon)

    chunk_size: int = 10_000
    num_chunks: int = (len(SEARCH_IDS) + chunk_size - 1) // chunk_size
    linear_cka, rbf_kernel_cka, mse, cosine_sim = [], [], [], []
    for chunk_latents1, chunk_latents2 in zip(
        latent_space1.vectors.chunk(num_chunks), latent_space2.vectors.chunk(num_chunks)
    ):
        chunk_latents1 = chunk_latents1.cuda()
        chunk_latents2 = chunk_latents2.cuda()
        cka = CKA(device=DEVICE)

        chunk_linear_cka = cka.linear_CKA(chunk_latents1, chunk_latents2).cpu()
        # chunk_rbf_kernel_cka = cka.kernel_CKA(chunk_latents1, chunk_latents2).cpu()
        chunk_cosine_sim = F.cosine_similarity(chunk_latents1, chunk_latents2).mean().cpu()
        chunk_mse = F.mse_loss(chunk_latents1, chunk_latents2, reduction="sum").cpu()

        _ = chunk_latents1.cpu()
        _ = chunk_latents2.cpu()

        linear_cka.append(chunk_linear_cka)
        rbf_kernel_cka.append(torch.zeros(1))
        mse.append(chunk_mse)
        cosine_sim.append(chunk_cosine_sim)

    linear_cka = torch.stack(linear_cka).mean(dim=0).cpu().item()
    rbf_kernel_cka = torch.stack(rbf_kernel_cka).mean(dim=0).cpu().item()
    mse = torch.stack(mse).mean(dim=0).cpu().item()
    cosine_sim = torch.stack(cosine_sim).mean(dim=0).cpu().item()

    performance["src_enc"].append(latent_space1.encoder)
    performance["tgt_enc"].append(latent_space2.encoder)
    performance["topk_jaccard"].append(topk_jaccard)
    performance["mrr"].append(mrr)
    performance["linear_cka"].append(linear_cka)
    performance["rbf_kernel_cka"].append(rbf_kernel_cka)
    performance["mse"].append(mse)
    performance["cosine_sim"].append(cosine_sim)

    performance = pd.DataFrame(performance)
    performance["enc1_type"] = latent_space1.encoding_type
    performance["enc2_type"] = latent_space2.encoding_type

    return performance


# @torch.no_grad()
# def evaluate(
#     enc_type2enc_name2faiss_index: Mapping[str, Mapping[str, FaissIndex]],
#     enc_type2enc_names2word2topk: Mapping[str, Mapping[Tuple[str, str], Mapping[str, Sequence[str]]]],
# ):
#     performance = {key: [] for key in ("enc_name", "linear_cka", "rbf_kernel_cka")}
#
#     for enc_name in ("word2vec-google-news-300", "local_fasttext"):
#         faiss_abs = enc_type2enc_name2faiss_index["absolute"][enc_name]
#         faiss_rel = enc_type2enc_name2faiss_index["relative"][enc_name]
#
#         linear_cka, rbf_kernel_cka = [], []
#         for chunk in chunk_iterable(SEARCH_WORDS, chunk_size=5_000):
#             chunk_latents_enc1 = torch.as_tensor(faiss_abs.reconstruct_n(keys=chunk), device=DEVICE)
#             chunk_latents_enc2 = torch.as_tensor(faiss_rel.reconstruct_n(keys=chunk), device=DEVICE)
#
#             cka = CKA(device=DEVICE)
#             linear_cka.append(cka.linear_CKA(chunk_latents_enc1, chunk_latents_enc2))
#             rbf_kernel_cka.append(cka.kernel_CKA(chunk_latents_enc1, chunk_latents_enc2))
#
#         linear_cka = torch.stack(linear_cka).mean(dim=0).cpu().item()
#         rbf_kernel_cka = torch.stack(rbf_kernel_cka).mean(dim=0).cpu().item()
#
#         performance["enc_name"].append(enc_name)
#         performance["linear_cka"].append(linear_cka)
#         performance["rbf_kernel_cka"].append(rbf_kernel_cka)
#
#     return pd.DataFrame(performance)

In [None]:
NUM_SEEDS: int = 1

In [None]:
import itertools
from sklearn.cluster import KMeans
from torchmetrics.functional import pairwise_cosine_similarity
from nn_core.common import PROJECT_ROOT


performances = []
anchor_infos = []
for seed, anchor_choice in (pbar := tqdm(list(itertools.product(range(NUM_SEEDS), ANCHOR_CHOICES)))):
    absolute_spaces: Sequence[LatentSpace] = [
        LatentSpace(
            encoding_type="absolute",
            vectors=data[encoder],
            ids=SEARCH_IDS,
            encoder=encoder,
        )
        for encoder in BENCHMARK_ENCODERS
    ]

    for absolute_space in absolute_spaces:
        words: Sequence[str] = absolute_space.get_anchors(anchor_choice=anchor_choice, seed=seed)
        anchors: torch.Tensor = get_latents(words=words, encoder=absolute_space.encoder)
        anchor_info = {
            "seed": seed,
            "anchor_choice": anchor_choice,
            "anchors": anchors,
            "words": words,
            "encoder": absolute_space.encoder,
            "dists": pairwise_cosine_similarity(anchors, zero_diagonal=False),
        }
        anchor_infos.append(anchor_info)

    for abs_space1, abs_space2 in itertools.product(absolute_spaces, repeat=2):
        # absolute
        absolute_performance = evaluate_retrieval(latent_space1=abs_space1, latent_space2=abs_space2)

        absolute_performance["anchor_choice"] = anchor_choice
        absolute_performance["seed"] = seed
        performances.append(absolute_performance)

        # relative
        rel_space1: RelativeSpace = abs_space1.to_relative(seed=seed, anchor_choice=anchor_choice)
        rel_space2: RelativeSpace = abs_space2.to_relative(anchors=rel_space1.anchors)
        relative_performance = evaluate_retrieval(latent_space1=rel_space1, latent_space2=rel_space2)
        relative_performance["anchor_choice"] = anchor_choice
        relative_performance["seed"] = seed

        performances.append(relative_performance)

performances = pd.concat(performances)
performances.to_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / f"quantitative_analysis_{DATASET_KEY}.tsv",
    sep="\t",
    index=False,
)
torch.save(
    anchor_infos, PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / f"anchor_infos_{DATASET_KEY}.pt"
)
performances

In [None]:
performance_df = pd.read_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / f"quantitative_analysis_{DATASET_KEY}.tsv",
    sep="\t",
)
performance_df.groupby(["anchor_choice", "enc1_type", "enc1_type", "src_enc", "tgt_enc"]).aggregate(
    [
        np.mean,
    ]
)