In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda"
hidden_size = 512
n_x, n_y = 200000, 1024

emb_x = nn.Embedding(n_x, hidden_size).to(device)
emb_y = nn.Embedding(n_y, hidden_size).to(device)

targets = torch.zeros(n_x, n_y, device=device)
pos_x = torch.randint(0, n_x, (n_y,), device=device)
pos_y = torch.arange(n_y, device=device)
targets[pos_x, pos_y] = 1.0

opt = torch.optim.Adam(list(emb_x.parameters()) + list(emb_y.parameters()), lr=1e-3)

def metrics(sim, tgt):
    ranks = torch.argsort(sim, dim=1, descending=True)
    tgt_idx = tgt.argmax(dim=1)
    idx_pos = (ranks == tgt_idx.unsqueeze(1)).nonzero(as_tuple=True)[1] + 1
    r1 = (idx_pos == 1).float().mean().item()
    r10 = (idx_pos <= 10).float().mean().item()
    mrr = (1.0 / idx_pos.float()).mean().item()
    return r1, r10, mrr

scheme = "standard"  # "simplified" or "standard"

N_total = n_x * n_y
N_pos = targets.sum().item()
N_neg = N_total - N_pos

if scheme == "simplified":
    w_pos, w_neg = N_neg / N_pos, 1.0
elif scheme == "standard":
    w_pos, w_neg = N_total / (2 * N_pos), N_total / (2 * N_neg)

for i in range(1000000):
    x = F.normalize(emb_x(torch.arange(n_x, device=device)), dim=-1)
    y = F.normalize(emb_y(torch.arange(n_y, device=device)), dim=-1)
    sim = x @ y.t()
    loss = F.mse_loss(sim, targets, reduction="none")
    weights = torch.where(targets == 1, w_pos, w_neg)
    loss = (loss * weights).mean()
    opt.zero_grad(); loss.backward(); opt.step()
    with torch.no_grad():
        sim_eval = sim[pos_x]                # (n_y, n_y)
        tgt_eval = targets[pos_x]
        r1, r10, mrr = metrics(sim_eval, tgt_eval)
        pos_scores = sim_eval[torch.arange(n_y), pos_y]
        hardest_neg = (sim_eval + tgt_eval * -1e9).max(dim=1).values
        print(f"{i}: {scheme} | loss {loss.item():.4f} | R1 {r1:.3f} R10 {r10:.3f} MRR {mrr:.3f} | pos {pos_scores.mean().item():.3f} hardneg {hardest_neg.mean().item():.3f}")


0: standard | loss 0.5013 | R1 0.000 R10 0.008 MRR 0.006 | pos 0.001 hardneg 0.144
1: standard | loss 0.4997 | R1 0.000 R10 0.010 MRR 0.006 | pos 0.002 hardneg 0.144
2: standard | loss 0.4982 | R1 0.000 R10 0.010 MRR 0.007 | pos 0.004 hardneg 0.144
3: standard | loss 0.4966 | R1 0.000 R10 0.011 MRR 0.007 | pos 0.005 hardneg 0.144
4: standard | loss 0.4950 | R1 0.000 R10 0.011 MRR 0.007 | pos 0.007 hardneg 0.144
5: standard | loss 0.4934 | R1 0.000 R10 0.013 MRR 0.008 | pos 0.008 hardneg 0.144
6: standard | loss 0.4919 | R1 0.000 R10 0.014 MRR 0.008 | pos 0.010 hardneg 0.144
7: standard | loss 0.4903 | R1 0.001 R10 0.014 MRR 0.010 | pos 0.012 hardneg 0.144
8: standard | loss 0.4887 | R1 0.001 R10 0.016 MRR 0.010 | pos 0.013 hardneg 0.144
9: standard | loss 0.4872 | R1 0.001 R10 0.018 MRR 0.011 | pos 0.015 hardneg 0.144
10: standard | loss 0.4856 | R1 0.002 R10 0.021 MRR 0.012 | pos 0.016 hardneg 0.144
11: standard | loss 0.4841 | R1 0.003 R10 0.024 MRR 0.014 | pos 0.018 hardneg 0.144
12

KeyboardInterrupt: 