In [27]:
import evodiff
import pandas as pd
import numpy as np
import torch
from tqdm.notebook import tqdm
from pathlib import Path

data = Path('data')
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

device: cuda


In [20]:
train_df = pd.read_csv(data / 'train_ec_311.csv', index_col=0)
test_df = pd.read_csv(data / 'test_ec_311.csv', index_col=0)
df = pd.concat([train_df, test_df])

In [21]:
lengths = df['Sequence'].apply(len).sort_values()
low_bound = lengths[int(0.1 * len(lengths))]
high_bound = lengths[int(0.9 * len(lengths))]

In [23]:
from evodiff.pretrained import OA_DM_38M

checkpoint = OA_DM_38M()
model, collater, tokenizer, scheme = checkpoint
_ = model.to(device)

In [None]:
from evodiff.generate import generate_oaardm

generated_sequences = []
target_lengths = np.random.randint(low_bound, high_bound, 500)

for length in tqdm(target_lengths):
    tokenized_sample, generated_sequence = generate_oaardm(model, tokenizer, length, batch_size=1, device='cuda')
    generated_sequences.extend(generated_sequence)


print("Generated sequence:", generated_sequences)

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 318/318 [00:02<00:00, 122.07it/s]
100%|██████████| 215/215 [00:01<00:00, 130.04it/s]
100%|██████████| 479/479 [00:03<00:00, 125.94it/s]
100%|██████████| 442/442 [00:03<00:00, 127.89it/s]
100%|██████████| 231/231 [00:01<00:00, 127.92it/s]
100%|██████████| 352/352 [00:02<00:00, 125.71it/s]
100%|██████████| 155/155 [00:01<00:00, 122.78it/s]
100%|██████████| 172/172 [00:01<00:00, 127.12it/s]
100%|██████████| 167/167 [00:01<00:00, 126.42it/s]
100%|██████████| 283/283 [00:02<00:00, 132.65it/s]
100%|██████████| 233/233 [00:01<00:00, 133.03it/s]
100%|██████████| 305/305 [00:02<00:00, 135.43it/s]
100%|██████████| 218/218 [00:01<00:00, 129.53it/s]
100%|██████████| 168/168 [00:01<00:00, 126.76it/s]
100%|██████████| 258/258 [00:02<00:00, 128.55it/s]
100%|██████████| 227/227 [00:01<00:00, 129.34it/s]
100%|██████████| 360/360 [00:02<00:00, 130.38it/s]
100%|██████████| 322/322 [00:02<00:00, 126.53it/s]
100%|██████████| 512/512 [00:03<00:00, 128.09it/s]
100%|██████████| 440/440 [00:03

In [29]:
with open(data / 'baseline_generated.fasta', "w") as outfile:
        for i, seq in enumerate(generated_sequences):
            header = f">sequence_{i+1}"  # Create a simple header
            outfile.write(header + "\n")
            outfile.write(seq + "\n")