In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from collections import namedtuple
from torch.utils.data import random_split
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from mteb import MTEB

In [None]:
sts_tasks = [
    "STSBenchmark",   # Standard semantic similarity benchmark
    # "SICK-R",         # Semantic relatedness from the SICK dataset
    # "STS12",
    # "STS13",
    # "STS14",
    # "STS15",
    # "STS16",
]

evaluation = MTEB(tasks=sts_tasks)


def eval(eval_model):
    results = evaluation.run(eval_model, output_folder=None)
    score = [result.scores['test'][0]['main_score']
                     for result in results]
    tasks = [result.task_name for result in results]
    # Print list of tasks as comma-separated values
    print(", ".join(tasks))
    # Print list of scores as comma-separated values    
    print(", ".join([f"{score:.4f}" for score in score]))
    return score



In [3]:
task = evaluation.tasks[0]

task.load_data()

In [4]:
test_dataset = task.dataset['test']

test_dataset[0:5]

{'split': ['test', 'test', 'test', 'test', 'test'],
 'genre': ['main-captions',
  'main-captions',
  'main-captions',
  'main-captions',
  'main-captions'],
 'dataset': ['MSRvid', 'MSRvid', 'MSRvid', 'MSRvid', 'MSRvid'],
 'year': ['2012test', '2012test', '2012test', '2012test', '2012test'],
 'sid': ['0024', '0033', '0045', '0063', '0066'],
 'score': [2.5, 3.6, 5.0, 4.2, 1.5],
 'sentence1': ['A girl is styling her hair.',
  'A group of men play soccer on the beach.',
  "One woman is measuring another woman's ankle.",
  'A man is cutting up a cucumber.',
  'A man is playing a harp.'],
 'sentence2': ['A girl is brushing her hair.',
  'A group of boys are playing soccer on the beach.',
  "A woman measures another woman's ankle.",
  'A man is slicing a cucumber.',
  'A man is playing a keyboard.']}

In [5]:
model_name = "meta-llama/Llama-3.2-1B"

In [6]:


class LLaMAEmbeddingModel:
    def __init__(self, model, tokenizer, normalize_embeddings=True, mean_pooling=True, layer_idx=-1, ignore_bos_token=False, batch_size=8, W_aug=torch.nn.Identity()):
        self.model = model
        self.tokenizer = tokenizer
        self.model.eval()
        # Remove the second half of decoder layers
        # self.model.model.layers = self.model.model.layers[:len(self.model.model.layers)//2]
        self.model.lm_head = torch.nn.Identity()
        self.normalize_embeddings = normalize_embeddings
        self.mean_pooling = mean_pooling
        self.batch_size = batch_size
        self.layer_idx = layer_idx
        self.ignore_bos_token = ignore_bos_token
        self.W_aug = W_aug.cuda()

    def encode(self, sentences, **kwargs):
        all_embeddings = []
        for i in range(0, len(sentences), self.batch_size):
            batch = sentences[i:i+self.batch_size]
            # Update each sentence such that its templatized as "Sentence: {sentence}, Repeat:"
            batch = [f"Sentence: {sentence}\n Meaning:" for sentence in batch]
            inputs = self.tokenizer(
                batch, return_tensors='pt', padding=True).to(self.model.device)
            with torch.no_grad():
                output = self.model(
                    **inputs, output_hidden_states=True, return_dict=True)
                outputs = output.hidden_states[self.layer_idx]
                # Mean pooling
                if self.mean_pooling:
                    attention_mask = inputs['attention_mask']
                    if self.ignore_bos_token:
                        # Find all positions of the BOS token
                        bos_token_id = self.tokenizer.bos_token_id
                        bos_positions = (inputs['input_ids'] == bos_token_id).nonzero(as_tuple=True)
                        # Set the attention mask to 0 for BOS token positions, first dimension is batch size
                        attention_mask[bos_positions] = 0

                    embeddings = outputs * attention_mask.unsqueeze(-1)
                    # Do weighted mean pooling based on position of token
                    # basically weight = postion / sum(position)
                    # Create position weights: 1-based indexing
                    seq_length = attention_mask.size(1)
                    position_ids = torch.arange(
                       1, seq_length + 1, device=attention_mask.device).unsqueeze(0)  # (1, T)
                    position_weights = position_ids * attention_mask  # (B, T)
                    norm_factor = position_weights.sum(
                        dim=1, keepdim=True).clamp(min=1e-5)  # avoid div-by-zero

                    # Expand mask to match hidden dim
                    # (B, T, 1)
                    position_weights = position_weights.unsqueeze(-1)

                    # Apply weighted pooling
                    weighted_outputs = outputs * position_weights  # (B, T, D)
                    embeddings = weighted_outputs.sum(
                        dim=1) / norm_factor  # (B, D)


                else:
                    # get last token embedding
                    embeddings = outputs[:, -1, :]
                    # Apply ridge regression projection
                    embeddings = self.W_aug(embeddings)

                if self.normalize_embeddings:
                    embeddings = torch.nn.functional.normalize(
                        embeddings, p=2, dim=1)
                all_embeddings.append(embeddings.cpu().float().numpy())
        return np.vstack(all_embeddings)

