In [None]:
import random

import pandas as pd
import torch
from typing import Mapping, Tuple

from rae.openfaiss import FaissIndex
from tqdm import tqdm
import itertools
from pytorch_lightning import seed_everything
import numpy as np
from rae.modules.attention import *
from rae.utils.utils import StrEnum

In [None]:
DEVICE: str = "cuda"

In [None]:
from rae.modules.text.encoder import GensimEncoder

ENCODERS = {
    model_name: GensimEncoder(language="en", lemmatize=False, model_name=model_name)
    for model_name in (
        "local_fasttext",
        "word2vec-google-news-300",
    )
}

In [None]:
all_words = list(list(ENCODERS.values())[0].model.key_to_index.keys())
all_words = [
    word for word in all_words[400:] if word.isalpha() and len(word) >= 4
]  # skip stopwords (first ~400 words) and filter out non-alpha and short words
# random.shuffle(random_words)
SEARCH_WORDS = all_words[:20_000]
SEARCH_WORDS[:10]

In [None]:
RETRIEVAL_ANCHORS_NUM = 300
NUM_SEEDS: int = 10

In [None]:
def get_latents(words, encoder: GensimEncoder):
    latents = torch.tensor([encoder.model.get_vector(word) for word in words], device=DEVICE)
    return latents

In [None]:
K = 5


def build_neighborhood(
    enc_type2enc_name2faiss_index: Mapping[str, Mapping[str, FaissIndex]]
) -> Mapping[str, Mapping[Tuple[str, str], Mapping[str, Sequence[str]]]]:
    enc_type2enc_names2word2topk = {enc_type: {} for enc_type in ("absolute", "relative")}

    for enc_type, enc_name2faiss_index in enc_type2enc_name2faiss_index.items():
        for enc_name1, enc_name2 in itertools.product(enc_name2faiss_index.keys(), repeat=2):
            faiss_index1 = enc_name2faiss_index[enc_name1]
            faiss_index2 = enc_name2faiss_index[enc_name2]

            enc1_vectors = np.asarray(faiss_index1.reconstruct_n(SEARCH_WORDS), dtype="float32")
            enc2_neighbors = faiss_index2.search_by_vectors(
                query_vectors=enc1_vectors, k_most_similar=K, normalize=False
            )

            enc_type2enc_names2word2topk[enc_type][(enc_name1, enc_name2)] = {
                word: topk for word, topk in zip(SEARCH_WORDS, enc2_neighbors)
            }

    return enc_type2enc_names2word2topk

In [None]:
def evaluate(enc_type2enc_names2word2topk: Mapping[str, Mapping[Tuple[str, str], Mapping[str, Sequence[str]]]]):
    performance = {key: [] for key in ("src_enc", "tgt_enc", "enc_type", "topk_jaccard", "mrr")}

    for enc_type, enc_names2word2topk in enc_type2enc_names2word2topk.items():
        for (enc_name1, enc_name2), word2topk in enc_names2word2topk.items():
            target_words = {
                search_word: set(enc_names2word2topk[(enc_name2, enc_name2)][search_word].keys())
                for search_word in SEARCH_WORDS
            }
            actual_words = {
                search_word: set(enc_names2word2topk[(enc_name1, enc_name2)][search_word].keys())
                for search_word in SEARCH_WORDS
            }

            topk_jaccard = {
                search_word: len(set.intersection(target_words[search_word], actual_words[search_word]))
                / len(set.union(target_words[search_word], actual_words[search_word]))
                for search_word in SEARCH_WORDS
            }
            topk_jaccard = np.mean(list(topk_jaccard.values()))

            search_word2word2rank = {
                search_word: {key: index for index, key in enumerate(word2sim.keys(), start=1)}
                for search_word, word2sim in enc_names2word2topk[(enc_name1, enc_name2)].items()
            }
            mrr = {
                search_word: (
                    #                 word2rank.get(search_word, K)
                    0
                    if search_word not in word2rank
                    else 1 / word2rank[search_word]
                )
                for search_word, word2rank in search_word2word2rank.items()
            }
            mrr = np.mean(list(mrr.values()))

            # semantic_horizon = []
            # for search_word, neighbors in actual_words.items():
            #     neighbor2ranking = {
            #         neighbor: {
            #             key: index
            #             for index, key in enumerate(
            #                 enc_type2enc_names2word2topk["absolute"][(enc_name2, enc_name2)][neighbor].keys(), start=1
            #             )
            #         }
            #         for neighbor in neighbors
            #     }
            #     neighbor2mrr = {
            #         neighbor: (
            #             #                 topk.get(search_word, K)
            #             0
            #             if search_word not in ranking
            #             else 1 / ranking[search_word]
            #         )
            #         for neighbor, ranking in neighbor2ranking.items()
            #     }
            #     semantic_horizon.append(np.mean(list(neighbor2mrr.values())))
            #
            # semantic_horizon = np.mean(semantic_horizon)

            performance["src_enc"].append(enc_name1)
            performance["tgt_enc"].append(enc_name2)
            performance["enc_type"].append(enc_type)
            performance["topk_jaccard"].append(topk_jaccard)
            performance["mrr"].append(mrr)

    return pd.DataFrame(performance)

