In [None]:
import torch
import numpy as np

In [None]:
nucleotides = ['**', 'A', 'T', 'C', 'G', '#'] # vocabulary
def token2nucleotide(s):
    return nucleotides[s]

PRIME_LENGTH = 4 # give the model a random DNA primer to start
num_seq = 2 # number of runs
context_length = 10000 # maximal length for the generated sequence, depend on your GPU memory (upper limit is 131K)

# model can be downloaded from https://huggingface.co/lingxusb/megaDNA_updated/resolve/main/megaDNA_phage_145M.pt
model_path = "megaDNA_phage_145M.pt" # model name
device = 'cpu' # change this to 'cuda' if you use GPU

In [None]:
for j in range(num_seq):
    # Load the pre-trained model
    model = torch.load(model_path, map_location=torch.device(device))
    model.eval()  # Set the model to evaluation mode

    # set the random DNA primer
    primer_sequence = torch.tensor(np.random.choice(np.arange(1,5), PRIME_LENGTH)).long().to(device)[None,]
    primer_DNA = ''.join(map(token2nucleotide, primer_sequence[0]))
    print(f"Primer sequence: {primer_DNA}\n{'*' * 100}")

    # Generate a sequence using the model
    seq_tokenized = model.generate(primer_sequence, 
                                   seq_len=context_length,
                                   temperature=0.95, 
                                   filter_thres=0.0)
    generated_sequence = ''.join(map(token2nucleotide, seq_tokenized.squeeze().cpu().int()))

    # Split the generated sequence into contigs at the '#' character
    contigs = generated_sequence.split('#')

    # Write the contigs to a .fna file
    output_file_path = f"generate_{1+j}.fna"
    with open(output_file_path, "w") as file:
        for idx, contig in enumerate(contigs):
            if len(contig) > 0:
                file.write(f">contig_{idx}\n{contig}\n")
    
    # Clean up to free memory
    del model, primer_sequence, generated_sequence
    torch.cuda.empty_cache()