In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from collections import namedtuple
from torch.utils.data import random_split
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from mteb import MTEB


In [2]:
sts_tasks = [
    "STSBenchmark",   # Standard semantic similarity benchmark
    "SICK-R",         # Semantic relatedness from the SICK dataset
    "STS12",
    "STS13",
    "STS14",
    "STS15",
    "STS16",
]

evaluation = MTEB(tasks=sts_tasks)


def eval(eval_model):
    results = evaluation.run(eval_model, output_folder=None)
    score = [result.scores['test'][0]['main_score']
                     for result in results]
    tasks = [result.task_name for result in results]
    # Print list of tasks as comma-separated values
    print(", ".join(tasks))
    # Print list of scores as comma-separated values    
    print(", ".join([f"{score:.4f}" for score in score]))
    return score



In [3]:
model_name = "meta-llama/Llama-2-7b-hf"
dataset_name = "Skylion007/openwebtext"
model_suffix = model_name.split("/")[-1]
dataset_suffix = dataset_name.split("/")[-1]

dir = f"data/{dataset_suffix}_{model_suffix}"


hidden_states = torch.load(f"../{dir}/hidden_states.pt")
#embeddings = torch.load(f"../{dir}/embeddings_all-MiniLM-L6-v2.pt")
embeddings = torch.load(f"../{dir}/embeddings_all-mpnet-base-v2.pt")
#embeddings = torch.load(f"../{dir}/embeddings_UAE-Large-V1.pt")
#embeddings = torch.load(f"../{dir}/embeddings_gte-Qwen2-1.5B-instruct.pt")
# embeddings = torch.load(f"../{dir}/hidden_states_cumulative_mean.pt")
# token_ids = torch.load(f"{dir}/token_ids_merged.pt")

# Cast hidden states to float32
hidden_states = hidden_states.float()
embeddings = embeddings.float()

embeddings = embeddings[:hidden_states.shape[0]]

In [4]:

def ridge_regression(X, Y, alpha):
    d_in = X.shape[1]
    I = torch.eye(d_in, device=X.device)
    W = torch.linalg.solve(X.T @ X + alpha * I, X.T @ Y)
    return W

def evaluate(X, Y, W):
    preds = X @ W
    mse = torch.mean((preds - Y) ** 2).item()
    return mse

def find_best_alpha(X, Y, alphas, val_ratio=0.2, seed=42):
    # Split data into train and validation
    n = X.shape[0]
    val_size = int(n * val_ratio)
    train_size = n - val_size
    generator = torch.Generator().manual_seed(seed)
    X_train, X_val = random_split(X, [train_size, val_size], generator=generator)
    Y_train, Y_val = random_split(Y, [train_size, val_size], generator=generator)

    X_train = X_train.dataset[X_train.indices]
    Y_train = Y_train.dataset[Y_train.indices]
    X_val = X_val.dataset[X_val.indices]
    Y_val = Y_val.dataset[Y_val.indices]

    best_alpha = None
    best_val_loss = float('inf')
    losses = []

    for alpha in alphas:
        W = ridge_regression(X_train, Y_train, alpha)
        val_loss = evaluate(X_val, Y_val, W)
        losses.append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_alpha = alpha

    # Plot results
    plt.plot(alphas, losses, marker='o')
    plt.xscale('log')
    plt.xlabel('Alpha')
    plt.ylabel('Validation MSE')
    plt.title('Ridge Regression Validation Loss vs Alpha')
    plt.show()

    print(f"Best alpha: {best_alpha:.2e} (val loss: {best_val_loss:.4f})")
    return best_alpha

In [5]:
# find_best_alpha(hidden_states, embeddings, alphas=[1e-3, 1e-2, 1e-1, 1, 10, 100], val_ratio=0.2)

In [4]:


