In [None]:
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertForMaskedLM
from datasets import load_from_disk
import random
import os
from tqdm import tqdm

import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict


In [None]:
model_path = "/home/logs/jtorresb/Geneformer/yeast/pretraining/models/250225_192022_yeastformer_L4_emb256_SL512_E20_B8_LR0.0016_LScosine_WU50_Oadamw_torch/models"
token_dict_path = "/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_token_dict.pkl"

# Load gene token dictionary
with open(token_dict_path, "rb") as fp:
    token_dictionary = pickle.load(fp)

# Load model
model = BertForMaskedLM.from_pretrained(model_path)
model.eval()  # Set model to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

# Invert the token dictionary (if it maps gene_name -> token_id)
id_to_gene = {v: k for k, v in token_dictionary.items()}

In [None]:
# -------------------------------
# Load dataset
# -------------------------------
dataset = load_from_disk("/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_master_matrix_sgd.dataset")
# In this example, we use the full dataset (no train-test split)

# -------------------------------
# Define a collate function for batching
# -------------------------------
def collate_fn(samples):
    # Each sample is expected to be a dictionary with key "input_ids"
    # Here we assume all input_ids are already padded to the same length.
    input_ids = torch.tensor([s["input_ids"] for s in samples])
    return {"input_ids": input_ids}

# Create a DataLoader with an appropriate batch size
batch_size = 8  # Adjust based on your GPU memory
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)


In [None]:
# -------------------------------
# Aggregation configuration
# -------------------------------
# Define weights for ranking: rank 1 gets 5 points, then 4, 3, 2, and 1.
rank_weights = [5, 4, 3, 2, 1]

# Dictionary to store aggregated attention scores.
# Keys: source gene names (the gene paying attention)
# Values: dictionaries mapping target gene names -> aggregated weighted score.
gene_attention_scores = {}

# -------------------------------
# Process each batch from the DataLoader
# -------------------------------
for batch in tqdm(dataloader, desc="Processing batches"):
    # Move input_ids to device; shape: [batch_size, seq_len]
    input_ids = batch["input_ids"].to(device)

    # Forward pass with output_attentions=True to obtain attention matrices.
    with torch.no_grad():
        outputs = model(input_ids, output_attentions=True)
    
    # Extract attentions from the last layer.
    # outputs.attentions[-1] has shape: [batch_size, num_heads, seq_len, seq_len]
    last_layer_attentions = outputs.attentions[-1]
    b_size, num_heads, seq_len, _ = last_layer_attentions.shape

    # Iterate over each sample in the batch
    for b in range(b_size):
        # Get input_ids for this sample as a list (to index the token IDs)
        sample_input_ids = input_ids[b].cpu().tolist()
        # Get attention matrices for this sample; shape: [num_heads, seq_len, seq_len]
        sample_attentions = last_layer_attentions[b]

        # For each head in the sample:
        for head in range(num_heads):
            att_matrix = sample_attentions[head]  # shape: [seq_len, seq_len]
            # For every token position in the sentence (each representing a source gene)
            for i in range(seq_len):
                source_token_id = sample_input_ids[i]
                # Skip if token id is not in the id_to_gene mapping.
                if source_token_id not in id_to_gene:
                    continue
                source_gene = id_to_gene[source_token_id]
                # Initialize the dictionary for this source gene if not already present.
                if source_gene not in gene_attention_scores:
                    gene_attention_scores[source_gene] = {}

                # Get the attention vector for the source gene (row i in the attention matrix).
                # This vector indicates how much attention the source gene pays to every token.
                att_vector = att_matrix[i].clone()
                # Exclude self-attention by setting the score for position i to -infinity.
                att_vector[i] = float('-inf')
                # Obtain the top 5 indices (tokens) with the highest attention scores.
                topk = torch.topk(att_vector, k=5)
                top_indices = topk.indices  # indices of tokens in the top-5

                # For each rank (0 = highest, 4 = lowest), assign a weighted score.
                for rank, target_position in enumerate(top_indices):
                    # Convert target_position (tensor) to an integer index
                    target_index = target_position.item()
                    target_token_id = sample_input_ids[target_index]
                    if target_token_id not in id_to_gene:
                        continue
                    target_gene = id_to_gene[target_token_id]
                    weight = rank_weights[rank]
                    # Update the aggregated score for (source_gene -> target_gene)
                    gene_attention_scores[source_gene][target_gene] = (
                        gene_attention_scores[source_gene].get(target_gene, 0) + weight
                    )


In [None]:
# -------------------------------
# Save top 5 most important genes for each gene to a text file
# -------------------------------
output_file = "top5_genes.txt"

with open(output_file, "w") as f:
    for source_gene, targets in gene_attention_scores.items():
        # Sort target genes by aggregated score (highest first)
        sorted_targets = sorted(targets.items(), key=lambda x: x[1], reverse=True)
        # Select the top 5 targets
        top5 = sorted_targets[:5]
        # Write the source gene and its top 5 targets to the file
        f.write(f"{source_gene}:\n")
        for rank, (target_gene, score) in enumerate(top5, start=1):
            f.write(f"    {rank}. {target_gene} (score: {score})\n")
        f.write("\n")

print(f"Top 5 genes for each gene have been written to '{output_file}'.")