In [1]:
import torch
import os
from pathlib import Path
from transformers import T5ForConditionalGeneration

from tokenizer import BpeTokenizer
# from helpers import preprocess_rna


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load config for model and training
results_dir = Path("attentions")
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

# Device setup
device = 'cuda' if torch.cuda.is_available else 'cpu'
    
# Load the T5 model
model_path = "/data6/sobhan/rllm/results/train/t5/run3_20240822-152114/checkpoints/checkpoint-349800"
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)


In [3]:
def preprocess_rna(rna):
    return rna.lower().replace(
                    'a', 'B').replace('c', 'J').replace('u', 'U').replace('g', 'Z')


In [43]:
generation_config = {
    "max_length": 100,
    "num_beams": 3,
    "temperature": 0.7,
    "top_k": 3,
    "top_p": 1.0,
    "repetition_penalty": 1.0
}

# Input protein sequence
protein_sequence =  'MSNGYEDHMAEDCRGDIGRTNLIVNYLPQNMTQDELRSLFSSIGEVESAKLIRDKVAGHSLGYGFVNYVTAKDAERAINTLNGLRLQSKTIKVSYARPSSEVIKDANLYISGLPRTMTQKDVEDMFSRFGRIINSRVLVDQTTGLSRGVAFIRFDKRSEAEEAITSFNGHKPPGSSEPITVKFAANPNQNKNVALLSQLYHSPARRFGGPVHHQAQRFRFSPMGVDHMSGLSGVNVPGNASSGWCIFIYNLGQDADEGILWQMFGPFGAVTNVKVIRDFNTNKCKGFGFVTMTNYEEAAMAIASLNGYRLGDKILQVSFKTNKSHK'
rna_sequence = 'ttgggcaaaagacttttaaccctcccgaaccttggctccttacctaaaaaatagagcatctctaaagtctcttataaatgtaaatttccataaAUUAUUU'
rna_sequence = preprocess_rna(rna_sequence)

In [44]:
rna_tokenizer = BpeTokenizer(vocab_size=1000, seq_size=1024)
rna_tokenizer.load("/data6/sobhan/RLLM/dataset/tokenizers/bpe_rna_1000_1024.json")

protein_tokenizer = BpeTokenizer(vocab_size=1000, seq_size=1024)
protein_tokenizer.load("/data6/sobhan/RLLM/dataset/tokenizers/bpe_protein_1000_1024.json")

# Tokenize the sequences
input_ids = protein_tokenizer.tokenize(protein_sequence)
inputs = torch.tensor(input_ids.ids, dtype=torch.long).to(device)
inputs = inputs.unsqueeze(0)

output_ids = rna_tokenizer.tokenize(rna_sequence)
output = torch.tensor(output_ids.ids, dtype=torch.long).to(device)
output = output.unsqueeze(0)

# Pass through the model with attention outputs enabled
out = model(input_ids=inputs, decoder_input_ids=output, return_dict=True, output_attentions=True)

# Extract attentions
encoder_attentions = out.encoder_attentions
cross_attentions  = out.cross_attentions
decoder_attentions = out.decoder_attentions

enable_padding(max_length=X) is deprecated, use enable_padding(length=X) instead
enable_padding(max_length=X) is deprecated, use enable_padding(length=X) instead


In [57]:
import torch
import numpy as np

def analyze_attention(attention_tensor, top_n=100):
    avg_attention = attention_tensor.mean(dim=1)  # Shape: [1, 1024, 1024]
    attention_per_token = avg_attention.sum(dim=1).squeeze()  # Shape: [1024]
    top_indices = torch.argsort(attention_per_token, descending=True)[:top_n]
    top_indices_list = np.sort(top_indices.tolist())

    return top_indices_list, attention_per_token[top_indices].tolist(), top_indices.tolist()

