In [None]:
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertForMaskedLM
from datasets import load_from_disk

In [None]:
# ========== STEP 1: Load Pretrained Model & Token Dictionary ==========
model_path = "/home/logs/jtorresb/Geneformer/yeast/pretraining/models/250218_185527_yeastformer_L4_emb256_SL512_E10_B8_LR0.0016_LScosine_WU53_Oadamw_torch/checkpoint-2000"
token_dict_path = "/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_token_dict.pkl"

# Load gene token dictionary
with open(token_dict_path, "rb") as fp:
    token_dictionary = pickle.load(fp)

# Load model
model = BertForMaskedLM.from_pretrained(model_path)
model.eval()  # Set model to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# ========== STEP 2: Load a Sample Gene Sequence ==========
# Let's assume you have a dataset of 11,889 sequences of length 512
dataset_path = "/home/logs/jtorresb/Geneformer/yeast/yeast_data/output/yeast_master_matrix_sgd.dataset"
dataset = load_from_disk(dataset_path)

# Select a random sequence from dataset (shape: [512 tokens])
sample_idx = 150  # Change this to test different samples
gene_sequence = dataset[sample_idx]["input_ids"]  # Assuming it's stored as "input_ids"

# Convert to tensor & add batch dimension
input_ids = torch.tensor([gene_sequence]).to(device)  # Shape: [1, 512]

In [None]:
# ========== STEP 3: Get Attention Matrices ==========
with torch.no_grad():
    outputs = model(input_ids, output_attentions=True)
    attentions = outputs.attentions  # List of tensors (one per layer)

# Convert to numpy for easier analysis
num_layers = len(attentions)
num_heads = attentions[0].shape[1]  # Attention heads

# Print model details
print(f"Extracted attention from {num_layers} layers and {num_heads} attention heads.")

In [None]:
# ========== STEP 4: Visualize Attention (Optional) ==========
# Pick a layer and head to visualize
layer_idx = 0  # Change to inspect different layers
head_idx = 0   # Change to inspect different heads

# Extract attention matrix for the selected layer and head
attn_matrix = attentions[layer_idx].squeeze(0)[head_idx].cpu().numpy()  # Shape: [512, 512]