In [None]:
import torch
from torch.utils.data import DataLoader
from collections import defaultdict
import numpy as np
import pickle
import random
from transformers import BertForMaskedLM
from torch.nn.utils.rnn import pad_sequence
from datasets import load_from_disk

# Model paths and token dictionary
model_path = "/home/logs/jtorresb/yeastformer/yeast/pretraining/models/250304_125959_yeastformer_L3_emb384_SL512_E20_B8_LR0.00115_LSlinear_WU144_Oadamw_torch/models"
token_dict_path = "/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_token_dict.pkl"

# Load gene token dictionary
with open(token_dict_path, "rb") as fp:
    token_dictionary = pickle.load(fp)

# Load model
model = BertForMaskedLM.from_pretrained(model_path)
model.eval()  # Set model to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

# Invert the token dictionary (if it maps gene_name -> token_id)
id_to_gene = {v: k for k, v in token_dictionary.items()}

dataset = load_from_disk("/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_master_matrix_sgd.dataset")

# Load housekeeping and transcription factors from the pickle files
with open("/home/logs/jtorresb/yeastformer/yeast/yeast_data/genes_info/hk_genes.pkl", "rb") as f:
    housekeeping_genes = pickle.load(f)  # Assuming hs.pkl contains the list of housekeeping genes

with open("/home/logs/jtorresb/yeastformer/yeast/yeast_data/genes_info/tf_genes.pkl", "rb") as f:
    transcription_factors = pickle.load(f)  # Assuming tf.pkl contains the list of transcription factors


In [None]:
# Define a collate function to pad the input_ids in a batch
def collate_fn(batch):
    input_ids = [torch.tensor(sample['input_ids'], dtype=torch.long) for sample in batch]
    lengths = [len(ids) for ids in input_ids]
    padded = pad_sequence(input_ids, batch_first=True, padding_value=0)  # Assumes 0 is the pad token id
    return padded, lengths, batch

# Create DataLoader
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Dictionary to accumulate statistics for each gene token
gene_stats = defaultdict(lambda: {"positions": [], "sentence_ids": set(), "embeddings": []})

global_sentence_idx = 0  # A global counter for sentence IDs

# Collect statistics for each gene in the dataset
model.eval()
with torch.no_grad():
    for batch_input_ids, lengths, original_batch in dataloader:
        # Move the batch to the device (GPU/CPU)
        batch_input_ids = batch_input_ids.to(device)
        # Forward pass through the encoder with hidden states enabled
        outputs = model.bert(batch_input_ids, output_hidden_states=True)
        # Use the penultimate hidden state (second-to-last layer)
        penultimate_hidden_state = outputs.hidden_states[-2]

        # Loop over each sentence in the batch
        for i, seq_length in enumerate(lengths):
            sentence_input_ids = batch_input_ids[i, :seq_length]
            sentence_embeddings = penultimate_hidden_state[i, :seq_length]  # shape: (seq_length, hidden_dim)
            
            # Track tokens seen in this sentence (for sentence-level count)
            seen_tokens = set()
            for j in range(seq_length):
                token_id = sentence_input_ids[j].item()
                gene_stats[token_id]["positions"].append(j)
                gene_stats[token_id]["embeddings"].append(sentence_embeddings[j].cpu().numpy())
                seen_tokens.add(token_id)
            
            for token in seen_tokens:
                gene_stats[token]["sentence_ids"].add(global_sentence_idx)
            global_sentence_idx += 1

In [None]:
# Function to compute average index, number of sentences, and embedding similarity
def compute_gene_metrics(gene_stats, gene_list):
    avg_indices = []
    num_sentences = []
    embedding_similarities = []

    for gene_id in gene_list:
        stats = gene_stats.get(gene_id, None)
        if stats:
            avg_index = np.mean(stats["positions"]) if stats["positions"] else 0
            avg_indices.append(avg_index)
            num_sentences.append(len(stats["sentence_ids"]))
            
            # Compute embedding similarity (average of all pairwise cosine similarities)
            embeddings = np.array(stats["embeddings"])
            if embeddings.shape[0] > 1:
                cosine_sim = np.dot(embeddings, embeddings.T)
                cosine_sim = cosine_sim / (np.linalg.norm(embeddings, axis=1)[:, None] * np.linalg.norm(embeddings, axis=1)[None, :])
                np.fill_diagonal(cosine_sim, 0)  # Exclude self-similarity
                embedding_similarity = np.mean(cosine_sim)
            else:
                embedding_similarity = 0
            embedding_similarities.append(embedding_similarity)

    return avg_indices, num_sentences, embedding_similarities

# Compute metrics for housekeeping and transcription factors
housekeeping_avg_index, housekeeping_num_sentences, housekeeping_embedding_similarity = compute_gene_metrics(gene_stats, housekeeping_genes)
transcription_factors_avg_index, transcription_factors_num_sentences, transcription_factors_embedding_similarity = compute_gene_metrics(gene_stats, transcription_factors)

In [None]:

# Function to compute the p-value based on null distribution
def compute_p_value(observed_diff, null_diffs):
    p_value = np.mean(np.abs(null_diffs) >= np.abs(observed_diff))
    return p_value

# Resampling to generate null distribution
def generate_null_distribution(data1, data2, num_resamples=10000):
    null_diffs = []
    combined = np.concatenate([data1, data2])
    for _ in range(num_resamples):
        np.random.shuffle(combined)
        new_data1 = combined[:len(data1)]
        new_data2 = combined[len(data1):]
        null_diffs.append(np.mean(new_data1) - np.mean(new_data2))
    return null_diffs

# Generate null distribution for average index
null_avg_index_diffs = generate_null_distribution(housekeeping_avg_index, transcription_factors_avg_index)
observed_avg_index_diff = np.mean(housekeeping_avg_index) - np.mean(transcription_factors_avg_index)
avg_index_p_value = compute_p_value(observed_avg_index_diff, null_avg_index_diffs)

# Generate null distribution for number of sentences
null_num_sentences_diffs = generate_null_distribution(housekeeping_num_sentences, transcription_factors_num_sentences)
observed_num_sentences_diff = np.mean(housekeeping_num_sentences) - np.mean(transcription_factors_num_sentences)
num_sentences_p_value = compute_p_value(observed_num_sentences_diff, null_num_sentences_diffs)

# Generate null distribution for embedding similarity
null_embedding_similarity_diffs = generate_null_distribution(housekeeping_embedding_similarity, transcription_factors_embedding_similarity)
observed_embedding_similarity_diff = np.mean(housekeeping_embedding_similarity) - np.mean(transcription_factors_embedding_similarity)
embedding_similarity_p_value = compute_p_value(observed_embedding_similarity_diff, null_embedding_similarity_diffs)

# Print the results
print(f"Average Index P-value: {avg_index_p_value}")
print(f"Number of Sentences P-value: {num_sentences_p_value}")
print(f"Embedding Similarity P-value: {embedding_similarity_p_value}")