In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from rae import PROJECT_ROOT
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes

logging.getLogger().setLevel(logging.ERROR)

DEVICE: str = "cuda"

In [None]:
from rae.modules.text.encoder import GensimEncoder

ENCODERS = {
    model_name: GensimEncoder(language="en", lemmatize=False, model_name=model_name)
    for model_name in (
        "local_fasttext",
        "word2vec-google-news-300",
    )
}

In [None]:
assert len({frozenset(encoder.model.key_to_index.keys()) for encoder in ENCODERS.values()}) == 1

In [None]:
import random
from pytorch_lightning import seed_everything

seed_everything(4)

NUM_ANCHORS = 300
NUM_TARGETS = 200
NUM_CLUSTERS = 4
WORDS = sorted(ENCODERS["local_fasttext"].model.key_to_index.keys())
WORDS = [word for word in WORDS if word.isalpha() and len(word) >= 4]
TARGET_WORDS = ["school", "ferrari", "water", "martial"]  # words to take the neighborhoods from
# TARGET_WORDS = random.sample(WORDS, NUM_CLUSTERS)
print(f"{TARGET_WORDS=}")
word2index = {word: i for i, word in enumerate(WORDS)}
TARGETS = torch.zeros(len(WORDS), device="cpu")
target_cluster = [
    [word for word, sim in ENCODERS["local_fasttext"].model.most_similar(target_word, topn=NUM_TARGETS)]
    for target_word in TARGET_WORDS
]

valid_words, valid_targets = [], []
for i, target_cluster in enumerate(target_cluster):
    valid_words.append(TARGET_WORDS[i])
    valid_targets.append(i + 1)
    for word in target_cluster:
        if word in word2index:
            valid_words.append(word)
            valid_targets.append(i + 1)

WORDS = valid_words
TARGETS = valid_targets

ANCHOR_WORDS = sorted(random.sample(WORDS, NUM_ANCHORS))  # TODO: stratified

ANCHOR_WORDS[:10]

In [None]:
from sklearn.decomposition import PCA


def latents_distance(latents):
    assert len(latents) == 2
    for x in latents:
        assert x.shape[1] == 300

    dist = F.pairwise_distance(latents[0], latents[1], p=2).mean()
    return f"{dist:.2f}"


def get_latents(words, encoder: GensimEncoder):
    latents = torch.tensor([encoder.model.get_vector(word) for word in words], device=DEVICE)
    return latents


def to_df(latents, fit_pca: bool = True):
    if fit_pca:
        latents2d = PCA(n_components=2).fit_transform(latents.cpu())
    else:
        latents2d = latents[:, [0, 1]]
    df = pd.DataFrame(
        {
            "x": latents2d[:, 0].tolist(),
            "y": latents2d[:, 1].tolist(),
            "target": TARGETS,
        }
    )
    return df

# Plot stuff

In [None]:
def plot_bg(
    ax,
    df,
    cmap,
    norm,
    size,
    bg_alpha,
):
    """Create and return a plot of all our movie embeddings with very low opacity.
    (Intended to be used as a basis for further - more prominent - plotting of a
    subset of movies. Having the overall shape of the map space in the background is
    useful for context.)
    """
    ax.scatter(df.x, df.y, c=cmap(norm(df["target"])), alpha=bg_alpha, s=size)
    return ax


def hightlight_cluster(
    ax,
    df,
    target,
    alpha,
    cmap,
    norm,
    size=0.5,
):
    cluster_df = df[df["target"] == target]
    ax.scatter(cluster_df.x, cluster_df.y, c=cmap(norm(cluster_df["target"])), alpha=alpha, s=size)


def plot_latent_space(ax, df, targets, size, cmap, norm, bg_alpha, alpha):
    ax = plot_bg(ax, df, bg_alpha=bg_alpha, cmap=cmap, norm=norm, size=size)
    for target in targets:
        hightlight_cluster(ax, df, target, alpha=alpha, size=size, cmap=cmap, norm=norm)
    return ax

