In [None]:
import torch
import random
import numpy as np
from MEGABYTE_pytorch import MEGABYTE

nucleotides = ['**', 'A', 'T', 'C', 'G', '#'] # vocabulary
PRIME_LENGTH = 4 # give the model a random DNA primer to start
num_seq = 200 # number of runs
context_length = 50000 # maximal length for the generated sequence, depend on your GPU memory
model_path = "megaDNA_phage_145M.pt" # model name

for j in range(1, num_seq):
    # Load the pre-trained model
    model = torch.load(model_path)
    model.eval()  # Set the model to evaluation mode

    # set the random DNA primer
    primer_sequence = torch.tensor([[random.choice(np.arange(1,5)) for _ in range(PRIME_LENGTH)]]).long().cuda()
    primer_DNA = ''.join(nucleotides[_] for _ in primer_sequence[0])
    print(f"Primer sequence: {primer_DNA}\n{'*' * 100}")

    # Generate a sequence using the model
    generated_sequence = model.generate(primer_sequence, 
                                        seq_len=context_length, 
                                        temperature=0.95, 
                                        filter_thres=0.0)
    generated_str = ''.join([nucleotides[int(s)] for s in generated_sequence[0].flatten(0).cpu()])

    # Split the generated sequence into contigs at the '#' character
    contigs = generated_str.split('#')

    # Write the contigs to a .fna file
    output_file_path = f"generate_{j}.fna"
    with open(output_file_path, "w") as file:
        for idx, contig in enumerate(contigs):
            if len(contig) > 0:
                file.write(f">contig_{idx}\n{contig}\n")
    
    # Clean up to free memory
    del model, primer_sequence, generated_str
    torch.cuda.empty_cache()