class LLaMAEmbeddingModel:
    def __init__(self, model, tokenizer, normalize_embeddings=True, mean_pooling=True, layer_idx = -1, ignore_bos_token = False, batch_size=8, W_aug=torch.nn.Identity()):
        self.model = model
        self.tokenizer = tokenizer
        self.model.eval()
        # Remove the second half of decoder layers
        # self.model.model.layers = self.model.model.layers[:len(self.model.model.layers)//2]
        self.model.lm_head = torch.nn.Identity()
        self.normalize_embeddings = normalize_embeddings
        self.mean_pooling = mean_pooling
        self.batch_size = batch_size
        self.layer_idx = layer_idx
        self.ignore_bos_token = ignore_bos_token
        self.W_aug = W_aug.cuda()

    def encode(self, sentences, **kwargs):
        all_embeddings = []
        for i in range(0, len(sentences), self.batch_size):
            batch = sentences[i:i+self.batch_size]
            inputs = self.tokenizer(
                batch, return_tensors='pt', padding=True, truncation=True).to(self.model.device)
            with torch.no_grad():
                output = self.model(**inputs, output_hidden_states=True, return_dict=True)
                outputs = output.hidden_states[self.layer_idx]
                # Mean pooling
                if self.mean_pooling:
                    attention_mask = inputs['attention_mask']
                    embeddings = outputs * attention_mask.unsqueeze(-1)
                    if self.ignore_bos_token:
                        # Ignore the first token (BOS)
                        embeddings = embeddings[:, 1:, :]
                    # Ignore BOS token
                    #embeddings = embeddings[:, 1:, :]
                    embeddings = embeddings.mean(1)
                else:
                    # get last token embedding
                    embeddings = outputs[:, -1, :]
                    # Apply ridge regression projection
                    embeddings = self.W_aug(embeddings)

                if self.normalize_embeddings:
                    embeddings = torch.nn.functional.normalize(
                        embeddings, p=2, dim=1)
                all_embeddings.append(embeddings.cpu().float().numpy())
        return np.vstack(all_embeddings)

In [5]:
model_name = 'meta-llama/Llama-2-7b-hf'
model = AutoModelForCausalLM.from_pretrained(model_name,  device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_e

In [9]:

eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, mean_pooling=False)
print(eval(eval_model))

Failed to extract metadata from model: 'LLaMAEmbeddingModel' object has no attribute 'model_card_data'. Upgrading to sentence-transformers v3.0.0 or above is recommended.


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


SICK-R, STS12, STS13, STS14, STS15, STS16, STSBenchmark
0.2429, 0.2703, 0.0920, 0.1568, 0.2605, 0.3930, 0.1788
[np.float64(0.2429343467909805), np.float64(0.2703277420492475), np.float64(0.09201553833122342), np.float64(0.15676796511235006), np.float64(0.26054698162452655), np.float64(0.392956947080111), np.float64(0.1787514500176398)]


In [16]:
eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, mean_pooling=True)
print(eval(eval_model))

Failed to extract metadata from model: 'LLaMAEmbeddingModel' object has no attribute 'model_card_data'. Upgrading to sentence-transformers v3.0.0 or above is recommended.


SICK-R, STS12, STS13, STS14, STS15, STS16, STSBenchmark
0.4944, 0.3539, 0.5375, 0.3991, 0.5470, 0.5298, 0.4280
[np.float64(0.49435149334104195), np.float64(0.3539113999132947), np.float64(0.5375398915505257), np.float64(0.39907971860367547), np.float64(0.5469675363562213), np.float64(0.5297650976899748), np.float64(0.4279684825684324)]


In [9]:
n_layers = len(model.model.layers)

eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, mean_pooling=True, ignore_bos_token=False, layer_idx=19)
print(eval(eval_model))

Failed to extract metadata from model: 'LLaMAEmbeddingModel' object has no attribute 'model_card_data'. Upgrading to sentence-transformers v3.0.0 or above is recommended.


SICK-R, STS12, STS13, STS14, STS15, STS16, STSBenchmark
0.4482, 0.1955, 0.3361, 0.2598, 0.4988, 0.4538, 0.1097
[np.float64(0.44819931283683284), np.float64(0.19549343469143526), np.float64(0.33614121869939007), np.float64(0.2597619528018424), np.float64(0.4987740336827443), np.float64(0.45378752743999423), np.float64(0.10973299743363822)]


In [10]:
W_aug = ridge_regression(hidden_states, embeddings, alpha=1e-3)

# Convert to bf16
W_aug = W_aug.to(torch.bfloat16)

W_aug_model = torch.nn.Linear(W_aug.shape[0], W_aug.shape[1], bias=False)
W_aug_model.weight.data = W_aug.T


eval_model = LLaMAEmbeddingModel(
    model, tokenizer, normalize_embeddings=False, mean_pooling=False, W_aug=W_aug_model)

print(eval(eval_model))

NameError: name 'ridge_regression' is not defined

In [None]:
eval_model = SentenceTransformer("all-MiniLM-L6-v2")
print(eval(eval_model))

In [None]:
eval_model = SentenceTransformer("all-mpnet-base-v2")
print(eval(eval_model))

In [None]:
model = SentenceTransformer("WhereIsAI/UAE-Large-V1")
print(eval(model))