In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from rae import PROJECT_ROOT
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes

logging.getLogger().setLevel(logging.ERROR)

DEVICE: str = "cuda"

In [None]:
from rae.modules.text.encoder import GensimEncoder

ENCODERS = {
    model_name: GensimEncoder(language="en", lemmatize=False, model_name=model_name)
    for model_name in (
        "local_fasttext",
        "word2vec-google-news-300",
    )
}

In [None]:
import random
from pytorch_lightning import seed_everything

seed_everything(4)

NUM_ANCHORS = 300
NUM_TARGETS = 200
NUM_CLUSTERS = 4
WORDS = sorted(ENCODERS["local_fasttext"].model.key_to_index.keys())
WORDS = [word for word in WORDS if word.isalpha() and len(word) >= 4]
TARGET_WORDS = ["school", "ferrari", "water", "martial"]  # words to take the neighborhoods from
# TARGET_WORDS = random.sample(WORDS, NUM_CLUSTERS)
print(f"{TARGET_WORDS=}")
word2index = {word: i for i, word in enumerate(WORDS)}
TARGETS = torch.zeros(len(WORDS), device="cpu")
target_cluster = [
    [word for word, sim in ENCODERS["local_fasttext"].model.most_similar(target_word, topn=NUM_TARGETS)]
    for target_word in TARGET_WORDS
]

valid_words, valid_targets = [], []
for i, target_cluster in enumerate(target_cluster):
    valid_words.append(TARGET_WORDS[i])
    valid_targets.append(i + 1)
    for word in target_cluster:
        if word in word2index:
            valid_words.append(word)
            valid_targets.append(i + 1)

WORDS = valid_words
TARGETS = valid_targets

ANCHOR_WORDS = sorted(random.sample(WORDS, NUM_ANCHORS))  # TODO: stratified

ANCHOR_WORDS[:10]

In [None]:
# from umap import UMAP
from sklearn.manifold import TSNE
from enum import auto
from sklearn.decomposition import PCA


def latents_distance(latents):
    assert len(latents) == 2
    for x in latents:
        assert x.shape[1] == 300

    dist = F.pairwise_distance(latents[0], latents[1], p=2).mean()
    return f"{dist:.2f}"


def get_latents(words, encoder: GensimEncoder):
    latents = torch.tensor([encoder.model.get_vector(word) for word in words], device=DEVICE)
    return latents


class Reduction(StrEnum):
    PCA = auto()
    TSNE = auto()
    UMAP = auto()
    FIRST_DIMS = auto()


def to_df(latents, mode: Reduction = "pca"):
    if mode == Reduction.PCA:
        latents2d = PCA(n_components=2).fit_transform(latents.cpu())
    elif mode == Reduction.FIRST_DIMS:
        latents2d = latents[:, [0, 1]]
    elif mode == Reduction.TSNE:
        latents2d = TSNE(n_components=2, init="pca", learning_rate="auto", random_state=42).fit_transform(latents.cpu())
    # elif mode == Reduction.UMAP:
    #     latents2d = UMAP(n_components=2).fit_transform(latents.cpu())
    else:
        raise NotImplementedError

    df = pd.DataFrame(
        {
            "x": latents2d[:, 0].tolist(),
            "y": latents2d[:, 1].tolist(),
            "target": TARGETS,
        }
    )
    return df

In [None]:
def plot_bg(
    ax,
    df,
    cmap,
    norm,
    size,
    bg_alpha,
):
    """Create and return a plot of all our movie embeddings with very low opacity.
    (Intended to be used as a basis for further - more prominent - plotting of a
    subset of movies. Having the overall shape of the map space in the background is
    useful for context.)
    """
    ax.scatter(df.x, df.y, c=cmap(norm(df["target"])), alpha=bg_alpha, s=size)
    return ax


def hightlight_cluster(
    ax,
    df,
    target,
    alpha,
    cmap,
    norm,
    size=0.5,
):
    cluster_df = df[df["target"] == target]
    ax.scatter(cluster_df.x, cluster_df.y, c=cmap(norm(cluster_df["target"])), alpha=alpha, s=size)


def plot_latent_space(ax, df, targets, size, cmap, norm, bg_alpha, alpha):
    ax = plot_bg(ax, df, bg_alpha=bg_alpha, cmap=cmap, norm=norm, size=size)
    for target in targets:
        hightlight_cluster(ax, df, target, alpha=alpha, size=size, cmap=cmap, norm=norm)
    return ax

In [None]:
ae_latents = {}
anchors_latents = {}
for enc_name, encoder in ENCODERS.items():
    ae_latents[enc_name] = get_latents(words=WORDS, encoder=encoder)
    anchors_latents[enc_name] = get_latents(words=ANCHOR_WORDS, encoder=encoder)

import copy

original_ae_latents = copy.deepcopy(ae_latents)
original_anchor_latents = copy.deepcopy(anchors_latents)

In [None]:
REDUCTION = Reduction.PCA

In [None]:
from sklearn.decomposition import PCA
from rae.modules.attention import *


col_config = ((None, None),)
N_ROWS = len(ENCODERS)
N_COLS = len(col_config) + 1
print(
    N_ROWS,
    N_COLS,
)
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

import matplotlib as mpl

num_colors = len(TARGET_WORDS)
cmap = mpl.colors.ListedColormap(plt.cm.get_cmap("Set1", 10).colors[:num_colors], name="rgb", N=num_colors)
norm = plt.Normalize(min(TARGETS), max(TARGETS))

fig, axes = plt.subplots(dpi=300, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)

S = 7
BG_ALPHA = 0.35
ALPHA = 0.5

TARGET_HIGHLIGHT = [1]
for ax_encoders, (_, latents) in zip(axes[0], ae_latents.items()):

    plot_latent_space(
        ax_encoders,
        to_df(latents, mode=REDUCTION),
        targets=TARGET_HIGHLIGHT,
        size=S,
        bg_alpha=BG_ALPHA,
        alpha=ALPHA,
        cmap=cmap,
        norm=norm,
    )

distances = {"absolute": latents_distance(list(ae_latents.values()))}

for col_i, (quant_mode, bin_size) in enumerate(col_config):
    rel_attention = RelativeAttention(
        n_anchors=NUM_ANCHORS,
        n_classes=len(set(TARGETS)),
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
    )
    assert sum(x.numel() for x in rel_attention.parameters()) == 0
    rels = []
    for row_ax, (enc_name, latents), (a_enc_name, a_latents) in zip(
        axes[1], ae_latents.items(), anchors_latents.items()
    ):
        assert enc_name == a_enc_name
        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]
        rels.append(rel)
        plot_latent_space(
            row_ax,
            to_df(rel, mode=REDUCTION),
            targets=TARGET_HIGHLIGHT,
            size=S,
            bg_alpha=BG_ALPHA,
            alpha=ALPHA,
            cmap=cmap,
            norm=norm,
        )
    distances[f"relative({quant_mode}, {bin_size})"] = latents_distance(rels)

distances

In [None]:
file_name: str = f"word-embeddings-qualitative-{REDUCTION}"
file_name_pdf: str = f"{file_name}.pdf"
file_name_svg: str = f"{file_name}.svg"

fig.savefig(f"{file_name}.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o $file_name_pdf $file_name_svg
!rm $file_name_svg