## AE

In [None]:
ae_latents = {}
anchors_latents = {}
for enc_name, encoder in ENCODERS.items():
    ae_latents[enc_name] = get_latents(words=WORDS, encoder=encoder)
    anchors_latents[enc_name] = get_latents(words=ANCHOR_WORDS, encoder=encoder)

import copy

original_ae_latents = copy.deepcopy(ae_latents)
original_anchor_latents = copy.deepcopy(anchors_latents)

## Rel Attention NO Quantization

In [None]:
from sklearn.decomposition import PCA
from rae.modules.attention import *


col_config = ((None, None),)
N_ROWS = len(ENCODERS)
N_COLS = len(col_config) + 1
print(
    N_ROWS,
    N_COLS,
)
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

import matplotlib as mpl

num_colors = len(TARGET_WORDS)
cmap = mpl.colors.ListedColormap(plt.cm.get_cmap("Set1", 10).colors[:num_colors], name="rgb", N=num_colors)
norm = plt.Normalize(min(TARGETS), max(TARGETS))

fig, axes = plt.subplots(dpi=300, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)

S = 7
BG_ALPHA = 0.35
ALPHA = 0.5

TARGETS_HIGHTLIGHT = [1]
for ax_encoders, (_, latents) in zip(axes[0], ae_latents.items()):

    plot_latent_space(
        ax_encoders,
        to_df(latents),
        targets=TARGETS_HIGHTLIGHT,
        size=S,
        bg_alpha=BG_ALPHA,
        alpha=ALPHA,
        cmap=cmap,
        norm=norm,
    )

distances = {"absolute": latents_distance(list(ae_latents.values()))}

for col_i, (quant_mode, bin_size) in enumerate(col_config):
    rel_attention = RelativeAttention(
        n_anchors=NUM_ANCHORS,
        n_classes=len(set(TARGETS)),
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
    )
    assert sum(x.numel() for x in rel_attention.parameters()) == 0
    rels = []
    for row_ax, (enc_name, latents), (a_enc_name, a_latents) in zip(
        axes[1], ae_latents.items(), anchors_latents.items()
    ):
        assert enc_name == a_enc_name
        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]
        rels.append(rel)
        plot_latent_space(
            row_ax,
            to_df(rel),
            targets=TARGETS_HIGHTLIGHT,
            size=S,
            bg_alpha=BG_ALPHA,
            alpha=ALPHA,
            cmap=cmap,
            norm=norm,
        )
    distances[f"relative({quant_mode}, {bin_size})"] = latents_distance(rels)

distances

In [None]:
fig.savefig("word-embeddings-spaces-no-quant.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o 'word-embeddings-spaces-no-quant.pdf' 'word-embeddings-spaces-no-quant.svg'
!rm 'word-embeddings-spaces-no-quant'.svg

## Rel Attention Quantized

In [None]:
from sklearn.decomposition import PCA
from rae.modules.attention import *


col_config = (
    (None, None),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.0001),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.05),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.1),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.3),
    ("cluster", 1),
    #     ("cluster", 0.5),
    ("cluster", 1.5),
    ("cluster", 2),
    # ('kmeans', 3),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.6),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.7),
    #     (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.8),
    #    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.9),
)
N_ROWS = len(ENCODERS)
N_COLS = len(col_config) + 1

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

import matplotlib as mpl

num_colors = len(TARGET_WORDS)
cmap = mpl.colors.ListedColormap(plt.cm.get_cmap("Set1", 10).colors[:num_colors], name="rgb", N=num_colors)
# cmap = plt.cm.get_cmap("Set1", 5)
norm = plt.Normalize(min(TARGETS), max(TARGETS))

fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)


TARGETS_HIGHTLIGHT = [1]
for ax_encoders, (_, latents) in zip(axes, ae_latents.items()):

    plot_latent_space(
        ax_encoders[0],
        to_df(latents),
        targets=TARGETS_HIGHTLIGHT,
        size=0.75,
        bg_alpha=0.25,
        alpha=1,
        cmap=cmap,
        norm=norm,
    )

