In [None]:
import random

import pandas as pd
import torch
from typing import Mapping, Tuple

from rae.openfaiss import FaissIndex
from tqdm import tqdm
import itertools
from pytorch_lightning import seed_everything
import numpy as np
from rae.modules.attention import *
from rae.utils.utils import StrEnum
from sklearn.cluster import KMeans
from torchmetrics.functional import pairwise_cosine_similarity
from nn_core.common import PROJECT_ROOT

In [None]:
DEVICE: str = "cuda"
MODELS = (
    "local_fasttext",
    "word2vec-google-news-300",
)

In [None]:
from rae.modules.text.encoder import GensimEncoder

ENCODERS = {model_name: GensimEncoder(language="en", lemmatize=False, model_name=model_name) for model_name in MODELS}

In [None]:
all_words = list(list(ENCODERS.values())[0].model.key_to_index.keys())
all_words = [
    word for word in all_words[400:] if word.isalpha() and len(word) >= 4
]  # skip stopwords (first ~400 words) and filter out non-alpha and short words
# random.shuffle(random_words)
SEARCH_WORDS = all_words[:20_000]
SEARCH_WORDS[:10]

In [None]:
RETRIEVAL_ANCHORS_NUM = 300
NUM_SEEDS: int = 10

In [None]:
def get_latents(words, encoder: GensimEncoder):
    latents = torch.stack([torch.tensor(encoder.model.get_vector(word), device=DEVICE) for word in words], dim=0)
    return latents

In [None]:
from rae.modules.attention import *
from torch_cluster import fps

rel_proj = RelativeAttention(
    n_anchors=RETRIEVAL_ANCHORS_NUM,
    n_classes=None,
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
    #  output_normalization_mode=OutputNormalization.L2,
    #          similarities_quantization_mode='differentiable_round',
    #          similarities_bin_size=0.01,
    #          similarities_num_clusters=,
    #         absolute_quantization_mode="cluster",
    #         absolute_bin_size=2,  # ignored
    #         absolute_num_clusters=2,
)


class LatentSpace:
    def __init__(
        self,
        encoding_type: str,
        encoder: str,
        vectors: torch.Tensor = None,
    ):
        self.encoding_type: str = encoding_type
        self.vectors: torch.Tensor = vectors
        self.encoder: str = encoder

        self._cached_anchors = {}
        self._cache_index = None

    def to_faiss(self) -> FaissIndex:
        if self._cache_index is not None:
            return self._cache_index
        index: FaissIndex = FaissIndex(d=self.vectors.size(1))

        index.add_vectors(
            embeddings=list(zip(SEARCH_WORDS, self.vectors.cpu().numpy())),
            normalize=True,
        )

        self._cache_index = index
        return index

    def to_relative(
        self, anchor_choice: str = None, seed: int = None, anchors: Optional[Sequence[str]] = None
    ) -> "RelativeSpace":
        assert self.encoding_type != "relative"  # TODO: for now
        anchors = self.get_anchors(anchor_choice=anchor_choice, seed=seed) if anchors is None else anchors

        anchor_latents: np.ndarray = ENCODERS[self.encoder].model.vectors_for_all(keys=anchors).vectors
        anchor_latents: torch.Tensor = torch.as_tensor(anchor_latents)

        relative_vectors = rel_proj(x=self.vectors, anchors=anchor_latents.cpu())[AttentionOutput.SIMILARITIES].cpu()
        return RelativeSpace(
            vectors=relative_vectors,
            encoder=self.encoder,
            anchors=anchors,
        )

    def get_anchors(self, anchor_choice: str, seed: int) -> Sequence[str]:
        key = (seed, anchor_choice)
        if key in self._cached_anchors:
            # print(f"Cache match: {key} in {self._cached_anchors.keys()}")
            return self._cached_anchors[key]
        else:
            # print(f"Cache miss: {key} not in {self._cached_anchors.keys()}")
            pass
        # Select anchors
        seed_everything(seed)
        if anchor_choice == "uniform" or anchor_choice.startswith("top_"):
            limit: int = len(SEARCH_WORDS) if anchor_choice == "uniform" else int(anchor_choice[4:])
            anchor_set: Sequence[str] = random.sample(SEARCH_WORDS[:limit], RETRIEVAL_ANCHORS_NUM)
        elif anchor_choice == "fps":
            anchor_fps = get_latents(words=SEARCH_WORDS, encoder=ENCODERS[self.encoder])
            anchor_fps = F.normalize(anchor_fps, p=2, dim=-1)
            anchor_fps = fps(anchor_fps, random_start=True, ratio=RETRIEVAL_ANCHORS_NUM / len(SEARCH_WORDS))
            anchor_set: Sequence[str] = [SEARCH_WORDS[word_index] for word_index in anchor_fps.cpu().tolist()]
        elif anchor_choice == "kmeans":
            vectors = F.normalize(get_latents(words=SEARCH_WORDS, encoder=ENCODERS[self.encoder]), p=2)
            clustered = KMeans(n_clusters=RETRIEVAL_ANCHORS_NUM).fit_predict(vectors.cpu().numpy())

            all_targets = sorted(set(clustered))
            cluster2embeddings = {target: vectors[clustered == target] for target in all_targets}
            cluster2centroid = {
                cluster: centroid.mean(dim=0).cpu().numpy() for cluster, centroid in cluster2embeddings.items()
            }
            centroids = np.array(list(cluster2centroid.values()), dtype="float32")

            index: FaissIndex = FaissIndex(d=vectors.shape[1])
            index.add_vectors(list(zip(SEARCH_WORDS, vectors.cpu().numpy())), normalize=False)
            centroids = index.search_by_vectors(query_vectors=centroids, k_most_similar=1, normalize=True)

            anchor_set = [list(word2score.keys())[0] for word2score in centroids]
        else:
            assert NotImplementedError

        result = sorted(anchor_set)
        self._cached_anchors[key] = result

        return result


