- testing how well text embedding models can encode the features of input
- goal is to have some sort of custom benchmark for this

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(42)

<torch._C.Generator at 0x10990d6d0>

In [2]:
def embed(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0]  # CLS token pooling
        embeddings = torch.nn.functional.normalize(embeddings, dim=1)
    return embeddings

def cosine_sim(embeddings):
    n = embeddings.size()[0]
    sim = lambda a, b: torch.dot(a, b) / torch.linalg.norm(a) * torch.linalg.norm(b)
    sims = torch.empty(n, n)
    for i, e in enumerate(embeddings):
        for ii, ee, in enumerate(embeddings):
            sims[i, ii] = sim(e, ee)
    return sims

In [3]:
model_path = "ibm-granite/granite-embedding-30m-english"
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval()

RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 384, padding_idx=1)
    (position_embeddings): Embedding(514, 384, padding_idx=1)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-5): 6 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSdpaSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (drop

In [4]:
# here labels are world(0), sports(1), business(2), sci/tech(3)
dataset = load_dataset("ag_news")
len(dataset["train"]['text']), dataset["train"]['text'][0:2]

(120000,
 ["Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
  'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'])

In [5]:
# showing sim embeddings
texts = dataset["train"]['text'][0:5]
embeddings = embed(texts)

sim_matrix = cosine_sim(embeddings)
sim_matrix

tensor([[1.0000, 0.7019, 0.7057, 0.6213, 0.6776],
        [0.7019, 1.0000, 0.6954, 0.6464, 0.6647],
        [0.7057, 0.6954, 1.0000, 0.7319, 0.7669],
        [0.6213, 0.6464, 0.7319, 1.0000, 0.7396],
        [0.6776, 0.6647, 0.7669, 0.7396, 1.0000]])

In [20]:
train_texts = dataset["train"]["text"]
vocab = sorted(set(word.lower() for text in train_texts for word in text.split()))
word2idx = {w: i for i, w in enumerate(vocab)}

# Convert text -> token IDs
token_ids = [torch.tensor([word2idx[w.lower()] for w in text.split()]) for text in train_texts]

# Pad sequences to same length
from torch.nn.utils.rnn import pad_sequence
token_ids_padded = pad_sequence(token_ids, batch_first=True)
token_ids_padded[0]

tensor([153190, 135656,  34323,  44974,  32566,  82318, 142913,  36256,   9704,
        123425,  10731, 130870, 153190, 137429,  57832, 104235, 148767,  30077,
        128725,  72057,  25840,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0, 

In [6]:
def embed_split(split):
    texts = list(dataset[split]["text"][0:4000])
    labels = list(dataset[split]["label"][0:4000])
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0] # CLS token pooling
        embeddings = torch.nn.functional.normalize(embeddings, dim=1)
    return torch.tensor(embeddings, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)

X_train, y_train = embed_split("train")
X_test, y_test = embed_split("test")

  return torch.tensor(embeddings, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)


In [10]:
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)

num_classes = len(set(dataset["train"]["label"]))
input_dim = X_train.shape[1]
num_classes, input_dim

(4, 384)

- using linear probing (traing a linear layer using logistic regression on the frozen embeddings) to check how well the embedding models encoded the text
- the goals is for the linear probing layer to learn the embeddings and their features
- random guessing for a dataset like ag news with 4 classes would be 25% so something like 87% is pretty good
- linear probing score == (kinda) compression efficiency score
- more so though, linear probing also checks the linear seperability between features in the embeddings. to check if the features are even there at all, it would make more sense to use non-linear probing with something like a simple 2 layer mlp

In [11]:
class LogisticRegression(nn.Module):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

def train(probe, loss_fn, optimizer, train_loader, epochs):
    probe.train()
    for epoch in range(epochs):
        for xb, yb in train_loader:
            optimizer.zero_grad()
            logits = probe(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            optimizer.step()
        if epoch % (epochs // 10) == 0:
            print(f"epoch {epoch+1}/{epochs}, loss: {loss.item():.4f}")

def eval(probe, test_loader):
    probe.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for xb, yb in test_loader:
            logits = probe(xb)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(yb.cpu().numpy())
    return all_preds, all_labels

def accuracy(all_preds, all_labels):
    total_acc = 0
    for p, l in zip(all_preds, all_labels):
        if p == l: total_acc += 1
    return total_acc / len(all_preds)

In [18]:
probe = LogisticRegression(input_dim, num_classes)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
epochs = 400

total_loss = train(probe, loss_fn, optimizer, train_loader, epochs)
all_preds, all_labels = eval(probe, test_loader)
accuracy(all_preds, all_labels)

epoch 1/400, loss: 1.3183
epoch 41/400, loss: 0.3868
epoch 81/400, loss: 0.4437
epoch 121/400, loss: 0.3236
epoch 161/400, loss: 0.2183
epoch 201/400, loss: 0.2687
epoch 241/400, loss: 0.1063
epoch 281/400, loss: 0.3342
epoch 321/400, loss: 0.3588
epoch 361/400, loss: 0.3281


0.864