https://huggingface.co/learn/nlp-course/chapter5/6

https://www.sbert.net/docs/pretrained_models.html#model-overview


In [4]:
import os
from transformers import AutoTokenizer, AutoModel
from finetune import get_device

device = get_device()

# https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
MODEL_ID = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID)
model.to(device)

def cls_pooling(model_output):
    """Collect the last hidden state for the special [CLS] token."""
    return model_output.last_hidden_state[:, 0]

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

In [5]:
fname = "english.txt"
assert os.path.exists(fname)
with open(fname, "r") as f:
    lines = [line.strip() for line in f.readlines() if line.strip() != ""]

print(f"read {len(lines)} lines from {fname}")

read 2048 lines from english.txt


In [6]:
from datasets import Dataset

my_dataset = Dataset.from_dict({"text": lines})
embeddings_dataset = my_dataset.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Map: 100%|██████████| 2048/2048 [00:24<00:00, 85.18 examples/s]


In [9]:
embeddings_dataset.add_faiss_index(column="embeddings")

#embeddings = get_embeddings(lines)
#embeddings.shape
#list(embeddings_dataset.data[0].keys())

100%|██████████| 3/3 [00:00<00:00, 571.64it/s]


Dataset({
    features: ['text', 'embeddings'],
    num_rows: 2048
})

In [20]:
query = "cat"
query_embedding = get_embeddings([query]).cpu().detach().numpy()

scores, samples = embeddings_dataset.get_nearest_examples(
    "embeddings", query_embedding, k=30
)

display(samples["text"])

['cat',
 'kitten',
 'pet',
 'animal',
 'what',
 'define',
 'category',
 'example',
 'need',
 'this',
 'usage',
 'another',
 'code',
 'try',
 'way',
 'mystery',
 'step',
 'term',
 'cause',
 'issue',
 'question',
 'explain',
 'biology',
 'spell',
 'dilemma',
 'hurry',
 'problem',
 'word',
 'trouble',
 'program']