In [None]:
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import os

# Optional: Set visible GPUs (optional, if you want to limit usage to specific GPUs)
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

# Define last token pooling function to extract the embedding of the last valid token in each sequence
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    # Check if left padding is used
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        # Compute the actual length of each sequence to locate the last valid token
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

print("Loading model...")
model_path = "your/merged/model/path" 
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_path, 
    trust_remote_code=True
)

# Load model with device_map="auto" to distribute across multiple GPUs
model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.float32,
    trust_remote_code=True,
    device_map="auto"
)

# Define query
queries = ["your query text"]

# Define document database
documents = ["document 1", "document 2", "document 3"]

# Merge queries and documents into a single batch for processing
input_texts = queries + documents
max_length = 850

print(f"Processing {len(queries)} queries and {len(documents)} documents...")

# Tokenize the input texts
batch_dict = tokenizer(
    input_texts, 
    max_length=max_length, 
    padding=True, 
    truncation=True, 
    return_tensors='pt'
)

# Move input data to CUDA (automatically selects appropriate device)
if torch.cuda.is_available():
    batch_dict = {k: v.cuda() for k, v in batch_dict.items()}
else:
    print("Warning: CUDA not detected. Using CPU for inference.")

print("Running inference...")

# Get hidden states from the model
with torch.no_grad():
    outputs = model(**batch_dict, output_hidden_states=True)
    last_hidden_states = outputs.hidden_states[-1]

# Extract embeddings using last token pooling
embeddings = last_token_pool(last_hidden_states, batch_dict['attention_mask'])

# Normalize embeddings with L2 norm
embeddings = F.normalize(embeddings, p=2, dim=1)

# Compute cosine similarity between query and documents and scale scores
scores = (embeddings[:len(queries)] @ embeddings[len(queries):].T) * 100

print("\n" + "="*50)
print("Similarity Analysis Results")
print("="*50)

for query_idx, query in enumerate(queries):
    print(f"\nQuery: '{query}'")
    print("\nSimilarity scores for each document:")
    
    query_scores = scores[query_idx].tolist()
    
    # Pair documents with their scores and sort by score descending
    doc_score_pairs = list(zip(documents, query_scores, range(len(documents))))
    doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
    
    for i, (doc, score, original_idx) in enumerate(doc_score_pairs):
        doc_preview = doc[:50] + "..." if len(doc) > 50 else doc
        print(f"  Rank {i+1} - Document {original_idx+1}: {score:.2f} - {doc_preview}")
    
    print(f"\nRaw scores: {[f'{s:.2f}' for s in query_scores]}")
    best_doc_idx = query_scores.index(max(query_scores))
    print(f"Highest similarity: Document {best_doc_idx+1} (Score: {max(query_scores):.2f})")
