In [None]:
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertForMaskedLM
from datasets import load_from_disk
import random
import os
from tqdm import tqdm

import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict


In [None]:
seed_num = 0
random.seed(seed_num)
np.random.seed(seed_num)
seed_val = 42
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [None]:
model_path = "/home/logs/jtorresb/Geneformer/yeast/pretraining/models/250225_192022_yeastformer_L4_emb256_SL512_E20_B8_LR0.0016_LScosine_WU50_Oadamw_torch/models"
token_dict_path = "/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_token_dict.pkl"

# Load gene token dictionary
with open(token_dict_path, "rb") as fp:
    token_dictionary = pickle.load(fp)

# Load model
model = BertForMaskedLM.from_pretrained(model_path)
model.eval()  # Set model to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

# Invert the token dictionary (if it maps gene_name -> token_id)
id_to_gene = {v: k for k, v in token_dictionary.items()}

In [None]:
# -------------------------------
# Load dataset and split into train/validation
# ------------------------------
dataset = load_from_disk("/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_master_matrix_sgd.dataset")
#dataset_split = dataset.train_test_split(test_size=0.05, seed=seed_val)
#train_dataset = dataset_split['train'].select(range(len(dataset_split['train'])))

In [5]:
# Define a collate function to pad the input_ids in a batch
def collate_fn(batch):
    # Each item in batch is assumed to be a dict with key "input_ids"
    input_ids = [torch.tensor(sample['input_ids'], dtype=torch.long) for sample in batch]
    lengths = [len(ids) for ids in input_ids]
    padded = pad_sequence(input_ids, batch_first=True, padding_value=0)  # Assumes 0 is the pad token id
    return padded, lengths, batch

# Create a DataLoader – adjust batch_size based on your GPU memory
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Dictionary to accumulate statistics for each gene token
# For each token id, we store:
#   - "positions": a list of positions (ranks) where it appears in its sentences
#   - "sentence_ids": a set of sentence indices where it appears (to count unique sentences)
#   - "embeddings": a list of contextual embedding vectors (one per occurrence)
gene_stats = defaultdict(lambda: {"positions": [], "sentence_ids": set(), "embeddings": []})

global_sentence_idx = 0  # A global counter for sentence IDs

model.eval()
with torch.no_grad():
    for batch_input_ids, lengths, original_batch in dataloader:
        # Move the batch to the device (GPU/CPU)
        batch_input_ids = batch_input_ids.to(device)
        # Forward pass through the encoder with hidden states enabled
        outputs = model.bert(batch_input_ids, output_hidden_states=True)
        # Use the penultimate hidden state (second-to-last layer)
        penultimate_hidden_state = outputs.hidden_states[-2]

        # Loop over each sentence in the batch
        for i, seq_length in enumerate(lengths):
            # Get the tokens and embeddings for this sentence (exclude padded positions)
            sentence_input_ids = batch_input_ids[i, :seq_length]
            sentence_embeddings = penultimate_hidden_state[i, :seq_length]  # shape: (seq_length, hidden_dim)
            
            # Track tokens seen in this sentence (for sentence-level count)
            seen_tokens = set()
            for j in range(seq_length):
                token_id = sentence_input_ids[j].item()
                # Record the token's rank (position in the sentence)
                gene_stats[token_id]["positions"].append(j)
                # Save the corresponding embedding (moved to CPU and converted to numpy)
                gene_stats[token_id]["embeddings"].append(sentence_embeddings[j].cpu().numpy())
                seen_tokens.add(token_id)
            
            # Mark that these tokens appeared in the current sentence (unique count)
            for token in seen_tokens:
                gene_stats[token]["sentence_ids"].add(global_sentence_idx)
            global_sentence_idx += 1

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


In [6]:
results = {}  # Initialize a dictionary to store results for each gene token

# Iterate over each token (gene) and its statistics in gene_stats
for token_id, stats in gene_stats.items():
    # Extract the list of positions (ranks) where this token appears in sentences
    positions = stats["positions"]
    
    # Compute the average rank (position) of this token if it appears in any sentence
    avg_rank = np.mean(positions) if positions else None
    
    # Count the number of unique sentences in which this token appears
    sentence_count = len(stats["sentence_ids"])
    
    # Check if we have stored embeddings for this token
    if stats["embeddings"]:
        # Stack the list of embedding vectors into a 2D NumPy array
        # Shape will be: (number of occurrences, hidden_dim)
        embeddings = np.stack(stats["embeddings"])
    else:
        embeddings = None  # No embeddings were collected for this token
    
    # If embeddings exist and there is more than one occurrence
    if embeddings is not None and embeddings.shape[0] > 1:
        # Normalize each embedding vector to unit length along the feature dimension
        norm_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        
        # Compute the cosine similarity matrix by taking the dot product of normalized embeddings
        # Each element (i,j) in this matrix represents the cosine similarity between occurrence i and j
        cos_sim_matrix = np.dot(norm_embeddings, norm_embeddings.T)
        
        # Extract the indices of the upper triangular part of the similarity matrix (excluding the diagonal)
        # This ensures that each pair is considered only once and self-similarities are ignored
        triu_indices = np.triu_indices_from(cos_sim_matrix, k=1)
        
        # Calculate the average similarity score from the upper triangle values
        similarity_score = np.mean(cos_sim_matrix[triu_indices])
    else:
        # If there is only one occurrence or no embeddings, we cannot compute pairwise similarity
        similarity_score = None

    # Store the computed average rank, sentence count, and similarity score in the results dictionary
    results[token_id] = {
        "avg_rank": avg_rank,
        "sentence_count": sentence_count,
        "similarity_score": similarity_score
    }

In [7]:
# Total number of sentences in the training dataset
total_sentences = global_sentence_idx  # This counter was updated during processing

report_filename = "embeddings_report.txt"
with open(report_filename, "w") as f:
    header = "Gene\tAvg_Rank\tAppearance_Percentage\tEmb_Similarity_Score\n"
    f.write(header)
    for token_id, stats in results.items():
        gene_name = id_to_gene.get(token_id, f"Token_{token_id}")
        avg_rank_str = f"{stats['avg_rank']:.2f}" if stats['avg_rank'] is not None else "N/A"
        # Calculate the percentage of sentences where the gene appears
        percentage = (len(gene_stats[token_id]["sentence_ids"]) / total_sentences) * 100
        percentage_str = f"{percentage:.2f}%"
        sim_score_str = f"{stats['similarity_score']:.4f}" if stats['similarity_score'] is not None else "N/A"
        line = f"{gene_name}\t{avg_rank_str}\t{percentage_str}\t{sim_score_str}\n"
        f.write(line)

print(f"Report written to {report_filename}")

Report written to embeddings_report.txt