class RelativeSpace(LatentSpace):
    def __init__(
        self,
        vectors: torch.Tensor,
        anchors: Sequence[str],
        encoder: str = None,
    ):
        super().__init__(encoding_type="relative", vectors=vectors, encoder=encoder)
        self.anchors: Sequence[str] = anchors

In [None]:
ANCHOR_CHOICES = ("uniform", "top_1000", "top_5000", "top_10000", "fps", "kmeans")

In [None]:
from rae.cka import CudaCKA as CKA

EncPair = Tuple[str, str]


@torch.no_grad()
def evaluate_retrieval(latent_space1: LatentSpace, latent_space2: LatentSpace, k: int = 5):
    performance = {
        key: []
        for key in (
            "src_enc",
            "tgt_enc",
            "topk_jaccard",
            "mrr",
            "linear_cka",
            "rbf_kernel_cka",
            "mse",
            "cosine_sim",
        )
    }

    # index1: FaissIndex = latent_space1.to_faiss()
    index2: FaissIndex = latent_space2.to_faiss()

    target_neighbors = index2.search_by_vectors(
        query_vectors=latent_space2.vectors.cpu().numpy(), k_most_similar=k, normalize=True
    )
    actual_neighbors = index2.search_by_vectors(
        query_vectors=latent_space1.vectors.cpu().numpy(), k_most_similar=k, normalize=True
    )

    target_neighbors: Mapping[str, Mapping[str, float]] = {
        word: topk for word, topk in zip(SEARCH_WORDS, target_neighbors)
    }
    actual_neighbors: Mapping[str, Mapping[str, float]] = {
        word: topk for word, topk in zip(SEARCH_WORDS, actual_neighbors)
    }

    target_words: Mapping[str, Set[str]] = {
        search_word: set(target_neighbors[search_word].keys()) for search_word in SEARCH_WORDS
    }
    actual_words: Mapping[str, Set[str]] = {
        search_word: set(actual_neighbors[search_word].keys()) for search_word in SEARCH_WORDS
    }

    topk_jaccard = {
        search_word: len(set.intersection(target_words[search_word], actual_words[search_word]))
        / len(set.union(target_words[search_word], actual_words[search_word]))
        for search_word in SEARCH_WORDS
    }
    topk_jaccard = np.mean(list(topk_jaccard.values()))

    search_word2word2rank = {
        search_word: {key: index for index, key in enumerate(word2sim.keys(), start=1)}
        for search_word, word2sim in actual_neighbors.items()
    }
    mrr = {
        search_word: (
            #                 word2rank.get(search_word, K)
            0
            if search_word not in word2rank
            else 1 / word2rank[search_word]
        )
        for search_word, word2rank in search_word2word2rank.items()
    }
    mrr = np.mean(list(mrr.values()))

    # semantic_horizon = []
    # for search_word, neighbors in actual_words.items():
    #     neighbor2ranking = {
    #         neighbor: {
    #             key: index
    #             for index, key in enumerate(
    #                 enc_type2enc_names2word2topk["absolute"][(enc_name2, enc_name2)][neighbor].keys(), start=1
    #             )
    #         }
    #         for neighbor in neighbors
    #     }
    #     neighbor2mrr = {
    #         neighbor: (
    #             #                 topk.get(search_word, K)
    #             0
    #             if search_word not in ranking
    #             else 1 / ranking[search_word]
    #         )
    #         for neighbor, ranking in neighbor2ranking.items()
    #     }
    #     semantic_horizon.append(np.mean(list(neighbor2mrr.values())))
    #
    # semantic_horizon = np.mean(semantic_horizon)

    chunk_size: int = 5000
    num_chunks: int = (len(SEARCH_WORDS) + chunk_size - 1) // chunk_size
    linear_cka, rbf_kernel_cka, mse, cosine_sim = [], [], [], []
    for chunk_latents1, chunk_latents2 in zip(
        latent_space1.vectors.chunk(num_chunks), latent_space2.vectors.chunk(num_chunks)
    ):
        chunk_latents1 = chunk_latents1.cuda()
        chunk_latents2 = chunk_latents2.cuda()
        cka = CKA(device=DEVICE)

        chunk_linear_cka = cka.linear_CKA(chunk_latents1, chunk_latents2).cpu()
        # chunk_rbf_kernel_cka = cka.kernel_CKA(chunk_latents1, chunk_latents2).cpu()
        chunk_cosine_sim = F.cosine_similarity(chunk_latents1, chunk_latents2).mean().cpu()
        chunk_mse = F.mse_loss(chunk_latents1, chunk_latents2, reduction="sum").cpu()

        _ = chunk_latents1.cpu()
        _ = chunk_latents2.cpu()

        linear_cka.append(chunk_linear_cka)
        rbf_kernel_cka.append(torch.zeros(1))
        mse.append(chunk_mse)
        cosine_sim.append(chunk_cosine_sim)

    linear_cka = torch.stack(linear_cka).mean(dim=0).cpu().item()
    rbf_kernel_cka = torch.stack(rbf_kernel_cka).mean(dim=0).cpu().item()
    mse = torch.stack(mse).mean(dim=0).cpu().item()
    cosine_sim = torch.stack(cosine_sim).mean(dim=0).cpu().item()

    performance["src_enc"].append(latent_space1.encoder)
    performance["tgt_enc"].append(latent_space2.encoder)
    performance["topk_jaccard"].append(topk_jaccard)
    performance["mrr"].append(mrr)
    performance["linear_cka"].append(linear_cka)
    performance["rbf_kernel_cka"].append(rbf_kernel_cka)
    performance["mse"].append(mse)
    performance["cosine_sim"].append(cosine_sim)

    performance = pd.DataFrame(performance)
    performance["enc1_type"] = latent_space1.encoding_type
    performance["enc2_type"] = latent_space2.encoding_type

    return performance


