## Init

In [None]:
import numpy as np
from tqdm.auto import trange, tqdm
from scipy.linalg import orthogonal_procrustes
from scipy.optimize import quadratic_assignment
from sklearn.cluster import KMeans

import torch
import torch.nn.functional as F
from datasets import load_dataset

def cos_sim_matrix(X, Y):
    if isinstance(X, np.ndarray):
        X = torch.from_numpy(X)
    if isinstance(Y, np.ndarray):
        Y = torch.from_numpy(Y)
    X_norm = X / X.norm(dim=-1, keepdim=True)
    Y_norm = Y / Y.norm(dim=-1, keepdim=True)
    return X_norm @ Y_norm.T

def tensor(x):
    return torch.tensor(x).float()

def N(X, dim=-1, **kwargs):
    return F.normalize(X, dim=dim, **kwargs)

def sim(X, Y):
    X, Y = tensor(X), tensor(Y)
    # center the tensors
    # TODO: centering probably not necessary, and perhaps not implemented correctly, might need transpose somewhere and stuff
    H = torch.eye(len(X), device=X.device) - (1/len(X)) * torch.ones((len(X), len(X)), device=X.device)
    return H @ X @ Y.T @ H

def train_orthogonal_linear(X, Y):
    solution, _ = orthogonal_procrustes(X, Y)
    return tensor(solution)

def eval_score(X_eval, Y_eval, W, backward=False):
    if backward:
        return torch.round(torch.cosine_similarity(X_eval, Y_eval @ W.T, dim=-1).mean(), decimals=2)
    else:
        return torch.round(torch.cosine_similarity(X_eval @ W, Y_eval, dim=-1).mean(), decimals=2)

## Load Data

In [None]:
from huggingface_hub import hf_hub_download
import torch

repo_id = 'dar-tau/nq-embeddings'
filenames = {'stella': 'stella.pt',
             'e5': 'e5.pt',
             'granite': 'granite.pt',
             'gtr': 'gtr.pt'
             }
sent_embeds = {}
for key, filename in filenames.items():
    local_file = hf_hub_download(repo_id=repo_id, filename=filename)
    sent_embeds[key] = torch.load(local_file, map_location='cpu')

In [None]:
embed_A, embed_B = sent_embeds['e5'].cpu(), sent_embeds['gtr'].cpu()

