In [2]:
# !wget https://github.com/mesolitica/llama2-embedding/raw/main/test-set/malay-news-dataset-bge-test.sample.json
# !wget https://huggingface.co/datasets/mesolitica/embedding-pair-mining/resolve/main/malay-news-dataset-bge-test.jsonl

In [3]:
import json
import openai
from tqdm import tqdm

In [4]:
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity

model = AutoModel.from_pretrained('mesolitica/llama2-embedding-600m-8k', trust_remote_code = True)
tokenizer = AutoTokenizer.from_pretrained('mesolitica/llama2-embedding-600m-8k')

[2023-10-09 07:13:55,878] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2023-10-09 07:13:56.361395: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
_ = model.cuda()

In [6]:
with open('malay-news-dataset-bge-test.sample.json') as fopen:
    rev_data = json.load(fopen)

In [7]:
vectors = {}
for k, v in tqdm(rev_data.items()):
    try:
        padded = tokenizer([k],return_tensors = 'pt', padding = True)
        for k_ in padded:
            padded[k_] = padded[k_].cuda()

        vectors[k] = model.encode(padded).cpu().detach().numpy()[0]
    except:
        pass

100%|██████████| 18237/18237 [01:31<00:00, 198.39it/s]


In [8]:
no_string = {no: k for no, (k, v) in enumerate(vectors.items())}
string_no = {v: k for k, v in no_string.items()}
no_string[0]

'Sejurus pulih, Subasic yang muncul wira Croatia dalam sepakan penalti menentang Denmark pada pusingan 16 pasukan terakhir, menyelamatkan rembatan Fedor Smolov dalam masa kecederaan.'

In [9]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [10]:
line = 0
positive, negative = [], []
with open('malay-news-dataset-bge-test.jsonl') as fopen:
    for l in tqdm(fopen):
        l = json.loads(l)
        query = l['query'].strip()
        v_query = vectors.get(query)
        if v_query is None:
            continue
        v_query = np.array(v_query).reshape((1, -1))
        for s in l['pos']:
            v_s = vectors.get(s.strip())
            if v_s is None:
                continue
            v_s = np.array(v_s).reshape((1, -1))
            positive.append(cosine_similarity(v_query, v_s)[0, 0])
        
        for s in l['neg']:
            v_s = vectors.get(s.strip())
            if v_s is None:
                continue
            v_s = np.array(v_s).reshape((1, -1))
            negative.append(1 - cosine_similarity(v_query, v_s)[0, 0])
            
        line += 1
        if line >= 1000:
            break

999it [00:06, 159.38it/s]


In [11]:
np.mean(positive), np.mean(negative)

(0.71082395, 0.7160432709481884)

In [12]:
np_vectors = np.array(list(vectors.values()))
np_vectors.shape

(18237, 1536)

In [13]:
tops = {
    1: 0,
    3: 0,
    5: 0,
    10: 0,
}
total = 0
with open('malay-news-dataset-bge-test.jsonl') as fopen:
    for l in tqdm(fopen):
        l = json.loads(l)
        query = l['query'].strip()
        query_no = string_no.get(query)
        if query_no is None:
            continue
        for s in l['pos']:
            s = s.strip()
            v_s = vectors.get(s)
            s_no = string_no.get(s)
            if v_s is None:
                continue
            v_s = np.array(v_s).reshape((1, -1))
            argsort = np.argsort(cosine_similarity(v_s, np_vectors)[0])[::-1]
            for k in tops.keys():
                if s_no in argsort[:k]:
                    k_ = k + 1
                else:
                    k_ = k
                if query_no in argsort[:k_]:
                    tops[k] += 1
            total += 1

19613it [07:38, 42.80it/s]  


In [14]:
tops, total

({1: 1213, 3: 2345, 5: 3095, 10: 4005}, 8501)

In [16]:
for k, v in tops.items():
    print(k, v / total)

1 0.14268909540054112
3 0.27584990001176335
5 0.3640748147276791
10 0.47112104458299026