In [7]:
model_name = 'meta-llama/Llama-3.2-1B'
model = AutoModelForCausalLM.from_pretrained(model_name,  device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
eval_model = SentenceTransformer("all-mpnet-base-v2")
model.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [42]:

import torch
import torch.nn.functional as F


def template(
    x): return f"What does each word in this sentence mean: {x}"


def compute_entropy(logits):
    probs = F.softmax(logits, dim=-1)
    return -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)  # (B, T)


def compute_position_scores(mask):
    B, T = mask.shape
    position_ids = torch.arange(1, T + 1, device=mask.device).float()  # (T,)
    position_scores = position_ids.unsqueeze(0).expand(B, -1) * mask  # (B, T)
    norm_position = position_scores / \
        (position_scores.sum(dim=1, keepdim=True) + 1e-8)
    return norm_position


def compute_normalized_entropy(entropy, mask):
    inverted_entropy = 1.0 / (entropy + 1e-8)  # prevent division by zero
    masked_entropy = inverted_entropy * mask
    # Apply softmax with temperature 0.7
    norm_entropy = F.softmax(masked_entropy, dim=1)  # (B, T)
    return norm_entropy


def compute_combined_weights(entropy, mask, alpha):
    # Create low entropy mask
    norm_entropy = compute_normalized_entropy(entropy, mask)
    low_entropy_mask = (norm_entropy > 0.1).float() * mask  # (B, T)
    # norm_entropy = compute_normalized_entropy(entropy, mask)
    norm_position = compute_position_scores(low_entropy_mask)
    return norm_position


def compute_pooled_embedding(hidden_states, weights):
    return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)  # (B, D)


def process_sentences(model, tokenizer, sentences, alpha):
    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors='pt',
                           padding=True).to(model.device)
        output = model(**inputs, output_hidden_states=True, return_dict=True)

        logits = output.logits              # (B, T, V)
        entropy = compute_entropy(logits)   # (B, T)

        hidden = output.hidden_states[-3]   # (B, T, D3
        mask = inputs['attention_mask']     # (B, T)

        hidden = hidden[:, 1:, :]           # remove BOS → (B, T-1, D)
        mask = mask[:, 1:]                  # (B, T-1)
        entropy = entropy[:, 1:]           # (B, T-1)

        weights = compute_combined_weights(entropy, mask, alpha)  # (B, T-1)
        pooled = compute_pooled_embedding(hidden, weights)        # (B, D)

        return pooled, inputs, weights


def print_token_weights(inputs, weights, tokenizer, alpha):
    input_ids = inputs['input_ids']
    for i, (ids, w_row, mask_row) in enumerate(zip(input_ids, weights, inputs['attention_mask'][:, 1:])):
        tokens = tokenizer.convert_ids_to_tokens(ids[1:])  # skip BOS
        print(f"Sentence {i+1} tokens + weights (α = {alpha}):")
        for token, weight, valid in zip(tokens, w_row, mask_row):
            if valid.item() == 1:
                print(f"  {token:>12} : weight = {weight.item():.4f}")
        print()


