In [None]:
import json
from pathlib import Path
from typing import Mapping, Any, Optional, Set, Sequence, Collection

import numpy as np
import torch
from nn_core.common import PROJECT_ROOT
from torch import cosine_similarity
from tqdm import tqdm

from rae.modules.attention import RelativeAttention
from rae.modules.enumerations import (
    NormalizationMode,
    RelativeEmbeddingMethod,
    ValuesMethod,
    SimilaritiesQuantizationMode,
    AttentionOutput,
)
from rae.openfaiss import FaissIndex

In [None]:
iso_map = {"eng": "en", "spa": "es", "fra": "fr", "jpn": "ja"}

In [None]:
def _read_synset_info(pos_filter: Optional[Set[str]] = None) -> Mapping[str, Mapping[str, Any]]:
    with (PROJECT_ROOT / "data" / "synset_info.tsv").open("r", encoding="utf-8") as fr:
        head = next(fr).strip().split("\t")
        langs = head[2:]
        langs = [iso_map.get(lang, lang) for lang in langs]
        synset2info = {}

        # lang2all_lemmas = {}

        for i, line in enumerate(tqdm(fr, desc="Reading synset info")):
            # TODO: remove
            #             if i > 20_000_000:
            #                 break
            synset_id, pos, *lemmas = line.strip("\n").split("\t")
            if pos_filter is not None and pos not in pos_filter:
                continue

            assert len(langs) == len(lemmas)
            lang2lemmas = dict(zip(langs, lemmas))

            lang2lemmas: Mapping[str, Sequence[str]] = {
                lang: lang2lemmas[lang].split(",") for lang in iso_map.values()
            }  # TODO: remove

            lang2lemmas = {
                lang: [lemma for lemma in lemmas if "_" not in lemma and len(lemma) >= 4]
                for lang, lemmas in lang2lemmas.items()
            }
            if any(len(lemmas) == 0 for lang, lemmas in lang2lemmas.items()):
                continue
            # if len(set(lang2lemmas.values())) < len(lang2lemmas.values()):
            #     continue

            # for lang, lemmas in lang2lemmas.items():
            #     lang2all_lemmas.setdefault(lang, set())
            #     if lemma in lang2all_lemmas:
            #         continue
            #     lang2all_lemmas[lang].add(lemma)

            synset2info[synset_id] = dict(pos=pos, lang2lemmas=lang2lemmas, synset_id=synset_id)

        return synset2info

In [None]:
def build_anchors(
    lang2word2embedding: Mapping[str, Mapping[str, np.ndarray]], target_candidates: Optional[int] = 3_000
):
    synset_info = _read_synset_info()
    candidates = []

    for synset_id, info in tqdm(synset_info.items(), desc="Iterating synset info"):
        lang2lemmas: Mapping[str, Sequence[str]] = info["lang2lemmas"]
        # TODO: now considering only the first lemma for each synset
        lang2lemmas = {lang: lemmas[0] for lang, lemmas in lang2lemmas.items()}
        if all(lang2word2embedding.get(lang, {}).get(lemma, None) is not None for lang, lemma in lang2lemmas.items()):
            candidates.append(info)

        if target_candidates is not None and len(candidates) >= target_candidates:
            break

    lang2anchors = {}
    for candidate in candidates:
        for lang, lemmas in candidate["lang2lemmas"].items():
            lang2anchors.setdefault(lang, []).append(lemmas[0])

    return lang2anchors

In [None]:
def read_embeddings(langs: Optional[Collection[str]] = None) -> Mapping[str, Mapping[str, np.ndarray]]:
    def read_file(file_path: Path, max_index: Optional[int] = 10_000) -> Mapping[str, np.ndarray]:
        with file_path.open("r", encoding="utf-8") as fr:
            next(fr)

            word2embedding = {}
            for i, line in enumerate(tqdm(fr, desc=f"Reading {file_path}")):
                if max_index is not None and i > max_index:
                    break
                word, *embedding = line.strip().split(" ")
                embedding = np.array([float(x) for x in embedding])
                word2embedding[word] = embedding

            return word2embedding

    lang2word2embedding = {}

    for file in (PROJECT_ROOT / "fasttext").iterdir():
        if ".gz" in file.suffixes:
            continue
        lang = file.suffixes[0].strip(".")
        if langs is not None and lang not in langs:
            continue
        word2embedding = read_file(file_path=file, max_index=20_000)
        lang2word2embedding[lang] = word2embedding

    return lang2word2embedding

