In [1]:
import logging
import os
import sys

import numpy as np
import torch

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils import get_device, save_ss2ct, timing
from rhofold.relax.relax import AmberRelaxation
from rhofold.utils.alphabet import get_features


@torch.no_grad()
def main(ckpt="./pretrained/RhoFold_pretrained.pt"):
    model = RhoFold(rhofold_config)
    model.load_state_dict(
        torch.load(ckpt, map_location=torch.device("cpu"))["model"], strict=False
    )
    model.eval()

    return model



In [10]:
data_folder = "/srv/mingyang/dev/personal/rna-fold/data/"
input_csv = data_folder + "train_sequences.csv"
# read the csv file
import pandas as pd
df = pd.read_csv(input_csv)
print(df.head())
# compute largest temp cut_off
temp_cut_off = df["temporal_cutoff"].max()
print(temp_cut_off)

def id_to_msa_file(id):
    file = data_folder + "MSA/" + f'{id}.MSA.fasta'
    return file

id0 = df.iloc[0]["target_id"]
msa_file = id_to_msa_file(id0)
print(msa_file)



  target_id                            sequence temporal_cutoff  \
0    1SCL_A       GGGUGCUCAGUACGAGAGGAACCGCACCC      1995-01-26   
1    1RNK_A  GGCGCAGUGGGCUAGCGCCACUCAAAAGGCCCAU      1995-02-27   
2    1RHT_A            GGGACUGACGAUCACGCAGUCUAU      1995-06-03   
3    1HLX_A                GGGAUAACUUCGGUUGUCCC      1995-09-15   
4    1HMH_E  GGCGACCCUGAUGAGGCCGAAAGGCCGAAACCGU      1995-12-07   

                                         description  \
0               THE SARCIN-RICIN LOOP, A MODULAR RNA   
1  THE STRUCTURE OF AN RNA PSEUDOKNOT THAT CAUSES...   
2  24-MER RNA HAIRPIN COAT PROTEIN BINDING SITE F...   
3  P1 HELIX NUCLEIC ACIDS (DNA/RNA) RIBONUCLEIC ACID   
4  THREE-DIMENSIONAL STRUCTURE OF A HAMMERHEAD RI...   

                                       all_sequences  
