In [None]:
% load_ext autoreload
% autoreload 2

In [None]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes

logging.getLogger().setLevel(logging.ERROR)

DEVICE: str = "cuda"

In [None]:
from rae.modules.text.encoder import GensimEncoder

ENCODERS = [
    GensimEncoder(language="en", lemmatize=False, model_name=model_name)
    for model_name in (
        "local_fasttext",
        "word2vec-google-news-300",
        "glove-wiki-gigaword-300",
    )
]

In [None]:
assert len({frozenset(encoder.model.key_to_index.keys()) for encoder in ENCODERS}) == 1

In [None]:
import random

NUM_ANCHORS = 500
NUM_TARGETS = 1000
# NUM_WORDS = 20_000
WORDS = sorted(ENCODERS[0].model.key_to_index.keys())
WORDS = [word for word in WORDS if word.isalpha() and len(word) >= 4]
TARGET_WORDS = ["sea", "human", "sword"]  # words to take the neighborhoods from
TARGET_WORDS = random.sample(WORDS, 3)
print(f"{TARGET_WORDS=}")
word2index = {word: i for i, word in enumerate(WORDS)}
TARGETS = torch.zeros(len(WORDS), device="cpu")
target_cluster = [
    [word for word, sim in ENCODERS[0].model.most_similar(target_word, topn=NUM_TARGETS)]
    for target_word in TARGET_WORDS
]

valid_words, valid_targets = [], []
for i, target_cluster in enumerate(target_cluster):
    valid_words.append(TARGET_WORDS[i])
    valid_targets.append(i + 1)
    for word in target_cluster:
        if word in word2index:
            valid_words.append(word)
            valid_targets.append(i + 1)

WORDS = valid_words
TARGETS = valid_targets

ANCHOR_WORDS = sorted(random.sample(WORDS, NUM_ANCHORS))

ANCHOR_WORDS[:10]

In [None]:
from sklearn.decomposition import PCA


def get_latents(words, encoder: GensimEncoder, return_df: bool = True):
    latents = torch.tensor([encoder.model.get_vector(word) for word in words], device=DEVICE)

    latents2d = latents[:, [0, 1]].cpu()
    latents2d = PCA(n_components=2).fit_transform(latents.cpu())
    df = None
    if return_df:
        df = pd.DataFrame(
            {
                "x": latents2d[:, 0].tolist(),
                "y": latents2d[:, 1].tolist(),
                # "class": classes,
                "target": TARGETS,
                # "index": indexes,
            }
        )
    return df, latents

# Plot stuff

In [None]:
def plot_bg(
    ax,
    df,
    cmap,
    norm,
    size=0.5,
    bg_alpha=0.01,
):
    """Create and return a plot of all our movie embeddings with very low opacity.
    (Intended to be used as a basis for further - more prominent - plotting of a
    subset of movies. Having the overall shape of the map space in the background is
    useful for context.)
    """
    ax.scatter(df.x, df.y, c=cmap(norm(df["target"])), alpha=bg_alpha, s=size)
    return ax


def hightlight_cluster(
    ax,
    df,
    target,
    alpha,
    cmap,
    norm,
    size=0.5,
):
    cluster_df = df[df["target"] == target]
    ax.scatter(cluster_df.x, cluster_df.y, c=cmap(norm(cluster_df["target"])), alpha=alpha, s=size)


def plot_latent_space(ax, df, targets, size, cmap, norm, bg_alpha=0.1, alpha=0.5):
    ax = plot_bg(ax, df, bg_alpha=bg_alpha, cmap=cmap, norm=norm)
    for target in targets:
        hightlight_cluster(ax, df, target, alpha=alpha, size=size, cmap=cmap, norm=norm)
    return ax


LIM = len(ENCODERS)
N_ROWS = 1
N_COLS = LIM

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(min(TARGETS), max(TARGETS))


def plot_row(df, title, equal=True, sharey=False, sharex=False, dpi=150):
    fig, axes = plt.subplots(dpi=dpi, nrows=N_ROWS, ncols=N_COLS, sharey=sharey, sharex=sharex, squeeze=True)

    for j, ax in enumerate(axes):
        if j == 0:
            ax.set_title(title)
        if equal:
            ax.set_aspect("equal")
        plot_latent_space(ax, df[j], targets=[0, 1], size=0.75, bg_alpha=0.25, alpha=1, cmap=cmap, norm=norm)
    return fig

In [None]:
def latents_distance(latents):
    dists = []
    for i in range(len(latents)):
        for j in range(i + 1, len(latents)):
            x = latents[i][1]
            y = latents[j][1]
            # dist = ((x - y)**2).sum(dim=-1).sqrt().mean()
            dist = F.pairwise_distance(x, y, p=torch.inf).mean()
            dist = F.mse_loss(x, y, reduction="mean")

            # dist = ((x - y) ** 2).mean(dim=-1).mean()
            dists.append(f"{i}-{j}: {dist}")
    return " ".join(dists)

## AE

In [None]:
ae_latents = []
anchors_latents = []
for encoder in ENCODERS:
    df, latents = get_latents(words=WORDS, encoder=encoder, return_df=True)
    _, a_latents = get_latents(words=ANCHOR_WORDS, encoder=encoder, return_df=False)
    ae_latents.append((df, latents))
    anchors_latents.append(a_latents)

import copy

original_ae_latents = copy.deepcopy(ae_latents)
original_anchor_latents = copy.deepcopy(anchors_latents)

In [None]:
f = plot_row([df for df, _ in original_ae_latents[:LIM]], "AE", True, True, True)
latents_distance(ae_latents[:LIM])

## Rel Attention Quantized