def oracle_embedding(sentences):
    """
    This function computes the oracle embeddings for a list of sentences.
    It uses the SentenceTransformer model to encode the sentences.
    """
    return eval_model.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)


alphas = [0.0, 1.0, 0.7]

index = 0
print(test_dataset[index]['score'])
# Get samples from evaluation
sentence1 = test_dataset[index]['sentence1']
sentence2 = test_dataset[index]['sentence2']


n_samples = 200
scores = [test_dataset[i]['score'] for i in range(n_samples)]
sims2 = []
oracle = []
for i in range(n_samples):
    sentence1 = test_dataset[i]['sentence1']
    sentence2 = test_dataset[i]['sentence2']
    
    pooled_2 = process_sentences(
        model, tokenizer, [sentence1, sentence2], alpha=1.0)[0]

    sim2 = F.cosine_similarity(pooled_2[0].unsqueeze(
        0), pooled_2[1].unsqueeze(0)).item()
    oracle_emb = oracle_embedding([sentence1, sentence2])
    oracle_sim = F.cosine_similarity(
        oracle_emb[0].unsqueeze(0), oracle_emb[1].unsqueeze(0)).item()
    oracle.append(oracle_sim)
    sims2.append(sim2)
# Print correlation with scores
correlation2 = np.corrcoef(scores, sims2)[0, 1]
correlation_oracle = np.corrcoef(scores, oracle)[0, 1]
print(f"Correlation with oracle: {correlation_oracle:.4f}")
print(f"Correlation with scores (alpha=0.0): {correlation2:.4f}")


# n_samples = 200
# print()
# for alpha in alphas:
#     sims = []
#     for i in range(n_samples):
#         sentence1 = test_dataset[i]['sentence1']
#         sentence2 = test_dataset[i]['sentence2']
#         pooled, inputs, weights = process_sentences(model, tokenizer, [sentence1, sentence2], alpha)
#         sim = F.cosine_similarity(pooled[0].unsqueeze(0), pooled[1].unsqueeze(0)).item()
#         sims.append(sim)

#     # Print correlation with scores
#     correlation = np.corrcoef(scores, sims)[0, 1]
#     print(f"Alpha: {alpha}, Correlation with scores: {correlation:.4f}")

2.5
Correlation with oracle: 0.9382
Correlation with scores (alpha=0.0): 0.0801


In [40]:
s1 =test_dataset[0]['sentence1']
s2 = test_dataset[0]['sentence2']
pooled, inputs, weights = process_sentences(model, tokenizer, [s1, s2], alpha=1.0)
print_token_weights(inputs, weights, tokenizer, alpha=1.0)



Sentence 1 tokens + weights (α = 1.0):
             A : weight = 0.0357
         Ġgirl : weight = 0.0714
           Ġis : weight = 0.1071
      Ġstyling : weight = 0.1429
          Ġher : weight = 0.1786
         Ġhair : weight = 0.2143
             . : weight = 0.2500

Sentence 2 tokens + weights (α = 1.0):
             A : weight = 0.0357
         Ġgirl : weight = 0.0714
           Ġis : weight = 0.1071
     Ġbrushing : weight = 0.1429
          Ġher : weight = 0.1786
         Ġhair : weight = 0.2143
             . : weight = 0.2500



In [None]:

eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, mean_pooling=False)
print(eval(eval_model))

In [None]:

eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, ignore_bos_token=True, mean_pooling=True, batch_size=16)
print(eval(eval_model))



STSBenchmark
0.5314
[0.5314290129100447]


In [73]:
n_layers = len(model.model.layers)

eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, mean_pooling=False, ignore_bos_token=False, layer_idx=-4)
print(eval(eval_model))



STSBenchmark
0.0893
[0.08927839385848074]


In [None]:
eval_model = SentenceTransformer("all-MiniLM-L6-v2")
print(eval(eval_model))

In [None]:
eval_model = SentenceTransformer("all-mpnet-base-v2")
print(eval(eval_model))

In [None]:
model = SentenceTransformer("WhereIsAI/UAE-Large-V1")
print(eval(model))