0  >1SCL_1|Chain A|RNA SARCIN-RICIN LOOP|Rattus n...  
1  >1RNK_1|Chain A|RNA PSEUDOKNOT|null\nGGCGCAGUG...  
2  >1RHT_1|Chain A|RNA (5'-R(P*GP*GP*GP*AP*CP*UP*...  
3  >1HLX_1|Chain A|RNA (5'-R(*GP*GP*GP*A

In [None]:
fasta_path = f"{data_folder}/MSA/1SCL_A.fasta"  # This should contain just your query sequence
msa_path = f"{data_folder}/MSA/1SCL_A.MSA.fasta"  # This is your MSA file

# Get features (optionally specify msa_depth, default is 128)
features = get_features(
    fas_fpath=fasta_path,
    msa_fpath=msa_path,
    msa_depth=128,  # You can adjust this if you want to use more/fewer sequences from the MSA
)
print(features['seq'])
print(features['tokens'].shape)
print(features['rna_fm_tokens'].shape)

device = "cuda"
model = main().to(device)
outputs = model(tokens=features['tokens'].to(device),
                rna_fm_tokens=features['rna_fm_tokens'].to(device),
                seq=features['seq'],
                )



GGGUGCUCAGUACGAGAGGAACCGCACCC
torch.Size([1, 128, 29])
torch.Size([1, 29])


In [None]:
def kabsch_align(P, Q):
    # P, Q: [L, 3]
    P_mean = P.mean(dim=0)
    Q_mean = Q.mean(dim=0)
    P_centered = P - P_mean
    Q_centered = Q - Q_mean

    H = P_centered.T @ Q_centered
    U, S, Vt = torch.svd(H)
    R = Vt @ U.T
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T
    P_aligned = P_centered @ R
    return P_aligned, Q_centered


def tm_score(P, Q):
    # P: predicted coords, Q: true coords, both [L, 3]
    L = P.shape[0]
    d0 = 1.24 * ((L - 15) ** (1 / 3)) - 1.8
    d0 = max(d0, 0.5)  # clamp minimum

    P_aligned, Q_centered = kabsch_align(P, Q)
    dist = torch.norm(P_aligned - Q_centered, dim=1)
    score = (1 / (1 + (dist / d0) ** 2)).mean()
    return score.item()

In [35]:
g = outputs[0]["cords_c1'"]
f = outputs[1]["cords_c1'"]

print(len(f), len(g))
print(f[0].shape)


f = outputs[1]["single"]

1 1
torch.Size([1, 29, 3])


In [27]:
print(outputs[0].keys())
for k, elt in outputs[0].items():
    print(k, type(elt))
    if isinstance(elt, torch.Tensor):
        print(elt.shape)






dict_keys(['frames', 'unnormalized_angles', 'angles', 'single', 'cord_tns_pred', "cords_c1'", 'plddt', 'ss', 'p', 'c4_', 'n'])
frames <class 'torch.Tensor'>
torch.Size([8, 1, 29, 7])
unnormalized_angles <class 'torch.Tensor'>
torch.Size([8, 1, 29, 6, 2])
angles <class 'torch.Tensor'>
torch.Size([8, 1, 29, 6, 2])
single <class 'torch.Tensor'>
torch.Size([8, 1, 29, 384])
cord_tns_pred <class 'list'>
cords_c1' <class 'list'>
plddt <class 'tuple'>
ss <class 'torch.Tensor'>
torch.Size([1, 1, 29, 29])
p <class 'torch.Tensor'>
torch.Size([1, 40, 29, 29])
c4_ <class 'torch.Tensor'>
torch.Size([1, 40, 29, 29])
n <class 'torch.Tensor'>
torch.Size([1, 40, 29, 29])


In [3]:
import logging
import os
import sys

import numpy as np
import torch

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils import get_device, save_ss2ct, timing
from rhofold.relax.relax import AmberRelaxation
from rhofold.utils.alphabet import get_features

@torch.no_grad()
def main(ckpt='./pretrained/RhoFold_pretrained.pt'):
    model = RhoFold(rhofold_config)
    model.load_state_dict(torch.load(ckpt, map_location=torch.device('cpu'))['model'], strict=False)
    model.eval()

    return model

def inference(seq_id='165d_B'):
    device = get_device('cpu')
    model = main().to(device)
    print("Number of params:", sum(p.numel() for p in model.parameters()))
    input_fas = f'../data/RNA3D_DATA/seq/{seq_id}.seq'
    input_a3m = f'../data/RNA3D_DATA/rMSA/{seq_id}.a3m'
    data_dict = get_features(input_fas, input_a3m)
    embedding = seq_id_to_embedding(seq_id)

    outputs = model(tokens=data_dict['tokens'].to(device),
                    rna_fm_tokens=data_dict['rna_fm_tokens'].to(device),
                    seq=data_dict['seq'],
                    evo2_fea=embedding,
                    )

    output = outputs[-1]

    unrelaxed_model = f'tmp/unrelaxed_model.pdb'

    node_cords_pred = output['cord_tns_pred'][-1].squeeze(0)
    model.structure_module.converter.export_pdb_file(data_dict['seq'],
                                                        node_cords_pred.data.cpu().numpy(),
                                                        path=unrelaxed_model, chain_id=None,
                                                        confidence=output['plddt'][0].data.cpu().numpy())

In [4]:
inference()

Number of params: 130059855


FileNotFoundError: [Errno 2] No such file or directory: '../data/RNA3D_DATA/seq/165d_B.seq'