# @torch.no_grad()
# def evaluate(
#     enc_type2enc_name2faiss_index: Mapping[str, Mapping[str, FaissIndex]],
#     enc_type2enc_names2word2topk: Mapping[str, Mapping[Tuple[str, str], Mapping[str, Sequence[str]]]],
# ):
#     performance = {key: [] for key in ("enc_name", "linear_cka", "rbf_kernel_cka")}
#
#     for enc_name in ("word2vec-google-news-300", "local_fasttext"):
#         faiss_abs = enc_type2enc_name2faiss_index["absolute"][enc_name]
#         faiss_rel = enc_type2enc_name2faiss_index["relative"][enc_name]
#
#         linear_cka, rbf_kernel_cka = [], []
#         for chunk in chunk_iterable(SEARCH_WORDS, chunk_size=5_000):
#             chunk_latents_enc1 = torch.as_tensor(faiss_abs.reconstruct_n(keys=chunk), device=DEVICE)
#             chunk_latents_enc2 = torch.as_tensor(faiss_rel.reconstruct_n(keys=chunk), device=DEVICE)
#
#             cka = CKA(device=DEVICE)
#             linear_cka.append(cka.linear_CKA(chunk_latents_enc1, chunk_latents_enc2))
#             rbf_kernel_cka.append(cka.kernel_CKA(chunk_latents_enc1, chunk_latents_enc2))
#
#         linear_cka = torch.stack(linear_cka).mean(dim=0).cpu().item()
#         rbf_kernel_cka = torch.stack(rbf_kernel_cka).mean(dim=0).cpu().item()
#
#         performance["enc_name"].append(enc_name)
#         performance["linear_cka"].append(linear_cka)
#         performance["rbf_kernel_cka"].append(rbf_kernel_cka)
#
#     return pd.DataFrame(performance)