In [None]:
from sklearn.decomposition import PCA

f = plot_row([df for df, _ in ae_latents[:LIM]], f"AE: {latents_distance(ae_latents[:LIM])}", True, True, True)

# Relative
from rae.modules.attention import *

# Quantized
for quant_mode, bin_size in (
    (None, None),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.0001),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.05),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.1),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.3),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.4),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.5),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.6),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.7),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.8),
):
    rel_latents = []
    rel_attention = RelativeAttention(
        n_anchors=NUM_ANCHORS,
        n_classes=len(set(TARGETS)),
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
        # output_normalization_mode=OutputNormalization.NONE,
        similarities_quantization_mode=quant_mode,
        similarities_bin_size=bin_size,
        # absolute_quantization_mode=quant_mode,
        # absolute_bin_size=bin_size
        hidden_features=None,
        transform_elements=None,
        in_features=None,
        values_self_attention_nhead=None,
        similarities_aggregation_mode=None,
        similarities_aggregation_n_groups=None,
        anchors_sampling_mode=None,
        n_anchors_sampling_per_class=None,
    )
    assert sum(x.numel() for x in rel_attention.parameters()) == 0
    for (_, latents), a_latents in zip(ae_latents, anchors_latents):
        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]
        rellatents2d = rel[:, [0, 1]]
        # rellatents2d = PCA(n_components=2).fit_transform(rel.cpu().detach())
        df = pd.DataFrame(
            {
                "x": rellatents2d[:, 0].tolist(),
                "y": rellatents2d[:, 1].tolist(),
                # "class": classes,
                "target": TARGETS,
                # "index": indexes,
            }
        )
        rel_latents.append((df, rel))
    f = plot_row(
        [df for df, _ in rel_latents[:LIM]],
        f"QAtt, bin size: {bin_size}: {latents_distance(rel_latents[:LIM])}",
        True,
        True,
        True,
    )

# Optimal transofrm

In [None]:
from rae.modules.attention import *

ae, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=checkpoints["mnist"]["ae"][0],
    map_location="cpu",
)

att = RelativeAttention(
    n_anchors=anchors_batch.shape,
    n_classes=len(set(targets)),
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
    output_normalization_mode=OutputNormalization.NONE,
    similarities_quantization_mode=None,
    similarities_bin_size=None,
    # absolute_quantization_mode=AbsoluteQuantizationMode.DIFFERENTIABLE_ROUND,
    # absolute_bin_size=bin_size
)
att_q = RelativeAttention(
    n_anchors=anchors_batch.shape,
    n_classes=len(set(targets)),
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
    output_normalization_mode=OutputNormalization.NONE,
    similarities_quantization_mode=SimilaritiesQuantizationMode.CUSTOM_ROUND,
    similarities_bin_size=0.1,
    # absolute_quantization_mode=AbsoluteQuantizationMode.DIFFERENTIABLE_ROUND,
    # absolute_bin_size=0.1
)

ae.eval()
images_z = ae(images_batch)[Output.DEFAULT_LATENT].detach()
anchors_z = ae(anchors_batch)[Output.DEFAULT_LATENT].detach()

In [None]:
from tqdm import tqdm
import torch
from torch.optim.adam import Adam

# Absolut
from scipy.stats import ortho_group

opt_isometry = torch.tensor(ortho_group.rvs(images_z.shape[-1]), dtype=torch.float, requires_grad=True)
opt_shift = torch.zeros(images_z.shape[-1], dtype=torch.float, requires_grad=True)

opt = Adam([opt_isometry, opt_shift], lr=1e-4)


def transform(x):
    return x @ opt_isometry + opt_shift


R = 1000
Q = 1
I = 1000
S = 0
for i in (bar := tqdm(range(100))):
    rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
    rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
    rel_dist = F.mse_loss(rel, rel_iso, reduction="sum")
    rel_loss = -rel_dist * R

    qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
    qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
    qrel_dist = F.mse_loss(qrel, rel_iso, reduction="sum")
    qrel_loss = qrel_dist * Q

    t_temp = opt_isometry @ opt_isometry.T
    iso_loss = ((t_temp - t_temp.diag().diag()) ** 2).sum() * I
    # iso_loss = (t_temp ** 2 - torch.eye(t_temp.shape[0])).sum() * I
    shift_loss = opt_shift.abs().sum() * S
    loss = rel_loss + qrel_loss + iso_loss + shift_loss

    bar.set_description(f"Rel: {rel_loss.item():3f} \t Qua: {qrel_loss.item():3f} \t  Iso: {iso_loss.item():3f}")
    loss.backward()
    opt.step()
    opt.zero_grad()

rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Relative mse:", F.mse_loss(rel, rel_iso, reduction="sum"))

qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Quantized mse:", F.mse_loss(qrel, qrel_iso, reduction="sum"))

In [None]:
ae.eval()
images_z = ae(images_batch)[Output.DEFAULT_LATENT].detach()
anchors_z = ae(anchors_batch)[Output.DEFAULT_LATENT].detach()

rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Relative mse:", F.mse_loss(rel, rel_iso, reduction="sum"))

qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Quantized mse:", F.mse_loss(qrel, qrel_iso, reduction="sum"))

In [None]:
qrel

In [None]:
b = torch.as_tensor(0.5)
x = torch.linspace(-1, 1, 200)
y = x - torch.sin(2 * torch.pi * x) / (2 * torch.pi)

a = 1
f = 1 / b
s = 0
y = x - a * torch.cos(2 * torch.pi * f * x + s) / (2 * torch.pi * f)

fig, ax = plt.subplots(1, 1, dpi=150)
f = ax.plot(
    x,
    y,
)