This notebook walks through how to design RNAs given a protein binder (sequence and backbone structure) and desired RNA sequence length.

1. Set up conda environment (details given in repo)
2. Define four input lists:
    - PROT_SEQS: [str, ...]
    - PROT_COORDS: [Tensor, ...]
    - RNA_LENS: [int, ...]
    - NUM_SAMPLES: [int, ...]

These lists should all be the same length. PROT_COORDS should contain tensors of shape N x 3 x 3 (N residues, 3 backbone atoms per residue). PROT_SEQS should contain strings of length N, matching the dimensions of PROT_COORDS. RNA_LENS is a list of integers, describing the length of RNA that the model should generate. NUM_SAMPLES is a list of ints, where each integer gives the number of designs to sample for the corresponding protein and RNA length. We use our dataloader to get a sample input; comment out this cell as necessary.

3. Run the notebook to save RNA sequence and backbone structure designs.

In [None]:
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
import sys
import os

sys.path.append("rnaflow")
from data.dataloader import RFDataModule
from models.rnaflow import RNAFlow
from models.inverse_folding import InverseFoldingModel
from utils import pdb_utils

In [None]:
# just to get sample inputs
# comment out this code cell if the necessary inputs are already defined
RF_DATA_FOLDER = "rnaflow/data/rf_data"
DATASET_PKL = "rnaflow/data/seq_sim_dataset.pickle"
data_module = RFDataModule(rf_data_folder=RF_DATA_FOLDER, dataset_pkl=DATASET_PKL, batch_size=1)
test_dataloader = data_module.test_dataloader()

PROT_SEQS = []
RNA_LENS = []
PROT_COORDS = []
NUM_SAMPLES = [5]
for batch in test_dataloader:
    PROT_SEQS.append(batch[1]["prot_seq"][0])
    RNA_LENS.append(len(batch[1]["rna_seq"][0]))
    PROT_COORDS.append(batch[1]["prot_coords"][0])
    break

In [None]:
def write_tmp_files(prot_seq, rna_len, tmp_folder_path = "tmp_files/sample1"):
    with open(os.path.join(tmp_folder_path, "prot.fa"), 'w') as fasta_file:
        fasta_file.write(f'>prot\n{prot_seq}\n')
    with open(os.path.join(tmp_folder_path, "prot.a3m"), 'w') as fasta_file:
        fasta_file.write(f'>prot\n{prot_seq}\n')
    with open(os.path.join(tmp_folder_path, "rna.fa"), 'w') as fasta_file:
        fasta_file.write(f'>rna\n{"A"*rna_len}\n')
    with open(os.path.join(tmp_folder_path, "rna.afa"), 'w') as fasta_file:
        fasta_file.write(f'>rna\n{"A"*rna_len}\n')

In [None]:
rnaflow = RNAFlow.load_from_checkpoint("checkpoints/seq-sim-rnaflow-epoch32.ckpt")

In [None]:
for idx, (prot_seq, prot_coords, rna_len, num_samples) in enumerate(zip(PROT_SEQS, PROT_COORDS, RNA_LENS, NUM_SAMPLES)):
    if not os.path.exists(f"tmp_files/sample{idx}"):
        os.mkdir(f"tmp_files/sample{idx}")
        write_tmp_files(prot_seq, rna_len, tmp_folder_path = f"tmp_files/sample{idx}")
    
    for _ in range(num_samples):
        rna_seq_design, cplx_struct_design = rnaflow.design_rna(prot_seq, prot_coords, rna_len, f"tmp_files/sample{idx}")
        pdb_utils.save_cplx_pdb(cplx_struct_design, prot_seq, rna_seq_design, f"tmp_files/sample{idx}/final_cplx_{idx}.pdb")
        print(rna_seq_design)