In [None]:
performances = []
anchor_infos = []
for seed, anchor_choice in (pbar := tqdm(list(itertools.product(range(NUM_SEEDS), ANCHOR_CHOICES)))):
    absolute_spaces: Sequence[LatentSpace] = [
        LatentSpace(
            encoding_type="absolute",
            vectors=torch.as_tensor(encoder.model.vectors_for_all(keys=SEARCH_WORDS).vectors),
            encoder=enc_name,
        )
        for enc_name, encoder in ENCODERS.items()
    ]

    for absolute_space in absolute_spaces:
        words: Sequence[str] = absolute_space.get_anchors(anchor_choice=anchor_choice, seed=seed)
        anchors: torch.Tensor = get_latents(words=words, encoder=ENCODERS[absolute_space.encoder])
        anchor_info = {
            "seed": seed,
            "anchor_choice": anchor_choice,
            "anchors": anchors,
            "words": words,
            "encoder": absolute_space.encoder,
            "dists": pairwise_cosine_similarity(anchors, zero_diagonal=False),
        }
        anchor_infos.append(anchor_info)

    for abs_space1, abs_space2 in itertools.product(absolute_spaces, repeat=2):
        # absolute
        absolute_performance = evaluate_retrieval(latent_space1=abs_space1, latent_space2=abs_space2)

        absolute_performance["anchor_choice"] = anchor_choice
        absolute_performance["seed"] = seed
        performances.append(absolute_performance)

        # relative
        rel_space1: RelativeSpace = abs_space1.to_relative(seed=seed, anchor_choice=anchor_choice)
        rel_space2: RelativeSpace = abs_space2.to_relative(anchors=rel_space1.anchors)
        relative_performance = evaluate_retrieval(latent_space1=rel_space1, latent_space2=rel_space2)
        relative_performance["anchor_choice"] = anchor_choice
        relative_performance["seed"] = seed

        performances.append(relative_performance)

performances = pd.concat(performances)
performances.to_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "quantitative_analysis.tsv", sep="\t", index=False
)
torch.save(anchor_infos, PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "anchor_infos.pt")
performances

In [None]:
performance_df = pd.read_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "quantitative_analysis.tsv", sep="\t"
)
performance_df.groupby(["anchor_choice", "enc1_type", "enc1_type", "src_enc", "tgt_enc"]).aggregate(
    [
        np.mean,
    ]
)

In [None]:
anchor_infos = torch.load(PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "anchor_infos.pt")