## Sentence Embedding Similarities


Imports

In [58]:
%load_ext autoreload
%autoreload 2

import os
import json
import torch
import torch.nn.functional as F
from transformers import BertModel, BertTokenizerFast, AutoModel, AutoTokenizer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [59]:
LANGUAGE_MODELS = { # which model to use for which language
    'malay': 'gte', # can be either
    'tibetan': 'labse', # must be labse
    'wolof': 'labse', # must be labse
    'quechua': 'gte', # must be gte
    'spanish': 'gte', # can be either
}

# Target pairs to compare
target_pairs = [
    ('official', 'chatgpt'),    
    ('official', 'google'),      
    ('english', 'roundtrip_chatgpt'),  
    ('english', 'roundtrip_google')   
]

# Model for Malay, Tibetan, Wolof, and Spanish
labse_tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
labse_model = BertModel.from_pretrained("setu4993/LaBSE")
labse_model = labse_model.eval()

# Model for Quechua, Malay, Spanish
gte_tokenizer = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-multilingual-base")
gte_model = AutoModel.from_pretrained("Alibaba-NLP/gte-multilingual-base")
gte_model = gte_model.eval()



Some weights of the model checkpoint at Alibaba-NLP/gte-multilingual-base were not used when initializing NewModel: {'classifier.bias', 'classifier.weight'}
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [60]:
def similarity(embeddings_1, embeddings_2):
    normalized_embeddings_1 = F.normalize(embeddings_1, p=2)
    normalized_embeddings_2 = F.normalize(embeddings_2, p=2)
    return torch.matmul(
        normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1)
    )

def get_embedding(text, model_type, max_length=512):
    if model_type == 'labse':
        # Spanish, Wolof, Malay, Tibetan
        # Tibetan needs to be truncated to 512 tokens or model won't accept it
        inputs = labse_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        with torch.no_grad():
            outputs = labse_model(**inputs)
            return outputs.pooler_output
    elif model_type == 'gte':
        # Quechua, Malay, Spanish
        inputs = gte_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = gte_model(**inputs)
            return outputs.last_hidden_state[:, 0, :]
    return None

In [61]:
all_results = {}
for file in os.listdir('translations'):
    if not file.endswith('.json'):
        continue
        
    language_name = file.split('.')[0]
    # all json files should be in the format language_name.json

    model_type = LANGUAGE_MODELS.get(language_name)
    print(f"Processing {language_name} using {model_type} model")
    
    try:
        with open(os.path.join('translations', file), 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                passages = data.get('passages', {})
                
                if not passages:
                    print("empty file")
                    continue
                
                all_results[language_name] = {
                    'passage_similarities': {},
                    'avg_similarities': {}
                }
                
                all_pair_similarities = {f"{src}:{tgt}": [] for src, tgt in target_pairs}
                
                for passage_id, translations in passages.items():
                    passage_sims = {}
                    embeddings = {}
                    
                    for trans_type, text in translations.items():
                        if text and text != "None" and text.strip(): # Passages with no official translation have "None" as the text -- only passages 5/6
                            embedding = get_embedding(text, model_type)
                            if embedding is not None:
                                embeddings[trans_type] = embedding
                    
                    for src, tgt in target_pairs:
                        pair_name = f"{src}:{tgt}"
                        if src in embeddings and tgt in embeddings:
                            sim = float(similarity(embeddings[src], embeddings[tgt]))
                            passage_sims[pair_name] = sim
                            all_pair_similarities[pair_name].append(sim)
                    
                    if passage_sims:
                        print(f"passage id: {passage_id}, similarities: {passage_sims}")
                        all_results[language_name]['passage_similarities'][passage_id] = passage_sims
                
                for pair_name, sims in all_pair_similarities.items():
                    if sims:
                        all_results[language_name]['avg_similarities'][pair_name] = sum(sims) / len(sims)
                
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON from file {file}: {e}")
    except Exception as e:
        print(f"Error processing file {file}: {e}")

Processing wolof using labse model
passage id: 1, similarities: {'official:chatgpt': 0.9103513956069946, 'official:google': 0.9570765495300293, 'english:roundtrip_chatgpt': 0.9767275452613831, 'english:roundtrip_google': 0.9953362345695496}
passage id: 2, similarities: {'official:chatgpt': 0.5727707743644714, 'official:google': 0.859481692314148, 'english:roundtrip_chatgpt': 0.9873533844947815, 'english:roundtrip_google': 0.9770512580871582}
passage id: 3, similarities: {'official:chatgpt': 0.8564803600311279, 'official:google': 0.9046306610107422, 'english:roundtrip_chatgpt': 0.9779288172721863, 'english:roundtrip_google': 1.0000001192092896}
passage id: 4, similarities: {'official:chatgpt': 0.8179197311401367, 'official:google': 0.8423664569854736, 'english:roundtrip_chatgpt': 0.9694747924804688, 'english:roundtrip_google': 0.9845173358917236}
passage id: 5, similarities: {'english:roundtrip_chatgpt': 0.8109887838363647, 'english:roundtrip_google': 0.9091354608535767}
passage id: 6, 

In [62]:
for language, results in all_results.items():
    avg_sims = results['avg_similarities']
    print(f"language: {language}, avg_sims: {avg_sims}")
    if not avg_sims:
        print(f"{language} has empty avg_sims")
        continue
    

language: wolof, avg_sims: {'official:chatgpt': 0.7893805652856827, 'official:google': 0.8908888399600983, 'english:roundtrip_chatgpt': 0.9352389474709829, 'english:roundtrip_google': 0.9738694429397583}
language: spanish, avg_sims: {'official:chatgpt': 0.9740583300590515, 'official:google': 0.9782789796590805, 'english:roundtrip_chatgpt': 0.9728038807710012, 'english:roundtrip_google': 0.964999665816625}
language: nahuatl, avg_sims: {}
nahuatl has empty avg_sims
language: tibetan, avg_sims: {'official:chatgpt': 0.7794427126646042, 'official:google': 0.8647293299436569, 'english:roundtrip_chatgpt': 0.8182365496953329, 'english:roundtrip_google': 0.9385332564512888}
language: quechua, avg_sims: {'official:chatgpt': 0.7556637823581696, 'official:google': 0.7148463129997253, 'english:roundtrip_chatgpt': 0.8158847788969675, 'english:roundtrip_google': 0.9731576144695282}
language: malay, avg_sims: {'official:chatgpt': 0.8715976029634476, 'official:google': 0.8615406304597855, 'english:roun