In [None]:
from nn_core.common import PROJECT_ROOT

rel_attention = RelativeAttention(
    n_anchors=RETRIEVAL_ANCHORS_NUM,
    n_classes=None,
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
    #  output_normalization_mode=OutputNormalization.L2,
    #          similarities_quantization_mode='differentiable_round',
    #          similarities_bin_size=0.01,
    #          similarities_num_clusters=,
    #         absolute_quantization_mode="cluster",
    #         absolute_bin_size=2,  # ignored
    #         absolute_num_clusters=2,
)

performance = []
for seed in (pbar := tqdm(range(NUM_SEEDS))):
    enc_type2enc_name2faiss_index = {
        enc_type: {
            enc_name: FaissIndex(d=300 if enc_type == "absolute" else RETRIEVAL_ANCHORS_NUM)
            for enc_name in ENCODERS.keys()
        }
        for enc_type in ("absolute", "relative")
    }

    for enc_name, encoder in ENCODERS.items():
        pbar.set_description(f"seed={seed}/{NUM_SEEDS} encoder={enc_name}")

        seed_everything(seed)
        seed_anchors: Sequence[str] = sorted(random.sample(SEARCH_WORDS, RETRIEVAL_ANCHORS_NUM))
        latents = encoder.model.vectors_for_all(keys=SEARCH_WORDS).vectors
        enc_type2enc_name2faiss_index["absolute"][enc_name].add_vectors(
            embeddings=list(zip(SEARCH_WORDS, latents)), normalize=True
        )

        anchors: torch.Tensor = get_latents(words=seed_anchors, encoder=encoder)
        latents = torch.as_tensor(latents, dtype=torch.float32)
        relative_representation = rel_attention(x=latents, anchors=anchors.cpu())

        enc_type2enc_name2faiss_index["relative"][enc_name].add_vectors(
            embeddings=list(zip(SEARCH_WORDS, relative_representation[AttentionOutput.SIMILARITIES].cpu().numpy())),
            normalize=True,
        )

    seed_neighborhoods = build_neighborhood(enc_type2enc_name2faiss_index=enc_type2enc_name2faiss_index)
    seed_performance = evaluate(enc_type2enc_names2word2topk=seed_neighborhoods)
    seed_performance["seed"] = seed
    performance.append(seed_performance)

performance = pd.concat(performance)
performance.to_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "retrieval_performance.tsv", sep="\t"
)
performance

In [None]:
performance_df = pd.read_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "retrieval_performance.tsv", sep="\t"
)
performance_df
performance_df.groupby(["src_enc", "tgt_enc", "enc_type"]).aggregate([np.mean, np.std, "count"])