n = len(embed_A) - 8192
mean_A = embed_A[:n].mean(dim=0)
E_A1, E_A2, E_A3 = N(embed_A[:n // 2] - mean_A), N(embed_A[n // 2:n] - mean_A), N(embed_A[n:] - mean_A)

mean_B = embed_B[:n].mean(dim=0)
E_B1, E_B2, E_B3 = N(embed_B[:n // 2] - mean_B), N(embed_B[n // 2:n] - mean_B), N(embed_B[n:] - mean_B)

# Explanation:
# E_A1 and E_B1 represent the same sentences, just embedded in different spaces (and similarly for E_A2/E_A3 and E_B2/E_B3)

# We want to train on non-overlapping sets of sentences
# (because we assume that only the distribution of sentences is shared, not the individual sentences)
# Thus, we will train on E_A1 and E_B2, meaning there is no ground-truth alignment between the two sets
X_train, Y_train = E_A1, E_B2

# For evaluation, we will use the ground-truth alignment of E_A3 and E_B3
X_eval, Y_eval = E_A3, E_B3

## Step 1: Match Anchors

In [None]:
def aligned_centroids(X_train, Y_train, n_runs=300, n_clusters=50, method='2opt', subsample=None):
    options = {'P0': 'randomized', 'maximize': True}
    if subsample is not None:
        X_train, Y_train = X_train[torch.randperm(len(X_train))[:subsample]], Y_train[torch.randperm(len(Y_train))[:subsample]]

    clusterer1 = KMeans(n_clusters=n_clusters)
    clusterer1.fit(X_train)
    clusterer2 = KMeans(n_clusters=n_clusters)
    clusterer2.fit(Y_train)
    centers1, centers2 = clusterer1.cluster_centers_, clusterer2.cluster_centers_
    kernel1 =  sim(centers1, centers1).float()
    kernel2 = sim(centers2, centers2).float()

    quad = None
    # need to re-run the QAP a few times because it's not very good at finding the global optimum (even 2opt)
    for i in trange(n_runs):
        new_quad = quadratic_assignment(kernel1, kernel2, method=method, options=options)
        if quad is None or quad.fun < new_quad.fun:
            quad = new_quad
    centers2 = centers2[quad.col_ind]
    return tensor(centers1), tensor(centers2)

In [None]:
all_centers1, all_centers2 = [], []
for i in trange(30):
    centers1, centers2 = aligned_centroids(X_train, Y_train, subsample=10_000,
                                           n_clusters=20, n_runs=30, method='2opt')
    all_centers1.append(centers1)
    all_centers2.append(centers2)

In [None]:
all_centers1 = torch.cat(all_centers1, dim=0)
all_centers2 = torch.cat(all_centers2, dim=0)

In [None]:
sim1 = cos_sim_matrix(X_train, all_centers1)
sim2 = cos_sim_matrix(Y_train, all_centers2)
sim_similarity = cos_sim_matrix(sim1, sim2)

In [None]:
k = 50
top_similar = sim_similarity.topk(dim=-1, k=k).indices

In [None]:
coefs =  torch.ones(k) / k # N(1 / (1 + torch.arange(k))**.5, p=1) #
Y_matched = Y_train[top_similar].transpose(-1, -2) @ coefs

## Step 2: Train Mapping

In [None]:
W = train_orthogonal_linear(X_train, Y_matched)

In [None]:
print('Eval score:', eval_score(X_eval, Y_eval, W))

## Step 3: Refinement

### Refine-1: Iterative Closest Point Average

In [None]:
# depending on how well the previous step went, you might want to increase/decrease this
# of course, in reality you are not privy to the eval score..
n_iters = 100
k = 50
for iter in trange(n_iters):
    print('-' * 20)
    print(f'ITER {iter + 1}')
    sample_points = X_train[torch.randperm(len(X_train))[:1000]]
    sample_similarities = cos_sim_matrix(sample_points @ W, Y_train)
    neighbors = sample_similarities.topk(dim=-1, k=k).indices
    sample_matched = Y_train[neighbors].mean(dim=1)
    W_new = train_orthogonal_linear(sample_points, sample_matched)
    W = 0.5 * W + 0.5 * W_new
    print('Eval score:', eval_score(X_eval, Y_eval, W))

### Refine-2: Cluter-Based Alignment Correction

In [None]:
n_iters = 1

for iter in trange(n_iters):
    print('-' * 20)
    print(f'ITER {iter + 1}')
    print('Compute KMeans #1...')
    kmeans1 = KMeans(n_clusters=500).fit(X_train)
    centers1 = tensor(kmeans1.cluster_centers_)
    print('Compute KMeans #2...')
    kmeans2 = KMeans(n_clusters=500, init=centers1 @ W).fit(Y_train)
    centers2 = tensor(kmeans2.cluster_centers_)

    print('Self consistency of KMeans',
          torch.cosine_similarity(centers1 @ W, centers2, dim=-1).mean()
          )
    W_new = train_orthogonal_linear(centers1, centers2)
    W = 0.5 * W + 0.5 * W_new
    print('Eval score:', eval_score(X_eval, Y_eval, W))

## Final Evaluation

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def rank(X):
    return torch.argsort(torch.argsort(X, dim=-1), dim=-1)

In [None]:
ranks = rank(cos_sim_matrix(X_eval.to(device) @ W.to(device), Y_eval.to(device))).diagonal()

In [None]:
print("Top-1 Accuracy:", (len(X_eval) - 1 == ranks).float().mean())
print("Average Rank:", len(X_eval) - ranks.float().mean())