In [6]:
import torch
from Bio import SeqIO
from BCBio import GFF
import random

# load the model
model_path = "megaDNA_phage_145M.pt" # model name
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Device configuration

model = torch.load(model_path, map_location=torch.device(device))
model.eval() 

MEGADNA(
  (start_tokens): ParameterList(
      (0): Parameter containing: [torch.float32 of size 512 (GPU 0)]
      (1): Parameter containing: [torch.float32 of size 256 (GPU 0)]
      (2): Parameter containing: [torch.float32 of size 196 (GPU 0)]
  )
  (token_embs): ModuleList(
    (0): Embedding(6, 196)
    (1): Sequential(
      (0): Embedding(6, 196)
      (1): Rearrange('... r d -> ... (r d)')
      (2): LayerNorm((3136,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=3136, out_features=256, bias=True)
      (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (2): Sequential(
      (0): Embedding(6, 196)
      (1): Rearrange('... r d -> ... (r d)')
      (2): LayerNorm((200704,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=200704, out_features=512, bias=True)
      (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (transformers): ModuleList(
    (0): Transformer(
      (layers): ModuleList(
        (0

Please download the fasta file and gene annotation for lambda phage from https://www.ncbi.nlm.nih.gov/nuccore/NC_001416.1

In [12]:
# Read the FASTA file
fasta_file_path = "NC_001416.1.fasta"
seq_ids, sequences = [], []

with open(fasta_file_path, "r") as fasta_file:
    for record in SeqIO.parse(fasta_file, "fasta"):
        seq_ids.append(record.id)
        sequences.append(str(record.seq))

# Read the gene annotations
gff_file_path = "NC_001416.1.gff3"
limit_info = dict(gff_type=["CDS"])

start_position, end_position, strand_position = [], [], []

with open(gff_file_path) as in_handle:
    for rec in GFF.parse(in_handle, limit_info=limit_info):
        start_position.extend(feature.location.start for feature in rec.features)
        end_position.extend(feature.location.end for feature in rec.features)
        strand_position.extend(feature.location.strand for feature in rec.features)


In [13]:
nt = ['**', 'A', 'T', 'C', 'G', '#']  # Vocabulary
seq_id = 0  # Sequence ID

def encode_sequence(sequence, nt_vocab=nt):
    """Encode a DNA sequence to its numerical representation."""
    return [0] + [nt_vocab.index(nucleotide) if nucleotide in nt_vocab else 1 for nucleotide in sequence] + [5]

def get_loss_for_sequence(model, sequence, device):
    """Get model loss for a given sequence."""
    input_seq = torch.tensor(sequence).unsqueeze(0).to(device)
    with torch.no_grad():
        loss = model(input_seq, return_value='loss')
    return loss

# Get the model loss for the WT sequence
encoded_wt_sequence = encode_sequence(sequences[seq_id])
wt_loss = get_loss_for_sequence(model, encoded_wt_sequence, device)
print(wt_loss)

# Get the model loss for the mutants in the start codons
loss_start = []
random.seed(42)
for j, (start, end, strand) in enumerate(zip(start_position, end_position, strand_position)):
    encoded_mutant_sequence = encode_sequence(sequences[seq_id])
    
    # Mutate start codon positions based on strand orientation
    positions = range(start+1, start+4) if strand == 1 else range(end-2, end+1)
    for i in positions:
        encoded_mutant_sequence[i] = random.choice([1, 2, 3, 4])
    
    # Get model loss for mutated sequence
    mutant_loss = get_loss_for_sequence(model, encoded_mutant_sequence, device)
    loss_start.append(mutant_loss)

tensor(1.2624, device='cuda:0')
