In [1]:
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
def get_sentence_embedding(text):
    inputs = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state
    sentence_embedding = torch.mean(last_hidden_states, dim=1).numpy()
    return sentence_embedding

In [9]:
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "A fast brown fox leaps over a sleepy dog.",
    "This sentence is completely different from the others."
]

# Generate embeddings for texts
embeddings = [get_sentence_embedding(text) for text in texts]
embeds=np.array(embeddings)
embeds.shape

(3, 1, 768)

In [5]:
query_text = "The quick red fox jumps over the lazy dog."
query_embedding = get_sentence_embedding(query_text)
query_embedding[0][:20]

array([ 0.0158137 , -0.0613406 ,  0.0433245 ,  0.03471806,  0.3223506 ,
        0.00439711, -0.12446175,  0.38215244,  0.10109716, -0.09531727,
       -0.08614075, -0.16277851, -0.13426319, -0.05922299, -0.40886435,
       -0.19059582,  0.15335262, -0.12626106, -0.14645492,  0.0090116 ],
      dtype=float32)

In [6]:
similarities = cosine_similarity(query_embedding, np.vstack(embeddings))
similarities

array([[0.9852561 , 0.90269643, 0.48106045]], dtype=float32)

In [7]:
print (f"Query text: {query_text}")
for i, text in enumerate(texts):
    print(f"Similarity with '{text}': {similarities[0][i]*100:.1f}%")

Query text: The quick red fox jumps over the lazy dog.
Similarity with 'The quick brown fox jumps over the lazy dog.': 98.5%
Similarity with 'A fast brown fox leaps over a sleepy dog.': 90.3%
Similarity with 'This sentence is completely different from the others.': 48.1%