distances = {"absolute": latents_distance(list(ae_latents.values()))}

for col_i, (quant_mode, bin_size) in enumerate(col_config):
    rel_attention = RelativeAttention(
        n_anchors=NUM_ANCHORS,
        n_classes=len(set(TARGETS)),
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
        #  output_normalization_mode=OutputNormalization.L2,
        #         similarities_quantization_mode=quant_mode,
        #         similarities_bin_size=bin_size,
        #         similarities_num_clusters=bin_size,
        absolute_quantization_mode=quant_mode,
        absolute_bin_size=bin_size,
        absolute_num_clusters=bin_size,
    )
    assert sum(x.numel() for x in rel_attention.parameters()) == 0
    rels = []
    for row_axes, (enc_name, latents), (a_enc_name, a_latents) in zip(
        axes, ae_latents.items(), anchors_latents.items()
    ):
        assert enc_name == a_enc_name
        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]
        rels.append(rel)
        plot_latent_space(
            row_axes[col_i + 1],
            to_df(rel),
            targets=TARGETS_HIGHTLIGHT,
            size=0.75,
            bg_alpha=0.25,
            alpha=1,
            cmap=cmap,
            norm=norm,
        )
    distances[f"relative({quant_mode}, {bin_size})"] = latents_distance(rels)

distances

