In [2]:
import pandas as pd
import torch

# Assume these are in the current path or installed
from evodiff.utils import Tokenizer
from evodiff.losses import D3PMLVBLoss, D3PMCELoss
# Use your specific sequence constants
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK, GAP, START, STOP, SEP, MSA_AAS

from src.collator import ConditionalD3PMCollator 
from src.data import ConditionalProteinDataset    
from src.model import ConditionalByteNetLMTime
from src.generate import generate_conditional_d3pm as generate

device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
print(f"On device: {device}")


On device: cuda


In [3]:
n_tokens = 28
args = {
    'data_dir': 'data/',
    'train_data': 'train_ec_all.csv',
    'valid_data': 'test_ec_all.csv',
    'output_dir': 'runs/large',
    'd_embed': 8,
    'd_model': 1280,
    'n_layers': 32,
    'kernel_size': 5,
    'r': 128,
    'class_dropout_prob': 0.1,
    'embedding_scale': 1.0,
    'slim': False,
    'activation': 'gelu',
    'diffusion_timesteps': 500,
    'reweighting_term': 0.01,
    'epochs': 500,
    'batch_size': 32,
    'max_seq_len': 1024,
    'lr': 0.0001,
    'warmup_steps': 16000,
    'accumulate_grad_batches': 4,
    'clip_grad_norm': 1.0,
    'log_freq': 50,
    'checkpoint_freq_steps': 1000,
    'save_latest_only': False,
    'resume_checkpoint': None,
    'log_to_file': True,
    'seed': 42,
    'num_workers': 0,
    'device': None,
}

In [4]:
# args = {
#     'data_dir': 'data/',
#     'train_data': 'train_ec_all.csv',
#     'valid_data': 'test_ec_all.csv',
#     'output_dir': 'runs/small',
#     'd_embed': 8,
#     'd_model': 1024,
#     'n_layers': 16,
#     'kernel_size': 5,
#     'r': 128,
#     'class_dropout_prob': 0.1,
#     'embedding_scale': 1.0,
#     'slim': True,
#     'activation': 'relu',
#     'diffusion_timesteps': 500,
#     'reweighting_term': 0.01,
#     'epochs': 500,
#     'batch_size': 224,
#     'max_seq_len': 1024,
#     'lr': 0.0001,
#     'warmup_steps': 10000,
#     'accumulate_grad_batches': 1,
#     'clip_grad_norm': 1.0,
#     'log_freq': 50,
#     'checkpoint_freq_steps': 1000,
#     'save_latest_only': False,
#     'resume_checkpoint': None,
#     'log_to_file': True,
#     'seed': 42,
#     'num_workers': 0,
#     'device': None,
# }

In [5]:
tokenizer = Tokenizer(protein_alphabet=PROTEIN_ALPHABET, pad=PAD, all_aas=MSA_AAS, sequences=True)
n_tokens = tokenizer.K + 2

Q_prod, Q_t = tokenizer.q_random_schedule(timesteps=500)

model = ConditionalByteNetLMTime(
    n_tokens=n_tokens,
    d_embedding=args['d_embed'],
    d_model=args['d_model'],
    n_layers=args['n_layers'],
    kernel_size=args['kernel_size'],
    r=args['r'],
    slim=args['slim'],
    activation=args['activation'],
    n_classes=7,
    class_dropout_prob=args['class_dropout_prob'],
    embedding_scale=args['embedding_scale'],
    timesteps=args['diffusion_timesteps'],
    padding_idx=tokenizer.pad_id
    ).to(device)
_ = model.eval()

sohl-dickstein


In [6]:
# load model from .pth
checkpoint = torch.load('runs/large/checkpoint_best.pt', weights_only=True, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print("loaded model")

loaded model


In [7]:
with open('data/baseline_generated.fasta', 'r') as f:
    lines = f.readlines()

seq_lens = []
for line in lines:
    if not line.startswith(">"):
        seq_lens.append(len(line.strip()))

In [None]:
from tqdm.notebook import tqdm
classes = range(1, 8)

for label in classes:
    untokenized_seqs = []

    for length in tqdm(seq_lens[:50], desc=f"Generating class: {label}"):
        torch.manual_seed(length * (label) + length)
        sequences, untokenized = generate(
        model=model,
        tokenizer=tokenizer,
        Q=Q_t,
        Q_bar=Q_prod,
        timesteps=500,           # Should match model's training timesteps
        seq_len=length,            # Desired sequence length
        class_labels=label,         # Which class to condition on (1 to n_classes)
        guidance_scale=3.0,     # Higher values = stronger conditioning
        batch_size=1,           # Generate 5 sequences at once
        device='cuda'
        )   
        untokenized_seqs.append(untokenized[0])

    with open(f'runs/large/generated_sequences/g3_class{label}.fasta', 'w') as f:
        for i, seqs in enumerate(untokenized_seqs):
            f.write(f'>sequence_{i}\n')
            f.write(f'{seqs}\n')



Generating class: 1:   0%|          | 0/50 [00:00<?, ?it/s]

Generating class: 2:   0%|          | 0/50 [00:00<?, ?it/s]

Generating class: 3:   0%|          | 0/50 [00:00<?, ?it/s]

Generating class: 4:   0%|          | 0/50 [00:00<?, ?it/s]

Generating class: 5:   0%|          | 0/50 [00:00<?, ?it/s]