# Embedding Matching

## Algorithm
1. choose anchor
2. same embedding transformation learning

In [1]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

In [2]:
import umap
import umap.plot

In [3]:
from scipy.spatial.distance import cdist
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html#scipy.spatial.distance.cdist

In [4]:
from sklearn.datasets import fetch_20newsgroups
news_data = fetch_20newsgroups(subset='test')
n = len(news_data.data)

In [5]:
print(news_data.target.shape)
news_data.target
news_data.data[0]

(7532,)


'From: v064mb9k@ubvmsd.cc.buffalo.edu (NEIL B. GANDLER)\nSubject: Need info on 88-89 Bonneville\nOrganization: University at Buffalo\nLines: 10\nNews-Software: VAX/VMS VNEWS 1.41\nNntp-Posting-Host: ubvmsd.cc.buffalo.edu\n\n\n I am a little confused on all of the models of the 88-89 bonnevilles.\nI have heard of the LE SE LSE SSE SSEI. Could someone tell me the\ndifferences are far as features or performance. I am also curious to\nknow what the book value is for prefereably the 89 model. And how much\nless than book value can you usually get them for. In other words how\nmuch are they in demand this time of year. I have heard that the mid-spring\nearly summer is the best time to buy.\n\n\t\t\tNeil Gandler\n'

In [6]:
# https://www.sbert.net/docs/pretrained_models.html
main_model_type = 'all-mpnet-base-v1'
mimic_model_type = 'all-distilroberta-v1' 

# all-mpnet-base-v2
# all-distilroberta-v1
#'all-MiniLM-L12-v2' 
# average_word_embeddings_glove.6B.300d

from sentence_transformers import SentenceTransformer
main_model = SentenceTransformer(main_model_type)
mimic_model = SentenceTransformer(mimic_model_type)

In [7]:
main_emb = main_model.encode(news_data.data)
mimic_emb = mimic_model.encode(news_data.data)

In [8]:
def calculate_metrics(match_index_prediction, match_main_emb, verbose=True):
    # evaluation metric
    # 1. accuracy
    ground_truth = np.arange(len(match_index_prediction))
    accuracy = np.mean(ground_truth == match_index_prediction)

    # 2. average origin doc emb cosine similarity
    predict_main_emb = match_main_emb[match_index_prediction]
    cos_pred_match = cosine_similarity(predict_main_emb, match_main_emb)

    average_cos = np.trace(cos_pred_match) / match_main_emb.shape[0]

    if verbose:
        print('len', len(match_index_prediction))
        print('accuracy', accuracy)
        print('global average cos similarity', np.mean(cos_pred_match))
        print('predict average cos similarity', average_cos)
    
    return accuracy, average_cos

## Select Anchor

In [9]:
select_num = 200

In [10]:
anchor_index = np.random.choice(np.arange(n), size=select_num, replace=False)
match_index = np.delete(np.arange(n), anchor_index)

anchor_index, match_index

(array([4341, 2401,  434, 4955, 5055, 4782, 1599, 6915, 2328, 6718, 3483,
         324,  745, 6352, 3687, 5580, 2043, 5244, 6571, 6699, 2435, 7291,
        1077, 6032, 7362, 5047, 3696, 6595, 3094,  598, 3341, 3495, 5534,
        5724, 7475, 7321, 5598, 7323, 7390, 4091, 3545,  439, 5460, 5913,
        5508, 5380, 1468, 1153, 5758, 1822, 4276, 3909, 7490, 5958, 5199,
        7249, 1437, 5728, 2384, 2051, 3199, 3340, 3569, 2499, 1402, 5214,
        5733, 4254, 1346,  167, 1028, 2643, 5727, 3949, 1874, 6920, 3354,
        6845,  424, 1520,  427, 1290,  461, 5409, 2121,  471, 4317, 4810,
        1805, 2451, 3355,  766, 2862, 3744, 5535, 5288, 6936, 2580, 5360,
        5830, 6190, 4804, 2757, 6503,  485,  643, 3984, 3029, 1510, 3145,
        4775, 4658, 1151, 1942, 2920, 4752, 1550, 4297, 6599, 6817, 6067,
        5880, 5609, 3582, 2542,  616, 6216, 1639, 4932, 2897, 7377, 3459,
        2077, 6289, 4807, 6357,  497, 2206, 5638,  143,  718,  205, 7447,
        6664, 3186, 3694,  221, 6339, 

## Embedding Transformation Learning

In [11]:
train_main_emb = main_emb[anchor_index]
valid_main_emb = main_emb[match_index]

train_mimic_emb = mimic_emb[anchor_index]
valid_mimic_emb = mimic_emb[match_index]

## Procrustes baseline

In [12]:
from procrustes.orthogonal import orthogonal
from procrustes.generic import generic
from procrustes.rotational import rotational
from procrustes.orthogonal import orthogonal_2sided

def procrustes_train(model):
    print('train num', train_main_emb.shape[0])
    if model == 'orthogonal':
        result = orthogonal(train_main_emb, train_mimic_emb, translate=translate, scale=scale)
    elif model == 'generic':
        result = generic(train_main_emb, train_mimic_emb, translate=translate, scale=scale)
    elif model == 'rotational':
        result = rotational(train_main_emb, train_mimic_emb, translate=translate, scale=scale)
    elif model == 'orthogonal_2sided':
        result = orthogonal_2sided(train_main_emb, train_mimic_emb, translate=translate, scale=scale, single=False)
    
    pred_main_emb = valid_main_emb@result.t
    dm = cdist(pred_main_emb, valid_mimic_emb, 'euclidean')
    match_idx = np.argmin(dm, axis=1)
    
    acc, cos = calculate_metrics(match_idx, valid_main_emb)
    
model = 'rotational'
translate = True
scale = False
procrustes_train(model)

train num 200
len 7332
accuracy 0.5932896890343698
global average cos similarity 0.10376183
predict average cos similarity 0.8506935564989089


## Unsupervised Baseline

In [21]:
main_emb.T.shape

(768, 7532)

In [33]:
from procrustes.permutation import permutation

result = permutation(main_emb.T, mimic_emb.T, translate=False, scale=False)

pred_main_emb = main_emb.T@result.t

In [34]:
result.t.shape

(7532, 7532)

In [35]:
(np.argmax(result.t, axis=0) == np.arange(result.t.shape[0])).sum()

0