def highlight_attention(protein_sequence, attention_indices):
    highlighted_sequence = ""
    for i, char in enumerate(protein_sequence):
        if i in attention_indices:
            highlighted_sequence += f"{{{char}}}"
        else:
            highlighted_sequence += char
    return highlighted_sequence

def does_binding_site_overlap(attended_tokens, binding_start, binding_end, token_to_nt_mapping):
    for token_idx in attended_tokens:
        token_range = token_to_nt_mapping[token_idx]
        if not (token_range[1] < binding_start or token_range[0] > binding_end):
            return True
    return False 


dec_attns = decoder_attentions[-1]
print(analyze_attention(dec_attns, 30)[-1])
print(rna_tokenizer.decode(np.array(output_ids.ids)[analyze_attention(dec_attns)[0]]))
print(highlight_attention(rna_sequence, analyze_attention(dec_attns)[0]))

_, _, attended_tokens = analyze_attention(dec_attns, 150)

vocab = rna_tokenizer.tokenizer.get_vocab()  
id2token = {v: k for k, v in vocab.items()}
def convert_ids_to_tokens(token_ids):
    return [id2token[id_] for id_ in token_ids]
decoded_tokens = convert_ids_to_tokens(output_ids.ids)  

# Build a mapping: for each token, which substring positions in the decoded RNA does it cover?
token_to_nt_mapping = []
start_pos = 0
for tok in decoded_tokens:
    # 'tok' might be multiple characters, e.g. "ACG" if BPE merges them
    token_len = len(tok)
    token_to_nt_mapping.append((start_pos, start_pos + token_len - 1))
    start_pos += token_len


overlap = does_binding_site_overlap(attended_tokens, 1, 21, token_to_nt_mapping)

if overlap:
    print("The most-attended tokens overlap with the binding site.")
else:
    print("The most-attended tokens do not overlap with the binding site.")


[16, 0, 15, 5, 1, 2, 3, 6, 4, 7, 12, 9, 26, 63, 30, 27, 23, 28, 24, 64, 62, 59, 25, 65, 66, 48, 45, 10, 67, 43]
z z z j b b b b z b j b jjjjzb bjjzzj j j j b b b b b b b b j b j b b b b b b b b b u u u
{t}{t}{Z}{Z}{Z}{J}{B}{B}{B}{B}{Z}B{J}tt{t}{t}{B}{B}J{J}J{t}{J}{J}{J}{Z}{B}{B}J{J}t{t}{Z}{Z}Jt{J}{J}ttBJ{J}{t}{B}B{B}{B}{B}BtBZBZJB{t}{J}{t}{J}{t}{B}{B}{B}{Z}{t}{J}{t}{J}{t}{t}{B}{t}{B}{B}{B}{t}{Z}{t}{B}{B}{B}{t}{t}{t}{J}{J}{B}{t}{B}{B}{B}{U}{U}{B}{U}{U}{U}
The most-attended tokens overlap with the binding site.


In [58]:
def compute_binding_site_coverage(top_attended_tokens, token_to_nt_mapping, binding_start, binding_end):
    # The total length of the binding site in nucleotides
    binding_length = binding_end - binding_start + 1
    
    # Track how many nucleotides of that binding site are covered by top-attended tokens
    overlap_count = 0
    
    for token_idx in top_attended_tokens:
        token_start, token_end = token_to_nt_mapping[token_idx]
        
        overlap_start = max(token_start, binding_start)
        overlap_end = min(token_end, binding_end)
        
        if overlap_start <= overlap_end:
            overlap_count += (overlap_end - overlap_start + 1)
    coverage_percentage = 100.0 * overlap_count / binding_length
    return coverage_percentage

coverage = compute_binding_site_coverage(
    attended_tokens, 
    token_to_nt_mapping, 
    binding_start=20, 
    binding_end=40
)
print(f"Binding site coverage by top-attended tokens: {coverage:.2f}%")

Binding site coverage by top-attended tokens: 90.48%
