In [2]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [125]:
import torch
from transformers import AutoTokenizer, AutoModel

# Load the pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [161]:
def load_corpus_data():
  return [
      "Usually the people called Daniel are bad.",
      "Cosine similarity is a measure of similarity, often used to measure document similarity in text analysis.",
      "Peter likes fish, steak, and roast chicken"
  ]

corpus = load_corpus_data()

In [162]:
def load_search_queries():
  return [
      "which name is bad",
      "what is cosine similarity",
      "who likes to eat meat"
  ]

queries = load_search_queries()

In [163]:
import torch
corpus_embeddings = []
for text in corpus:
    input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)
    outputs = model(input_ids)
    last_hidden_states = outputs[0]
    avg_pooling = torch.mean(last_hidden_states, dim=1).squeeze()
    corpus_embeddings.append(avg_pooling.detach().numpy())



In [164]:
query_embeddings = []
for text in queries:
    input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)
    outputs = model(input_ids)
    last_hidden_states = outputs[0]
    avg_pooling = torch.mean(last_hidden_states, dim=1).squeeze()
    query_embeddings.append(avg_pooling.detach().numpy())

    # # Tokenize the input text and get the input IDs and attention mask
    # encoded_input = tokenizer(text, return_tensors='pt')
    # input_ids = encoded_input['input_ids']
    # attention_mask = encoded_input['attention_mask']

    # # Pass the input through the BERT model to get the last_hidden_states tensor
    # last_hidden_states = model(input_ids, attention_mask)[0]

    # # Get the attention weights from the last layer of the BERT model
    # attention_layer = model.encoder.layer[-1].attention.self
    # query_weight = attention_layer.query.weight
    # attention_scores = torch.matmul(last_hidden_states, query_weight.transpose(0, 1))
    # attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)

    # # Compute a weighted sum of the embeddings using the attention weights
    # sentence_embedding = torch.matmul(attention_weights.transpose(1, 2), last_hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
    # query_embeddings.append(sentence_embedding.squeeze().detach().numpy())
    

In [165]:
from sklearn.metrics.pairwise import cosine_similarity

for query_embedding in query_embeddings:
  relevance_scores = cosine_similarity(query_embedding.reshape(1, -1), corpus_embeddings)
  sorted_indices = relevance_scores[0].argsort()[::-1]
  relevant_documents = [corpus[i] for i in sorted_indices[:10]]
  print(relevance_scores)
  print(relevant_documents)

[[0.61023784 0.4218006  0.46227834]]
['Usually the people called Daniel are bad.', 'Peter likes fish, steak, and roast chicken', 'Cosine similarity is a measure of similarity, often used to measure document similarity in text analysis.']
[[0.4857527  0.60021305 0.44286966]]
['Cosine similarity is a measure of similarity, often used to measure document similarity in text analysis.', 'Usually the people called Daniel are bad.', 'Peter likes fish, steak, and roast chicken']
[[0.5862737 0.4151404 0.6338297]]
['Peter likes fish, steak, and roast chicken', 'Usually the people called Daniel are bad.', 'Cosine similarity is a measure of similarity, often used to measure document similarity in text analysis.']