In [None]:
device: str = "cuda"

In [None]:
lang2word2embedding = read_embeddings(langs=set(iso_map.values()))

In [None]:
abs_lang2faiss_index = {lang: FaissIndex(d=300) for lang, word2embedding in lang2word2embedding.items()}
for lang, faiss_index in abs_lang2faiss_index.items():
    faiss_index: FaissIndex
    faiss_index.add_vectors(embeddings=lang2word2embedding[lang].items(), normalize=True)

In [None]:
lang2anchors = build_anchors(lang2word2embedding=lang2word2embedding, target_candidates=1000)
Path(PROJECT_ROOT / "lang2anchors.json").write_text(json.dumps(lang2anchors, indent=4))
n_anchors: int = len(list(lang2anchors.values())[0])
n_anchors

In [None]:
attention_block: RelativeAttention = RelativeAttention(
    in_features=300,
    hidden_features=None,
    n_anchors=n_anchors,
    n_classes=None,
    normalization_mode=NormalizationMode.L2,
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    similarities_quantization_mode=SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND,
    similarities_bin_size=0.5,
    #     similarities_quantization_mode=None,
    #     similarities_bin_size=None,
    similarities_aggregation_mode=None,
    similarities_aggregation_n_groups=None,
    anchors_sampling_mode=None,
    n_anchors_sampling_per_class=None,
).to(device)

In [None]:
rel_lang2faiss_index = {lang: FaissIndex(d=n_anchors) for lang, _ in lang2word2embedding.items()}
lang2anchor_embeddings = {
    lang: torch.stack([torch.tensor(lang2word2embedding[lang][anchor]).to(device) for anchor in anchors])
    for lang, anchors in lang2anchors.items()
}
for lang, faiss_index in rel_lang2faiss_index.items():
    word2embedding = lang2word2embedding[lang]
    words, embeddings = list(zip(*word2embedding.items()))
    embeddings = torch.tensor(embeddings).to(device)
    embeddings = attention_block(x=embeddings, anchors=lang2anchor_embeddings[lang])[AttentionOutput.OUTPUT]
    faiss_index.add_vectors(embeddings=list(zip(words, embeddings.cpu().numpy())), normalize=True)

In [None]:
first_lang: str = "en"
second_lang: str = "es"
word_first_lang: str = "gel"
word_second_lang: str = "gel"
word_first_lang_rel_vector = rel_lang2faiss_index[first_lang].reconstruct(word_first_lang)
word_first_lang_abs_vector = abs_lang2faiss_index[first_lang].reconstruct(word_first_lang)
word_second_rel_vector = rel_lang2faiss_index[second_lang].reconstruct(word_second_lang)
word_second_abs_vector = abs_lang2faiss_index[second_lang].reconstruct(word_second_lang)

In [None]:
abs_lang2faiss_index[first_lang].search_by_keys(query=[word_first_lang], k_most_similar=10)

In [None]:
rel_lang2faiss_index[first_lang].search_by_keys(query=[word_first_lang], k_most_similar=10)

In [None]:
rel_lang2faiss_index[second_lang].search_by_keys(query=[word_second_lang], k_most_similar=10)

In [None]:
rel_lang2faiss_index[second_lang].search_by_vectors(
    query_vectors=np.array([word_first_lang_rel_vector], dtype="float32"),
    k_most_similar=10,
    normalize=False,
)

In [None]:
rel_lang2faiss_index[first_lang].search_by_vectors(
    query_vectors=np.array([word_second_rel_vector], dtype="float32"),
    k_most_similar=10,
    normalize=False,
)

In [None]:
cosine_similarity(
    x1=torch.tensor(word_first_lang_rel_vector),
    x2=torch.tensor(word_second_rel_vector),
    dim=-1,
)

In [None]:
cosine_similarity(
    x1=torch.tensor(word_first_lang_abs_vector),
    x2=torch.tensor(word_second_abs_vector),
    dim=-1,
)

In [None]:
rel_diff = (torch.tensor(word_first_lang_rel_vector) - torch.tensor(word_second_rel_vector)).abs().sum()
rel_diff, rel_diff / n_anchors

In [None]:
abs_diff = (torch.tensor(word_first_lang_abs_vector) - torch.tensor(word_second_abs_vector)).abs().sum()
abs_diff, abs_diff / 300