In [None]:
fig.savefig("word-embeddings-spaces.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o 'word-embeddings-spaces.pdf' 'word-embeddings-spaces.svg'
!rm 'word-embeddings-spaces'.svg

In [None]:
random_words = list(list(ENCODERS.values())[0].model.key_to_index.keys())[400:]
random_words = [word for word in random_words if word.isalpha() and len(word) >= 4]
# random.shuffle(random_words)
SEARCH_WORDS = random_words[:20_000]
SEARCH_WORDS

# Faiss Search

In [None]:
RETRIEVAL_ANCHORS_NUM = 300
RETRIEVAL_ANCHORS = sorted(random.sample(SEARCH_WORDS, RETRIEVAL_ANCHORS_NUM))

In [None]:
from rae.openfaiss import FaissIndex

enc_type2enc_name2faiss_index = {
    enc_type: {
        enc_name: FaissIndex(d=300 if enc_type == "absolute" else RETRIEVAL_ANCHORS_NUM) for enc_name in ENCODERS.keys()
    }
    for enc_type in ("absolute", "relative")
}


for enc_name, encoder in ENCODERS.items():
    print(enc_name)
    latents = encoder.model.vectors_for_all(keys=SEARCH_WORDS).vectors
    enc_type2enc_name2faiss_index["absolute"][enc_name].add_vectors(
        embeddings=list(zip(SEARCH_WORDS, latents)), normalize=True
    )

    rel_attention = RelativeAttention(
        n_anchors=NUM_ANCHORS,
        n_classes=len(set(TARGETS)),
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
        #  output_normalization_mode=OutputNormalization.L2,
        #          similarities_quantization_mode='differentiable_round',
        #          similarities_bin_size=0.01,
        #          similarities_num_clusters=,
        #         absolute_quantization_mode="cluster",
        #         absolute_bin_size=2,  # ignored
        #         absolute_num_clusters=2,
    )
    anchors = get_latents(words=RETRIEVAL_ANCHORS, encoder=encoder)
    latents = torch.as_tensor(latents, dtype=torch.float32)
    relative_representation = rel_attention(x=latents, anchors=anchors.cpu())

    enc_type2enc_name2faiss_index["relative"][enc_name].add_vectors(
        embeddings=list(zip(SEARCH_WORDS, relative_representation[AttentionOutput.SIMILARITIES].cpu().numpy())),
        normalize=True,
    )

In [None]:
enc_type2enc_name2faiss_index

In [None]:
import itertools
import numpy as np

K = 5
enc_type2enc_names2word2topk = {enc_type: {} for enc_type in ("absolute", "relative")}

for enc_type, enc_name2faiss_index in enc_type2enc_name2faiss_index.items():
    for enc_name1, enc_name2 in itertools.product(enc_name2faiss_index.keys(), repeat=2):
        faiss_index1 = enc_name2faiss_index[enc_name1]
        faiss_index2 = enc_name2faiss_index[enc_name2]

        enc1_vectors = np.asarray(faiss_index1.reconstruct_n(SEARCH_WORDS), dtype="float32")
        enc2_neighbors = faiss_index2.search_by_vectors(query_vectors=enc1_vectors, k_most_similar=K, normalize=False)

        enc_type2enc_names2word2topk[enc_type][(enc_name1, enc_name2)] = {
            word: topk for word, topk in zip(SEARCH_WORDS, enc2_neighbors)
        }
# enc_type2enc_names2word2topk

In [None]:
performance_df

In [None]:
performance = {key: [] for key in ("src_enc", "tgt_enc", "enc_type", "topk_jaccard", "mrr", "semantic_horizon")}

for enc_type, enc_names2word2topk in enc_type2enc_names2word2topk.items():
    for (enc_name1, enc_name2), word2topk in enc_names2word2topk.items():
        target_words = {
            search_word: set(enc_names2word2topk[(enc_name2, enc_name2)][search_word].keys())
            for search_word in SEARCH_WORDS
        }
        actual_words = {
            search_word: set(enc_names2word2topk[(enc_name1, enc_name2)][search_word].keys())
            for search_word in SEARCH_WORDS
        }

        topk_jaccard = {
            search_word: len(set.intersection(target_words[search_word], actual_words[search_word]))
            / len(set.union(target_words[search_word], actual_words[search_word]))
            for search_word in SEARCH_WORDS
        }
        topk_jaccard = np.mean(list(topk_jaccard.values()))

        search_word2word2rank = {
            search_word: {key: index for index, key in enumerate(word2sim.keys(), start=1)}
            for search_word, word2sim in enc_names2word2topk[(enc_name1, enc_name2)].items()
        }
        mrr = {
            search_word: (
                #                 word2rank.get(search_word, K)
                0
                if search_word not in word2rank
                else 1 / word2rank[search_word]
            )
            for search_word, word2rank in search_word2word2rank.items()
        }
        mrr = np.mean(list(mrr.values()))

        semantic_horizon = []
        for search_word, neighbors in actual_words.items():
            neighbor2ranking = {
                neighbor: {
                    key: index
                    for index, key in enumerate(
                        enc_type2enc_names2word2topk["absolute"][(enc_name2, enc_name2)][neighbor].keys(), start=1
                    )
                }
                for neighbor in neighbors
            }
            neighbor2mrr = {
                neighbor: (
                    #                 topk.get(search_word, K)
                    0
                    if search_word not in ranking
                    else 1 / ranking[search_word]
                )
                for neighbor, ranking in neighbor2ranking.items()
            }
            semantic_horizon.append(np.mean(list(neighbor2mrr.values())))

        semantic_horizon = np.mean(semantic_horizon)

        performance["src_enc"].append(enc_name1)
        performance["tgt_enc"].append(enc_name2)
        performance["enc_type"].append(enc_type)
        performance["topk_jaccard"].append(topk_jaccard)
        performance["mrr"].append(mrr)
        performance["semantic_horizon"].append(semantic_horizon)

performance_df = pd.DataFrame(performance)
performance_df.to_csv(
    PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison" / "semantic_horizon.tsv", sep="\t"
)
performance_df

In [None]:
enc_type2enc_names2word2topk["absolute"][("local_fasttext", "word2vec-google-news-300")]["student"]

In [None]:
enc_type2enc_names2word2topk["relative"][("local_fasttext", "word2vec-google-news-300")]["student"]

In [None]:
enc_type2enc_names2word2topk["absolute"][("local_fasttext", "local_fasttext")]["student"]

In [None]:
enc_type2enc_names2word2topk["relative"][("local_fasttext", "local_fasttext")]["student"]