In [1]:
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

In [2]:
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

In [3]:
def prompt(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

In [64]:
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
    prompt(task, 'What is netidx?'),
    prompt(task, 'is it possible to use netidx in a shell script?')
]
# No need to add instruction for retrieval documents
documents = [
    "Netidx is a middleware system for publishing and consuming values across networks or on the same machine. It uses a hierarchical namespace for globally unique names, allowing values to be updated and subscribers to receive notifications of changes. Unlike LDAP, Netidx doesn't store entries or attributes, and unlike MQTT, it doesn't have a centralized broker. It supports browsable directories, authentication, authorization, and encryption. Values can be both read and written, and subscribers receive updates reliably and in order. Netidx's data format includes primitive types and supports zero-copy decoding for efficiency. Security is optional but can be enforced with Kerberos, Local, or TLS. It's cross-platform and designed for large namespaces with delegation and replication capabilities.",
    "This shell script automates the process of publishing vmstat data to netidx. The script starts by defining a base path and uses `vmstat -n 1` to print the header and one line of data per second. It then reads each field of the output, reformats it into a publishable format, and pipes this data to `netidx` to display it in a browser as a table. Additionally, the script can aggregate and publish total values for each vmstat field across multiple hosts by listing all relevant paths under `/sys/vmstat`, filtering out the total row, and using an associative array to keep track of individual host totals for each field. This allows for real-time monitoring and aggregation of system performance metrics across a network of machines.",
    "This text is too short to summarize"
]
input_texts = queries + documents

In [6]:
tokenizer = AutoTokenizer.from_pretrained('/home/eric/proj/gte-Qwen2-7B-instruct', trust_remote_code=True)
model = AutoModel.from_pretrained('/home/eric/proj/gte-Qwen2-7B-instruct', trust_remote_code=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [65]:
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())

[[69.60458374023438, 32.227603912353516, 14.335012435913086], [53.810001373291016, 53.52702331542969, 14.783313751220703]]


In [18]:
embeddings.shape

torch.Size([4, 3584])

In [61]:
batch_dict["input_ids"]

tensor([[151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151