In [1]:
from datasets import load_dataset

dataset = load_dataset("ag_news")
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [2]:
dataset["train"][0]

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
 'label': 2}

In [3]:
documents = dataset["train"]["text"][:1000]
len(documents)

1000

In [4]:
from transformers import AutoTokenizer, AutoModel
import torch

embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"

embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
embed_model = AutoModel.from_pretrained(embed_model_name)
embed_model.eval()

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
W0101 22:11:45.931000 20759 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  _torch_pytree._register_pytree_node(


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-5): 6 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
    

In [5]:
def embed_texts(texts, tokenizer, model, batch_size=32):
    embeddings = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            inputs = tokenizer(
                batch,
                padding=True,
                truncation=True,
                return_tensors="pt"
            )
            outputs = model(**inputs)
            batch_embeddings = outputs.last_hidden_state.mean(dim=1)
            embeddings.append(batch_embeddings)
    return torch.cat(embeddings)

In [6]:
doc_embeddings = embed_texts(documents, embed_tokenizer, embed_model)
doc_embeddings.shape

torch.Size([1000, 384])

<h3>384 is the dimensionality of the semantic embedding space.</h3>

In [7]:
def embed_query(query, tokenizer, model):
    with torch.no_grad():
        inputs = tokenizer(
            [query],
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        outputs = model(**inputs)
        embedding = outputs.last_hidden_state.mean(dim=1)
    return embedding

In [8]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def semantic_search(query, documents, doc_embeddings, tokenizer, model, top_k=5):
    query_emb = embed_query(query, tokenizer, model)
    scores = cosine_similarity(query_emb, doc_embeddings)[0]
    ranked_indices = np.argsort(scores)[::-1][:top_k]

    results = []
    for idx in ranked_indices:
        results.append((scores[idx], documents[idx]))

    return results

In [9]:
query = "oil prices fall after global uncertainty"

results = semantic_search(
    query,
    documents,
    doc_embeddings,
    embed_tokenizer,
    embed_model,
    top_k=5
)

for score, doc in results:
    print(f"{score:.3f} | {doc[:120]}...")

0.459 | Oil Holds Near Record Level Oil prices fell 23 cents to \$46.35 a barrel after Venezuelan Hugo Chavez won a recall refer...
0.433 | Stocks Fall as Oil Hits High (Reuters) Reuters - Exporters led a fall in Asian shares\on Monday as oil prices set new hi...
0.428 | Stocks Fall as Oil Hits High  SINGAPORE (Reuters) - Exporters led a fall in Asian shares  on Monday as oil prices set ne...
0.409 | Oil Prices Hit Record (Reuters) Reuters - Oil prices jumped to a new record\high near  #36;47 on Monday with traders on ...
0.399 | No Need for OPEC to Pump More-Iran Gov  TEHRAN (Reuters) - OPEC can do nothing to douse scorching  oil prices when marke...


<h1>Why embeddings + cosine similarity are better for large-scale search?</h1>

<h3>Because you can precompute embeddings once and search fast, instead of running the model for every queryâ€“document pair.</h3>