In [1]:
from modeling_contrastive import LlamaModelEmbedding
from transformers import AutoTokenizer
from transformers.trainer_utils import get_last_checkpoint
from sklearn.metrics.pairwise import cosine_similarity
import torch

In [2]:
checkpoint = get_last_checkpoint('embedding-model-llama-600m-contrastive')
checkpoint

'embedding-model-llama-600m-contrastive/checkpoint-85500'

In [4]:
LlamaModelEmbedding.register_for_auto_class()

In [5]:
model = LlamaModelEmbedding.from_pretrained(checkpoint)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [7]:
input_ids = tokenizer(
    [
        'tak suka ayam', 
        'Isu perkauman: Kerajaan didakwa terdesak kaitkan pemimpin PN',
        'nasi ayam tu sedap', 
        'suka ikan goreng?',
        'Kerajaan tidak akan berkompromi dengan isu perkauman dan agama yang dimanipulasi pihak tertentu untuk mengganggu-gugat kestabilan negara serta ketenteraman rakyat.',
        'rasis bodo mamat tu',
        'kerajaan sekarang xde otak',
        'aku nak sukan olimpik ni',
        'malaysia dapat x pingat kt sukan asia?',
        'pingat gangsa menerusi terjun dan olahraga pada hari ke-10',
        'Kerajaan negeri kini dibenarkan melaksanakan penerokaan awal unsur nadir bumi (REE) berdasarkan prosedur operasi standard (SOP) sedia ada untuk perlombongan nadir bumi dan mineral.',
        'KONTINJEN Malaysia mendekati sasaran 27 pingat di Sukan Asia kali ini esok, selepas menuai dua lagi pingat gangsa menerusi terjun dan olahraga pada hari ke-10 pertandingan, pada Selasa.'
    ], 
    return_tensors = 'pt', padding = True)

In [8]:
v = model.encode(input_ids).detach().numpy()
v.shape

(12, 1536)

In [9]:
cosine_similarity(v)

array([[1.        , 0.4209137 , 0.9546999 , 0.9597863 , 0.3120873 ,
        0.8405844 , 0.58216   , 0.58645266, 0.40890247, 0.4996181 ,
        0.5089604 , 0.3351455 ],
       [0.4209137 , 0.9999997 , 0.3542768 , 0.42114827, 0.9504733 ,
        0.60532546, 0.8145397 , 0.51836056, 0.6349023 , 0.51608515,
        0.80891323, 0.5999971 ],
       [0.9546999 , 0.3542768 , 0.9999994 , 0.893784  , 0.27815583,
        0.73887515, 0.4366102 , 0.5009967 , 0.42337325, 0.47003913,
        0.51452726, 0.38308212],
       [0.9597863 , 0.42114827, 0.893784  , 1.        , 0.30760384,
        0.8868808 , 0.5656576 , 0.61279464, 0.4210244 , 0.5397133 ,
        0.44625235, 0.34082997],
       [0.3120873 , 0.9504733 , 0.27815583, 0.30760384, 1.0000002 ,
        0.50049514, 0.7748455 , 0.44755244, 0.5778815 , 0.46799105,
        0.8130318 , 0.5849888 ],
       [0.8405844 , 0.60532546, 0.73887515, 0.8868808 , 0.50049514,
        0.99999994, 0.64166284, 0.56571054, 0.43035477, 0.5582535 ,
        0.4899069 ,

In [10]:
model.push_to_hub('mesolitica/llama2-embedding-600m-8k-contrastive', safe_serialization = True)

model.safetensors:   0%|          | 0.00/2.17G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/llama2-embedding-600m-8k-contrastive/commit/0d60f553fe443549d5d3d89378236dfc1113b3ec', commit_message='Upload model', commit_description='', oid='0d60f553fe443549d5d3d89378236dfc1113b3ec', pr_url=None, pr_revision=None, pr_num=None)

In [11]:
tokenizer.push_to_hub('mesolitica/llama2-embedding-600m-8k-contrastive', safe_serialization = True)

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/llama2-embedding-600m-8k-contrastive/commit/2bdf016e21c07f977dc935e6eaa6060381c648cf', commit_message='Upload tokenizer', commit_description='', oid='2bdf016e21c07f977dc935e6eaa6060381c648cf', pr_url=None, pr_revision=None, pr_num=None)