In [1]:
# !pip install torch
# !pip install git+https://github.com/huggingface/transformers.git

In [2]:
import torch 
import math

from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModelForMaskedLM
from datasets import load_dataset

from tqdm import tqdm

from torch.utils.data import DataLoader

torch.set_float32_matmul_precision('high')

In [3]:
MODEL_NAME = "answerdotai/ModernBERT-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

NUM_SAMPLES = 1000
MINIBATCH_SIZE = 64

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model = model.to(DEVICE)

In [5]:
dataset = load_dataset('wmt14', 'de-en', split='train')
dataset = dataset.select(range(NUM_SAMPLES))

In [6]:
def tokenize_data(sentence_pair, tokenizer: AutoTokenizer = None):
    return tokenizer(
        sentence_pair["translation"]["de"],        	
        truncation=True,   
        padding='max_length',
        return_tensors='pt',  	
        add_special_tokens=True, 
        max_length=512
    )

# dataset = dataset.map(lambda sentence_pair: tokenize_data(sentence_pair, tokenizer=tokenizer), remove_columns=["translation"])

In [7]:

input_data = [data_point["translation"]["de"] for data_point in dataset]
tokenized_data = tokenizer.batch_encode_plus(
    input_data,        	
    truncation=True,   
    padding='max_length',
    return_tensors='pt',  	
    add_special_tokens=True, 
    max_length=512
)

input_ids_chunked = torch.split(tokenized_data["input_ids"], MINIBATCH_SIZE)
attn_mask_chunked = torch.split(tokenized_data["attention_mask"], MINIBATCH_SIZE)

for input_ids, attn_mask in zip(input_ids_chunked, attn_mask_chunked): 
    input_ids = input_ids.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True)
    
    word_embeddings = outputs.hidden_states[-1]

    # This is the PCA code. Depending on what we want to do, there are some interesting resources available: 
    #     - PyTorch PCA: https://pytorch.org/docs/stable/generated/torch.pca_lowrank.html 
    #     - Discussion on an embedding approximation: https://stackoverflow.com/questions/75796047/how-to-evaluate-the-quality-of-pca-returned-by-torch-pca-lowrank
    #     - Batched processing: https://github.com/pytorch/pytorch/issues/99705
    U, S, V = torch.pca_lowrank(word_embeddings, q=25, center=True, niter=2)

    # TODO: We may have to assess the PCA quality at some point.
    
    print(f"Word Embeddings Shape: {word_embeddings.shape}")
    print(f"PCA Embeddings Shape: {S}")
    # At this point we can either perform an avg_pool (procedure identical to SentenceTransformers) 
    # or we can run a PCA on the word embeddings (this might be helpful to identify "interesting" sequences).
    document_embedding = torch.nanmean(word_embeddings, dim=1)
    print(f"Document Embeddings Shape: {document_embedding.shape}")
    break


Word Embeddings Shape: torch.Size([64, 512, 768])
PCA Embeddings Shape: tensor([[33581.3203,  4239.0229,   922.3452,  ...,   109.7223,    99.1819,
            97.9879],
        [51179.2930,  7132.0757,  1350.3467,  ...,   166.1138,   160.5107,
           157.8768],
        [41247.7383,  8174.0542,  1336.0531,  ...,   169.6024,   164.6267,
           156.4999],
        ...,
        [40624.5977,  8774.3887,  1185.7632,  ...,   147.7459,   145.5760,
           135.4269],
        [36260.4844,  7496.4736,  1036.7146,  ...,   120.5133,   107.6837,
            96.1345],
        [29509.9902,  5759.0776,   993.5135,  ...,   120.3719,   110.7224,
           106.2065]], device='cuda:0')
Document Embeddings Shape: torch.Size([64, 768])


In [8]:
# with torch.no_grad():
# 
#     encoding = {k: v.to(DEVICE) for k, v in encoding.items()}
#     
#     outputs = model(**encoding, output_hidden_states=True)
#     word_embeddings = outputs.hidden_states[-1]  
# 
#     document_embedding = torch.nanmean(word_embeddings,dim = 1)
#     hidden_state = torch.nan_to_num(word_embeddings,nan = 0)
#     
# print(f"Word Embeddings Shape: {document_embedding.shape}")