# Load model and prepare data!

In [1]:
import logging
import os
import sys

import numpy as np
import torch

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils import get_device, save_ss2ct, timing
from rhofold.relax.relax import AmberRelaxation
from rhofold.utils.alphabet import get_features


def load_model(ckpt="./pretrained/RhoFold_pretrained.pt"):
    model = RhoFold(rhofold_config)
    model.load_state_dict(
        torch.load(ckpt, map_location=torch.device("cpu"))["model"], strict=False
    )
    model.eval()

    return model

device = "cuda"
model = load_model()
model.to(device)

Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /opt/conda/conda-bld/pytorch_1711403590347/work/aten/src/ATen/native/Cross.cpp:63.)
  e3 = torch.cross(e1, e2)


RhoFold(
  (msa_embedder): MSAEmbedder(
    (msa_emb): MSANet(
      (embed_tokens): Embedding(17, 256, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(4098, 256, padding_idx=1)
    )
    (pair_emb): PairNet(
      (pair_emb): PairEmbNet(
        (emb): Embedding(17, 64)
        (projection): Linear(in_features=128, out_features=128, bias=True)
        (pos): PositionalEncoding2D(
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (rna_fm): ProteinBertModel(
      (embed_tokens): Embedding(25, 640, padding_idx=1)
      (layers): ModuleList(
        (0-11): 12 x TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(in_features=640, out_features=640, bias=True)
            (v_proj): Linear(in_features=640, out_features=640, bias=True)
            (q_proj): Linear(in_features=640, out_features=640, bias=True)
            (out_proj): Linear(in_features=640, out_features=640, bias=True)
          )
          (se

# Prepare data

In [5]:
data_folder = "/srv/mingyang/dev/personal/rna-fold/data/"
train_csv = data_folder + "train_sequences.csv"
val_csv = data_folder + "validation_sequences.csv"
test_csv = data_folder + "test_sequences.csv"
import pandas as pd
id_to_seq = dict()
def prepare_csv(csv_path):
    df = pd.read_csv(csv_path)
    for idx, row in df.iterrows():
        seq_id = row["target_id"]
        print(seq_id)
        seq = row['sequence']
        # create a file {data_folder}/MSA/{seq_id}.fasta, write the sequence to the file
        with open(f"{data_folder}/MSA/{seq_id}.fasta", "w") as f:
            f.write(f">{seq_id}\n{seq}")
        id_to_seq[seq_id] = seq

def split_to_csv(csv_path, split="train"): # train, val, test
    if split == "train":
        csv_path = train_csv
    elif split == "val":
        csv_path = val_csv
    elif split == "test":
        csv_path = test_csv
    return csv_path

def all_seq_ids(split="train"): # train, val, test
    csv_path = split_to_csv(split)
    df = pd.read_csv(csv_path)
    return df["target_id"].tolist()

prepare_csv(train_csv)
prepare_csv(val_csv)
prepare_csv(test_csv)


1SCL_A
1RNK_A
1RHT_A
1HLX_A
1HMH_E
1RNG_A
1MME_D
1KAJ_A
1SLO_A
1BIV_A
1ANR_A
1ZIG_A
1ZIH_A
1ETF_A
1ZIF_A
1KPD_A
1IKD_A
1ZDI_S
1AFX_A
1EBQ_A
1EBR_A
1ULL_A
1KIS_B
1KIS_A
1ATO_A
1TLR_A
1VOP_A
1AQO_A
1ATV_A
1ATW_A
1UUU_A
1AUD_B
2U2A_A
1A4T_A
1A60_A
1A51_A
2A9L_A
1A1T_B
1A9N_Q
3PHP_A
2TPK_A
7MSF_S
5MSF_S
1LDZ_A
1ZDK_S
1BVJ_A
1B36_A
1HVU_I
1BGZ_A
2BJ2_B
2BJ2_A
28SP_A
1QFQ_A
17RA_A
1BAU_B
1BZ2_A
1CQ5_A
1QC8_A
484D_B
1D6K_B
1EIY_C
1ESH_A
1EXY_A
1D0U_A
1ESY_A
1EUQ_B
1F9L_A
1FFK_9
1EKZ_B
1F5U_B
1F6Z_A
1F6X_A
1F85_A
1F84_A
1FOQ_A
1FJE_A
1E4P_A
1FQZ_A
1G70_A
1E7K_D
1FYO_A
1FHK_A
1I3X_A
1I4C_A
1I46_A
1IBM_Y
1HWQ_A
1IK1_A
1E95_A
1HS2_A
1JO7_A
1IDV_A
1K5I_A
1K9W_A
1JTW_A
1K6G_A
1K4B_A
1K4A_A
1K6H_A
1JUR_A
1KKS_A
1JWC_A
1L1C_C
1KP7_A
1JOX_A
1K2G_A
1L1W_A
1LC6_A
1LS2_B
1KKA_A
1MFY_A
1M5L_A
1MFJ_A
1MFK_A
1N34_A
1JTJ_A
1MT4_A
1MNX_A
1NA2_A
1NC0_A
1N8X_A
1OQ0_A
1OW9_A
1OSW_A
1M82_A
1NYB_B
1PJY_A
1P6V_B
1P6V_D
1N66_A
1HS1_A
1HS8_A
1HS4_A
1HS3_A
1JZC_A
1QZC_C
1P5M_A
1R2W_C
1P5N_A
1QZC_B
1P5P_A
1QZA_B
1QZB_B

In [3]:
# read the csv file
import pandas as pd

from rhofold.utils.alphabet import get_features

def seq_id_to_data(seq_id):
    with torch.inference_mode():
        fasta_path = f"{data_folder}/MSA/{seq_id}.fasta"  # This should contain just your query sequence
        msa_path = f"{data_folder}/MSA/{seq_id}.MSA.fasta"  # This is your MSA file

        # Get features (optionally specify msa_depth, default is 128)
        features = get_features(
            fas_fpath=fasta_path,
            msa_fpath=msa_path,
            msa_depth=128,  # You can adjust this if you want to use more/fewer sequences from the MSA
        )
        print(features["seq"])
        print(features["tokens"].shape)
        print(features["rna_fm_tokens"].shape)

        device = "cuda"
        outputs = model(
            tokens=features["tokens"].to(device),
            rna_fm_tokens=features["rna_fm_tokens"].to(device),
            seq=features["seq"],
        )
        return outputs

train_list = all_seq_ids(train_csv)
print(train_list)
seq_id = train_list[0]
outputs = seq_id_to_data(seq_id)
print(len(outputs))
print(outputs[0].keys())

['1SCL_A', '1RNK_A', '1RHT_A', '1HLX_A', '1HMH_E', '1RNG_A', '1MME_D', '1KAJ_A', '1SLO_A', '1BIV_A', '1ANR_A', '1ZIG_A', '1ZIH_A', '1ETF_A', '1ZIF_A', '1KPD_A', '1IKD_A', '1ZDI_S', '1AFX_A', '1EBQ_A', '1EBR_A', '1ULL_A', '1KIS_B', '1KIS_A', '1ATO_A', '1TLR_A', '1VOP_A', '1AQO_A', '1ATV_A', '1ATW_A', '1UUU_A', '1AUD_B', '2U2A_A', '1A4T_A', '1A60_A', '1A51_A', '2A9L_A', '1A1T_B', '1A9N_Q', '3PHP_A', '2TPK_A', '7MSF_S', '5MSF_S', '1LDZ_A', '1ZDK_S', '1BVJ_A', '1B36_A', '1HVU_I', '1BGZ_A', '2BJ2_B', '2BJ2_A', '28SP_A', '1QFQ_A', '17RA_A', '1BAU_B', '1BZ2_A', '1CQ5_A', '1QC8_A', '484D_B', '1D6K_B', '1EIY_C', '1ESH_A', '1EXY_A', '1D0U_A', '1ESY_A', '1EUQ_B', '1F9L_A', '1FFK_9', '1EKZ_B', '1F5U_B', '1F6Z_A', '1F6X_A', '1F85_A', '1F84_A', '1FOQ_A', '1FJE_A', '1E4P_A', '1FQZ_A', '1G70_A', '1E7K_D', '1FYO_A', '1FHK_A', '1I3X_A', '1I4C_A', '1I46_A', '1IBM_Y', '1HWQ_A', '1IK1_A', '1E95_A', '1HS2_A', '1JO7_A', '1IDV_A', '1K5I_A', '1K9W_A', '1JTW_A', '1K6G_A', '1K4B_A', '1K4A_A', '1K6H_A', '1JUR_A',

# Example Evaluation

In [21]:
def kabsch_align(P, Q):
    # P, Q: [L, 3]
    P_mean = P.mean(dim=0)
    Q_mean = Q.mean(dim=0)
    P_centered = P - P_mean
    Q_centered = Q - Q_mean

    H = P_centered.T @ Q_centered
    U, S, Vt = torch.svd(H)
    R = Vt @ U.T
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T
    P_aligned = P_centered @ R
    return P_aligned, Q_centered


def tm_score(P, Q):
    # P: predicted coords, Q: true coords, both [L, 3]
    L = P.shape[0]
    d0 = 0.0
    if L < 12:
        d0 = 0.3
    elif L < 16:
        d0 = 0.4
    elif L < 20:
        d0 = 0.5
    elif L < 24:
        d0 = 0.6
    elif L < 30:
        d0 = 0.7
    else:
        d0 = 0.6 * (L - 0.5) ** 0.5 - 2.5

    P_aligned, Q_centered = kabsch_align(P, Q)
    dist = torch.norm(P_aligned - Q_centered, dim=1)
    score = (1 / (1 + (dist / d0) ** 2)).mean()
    return score.item()

In [24]:
train_labels = "/srv/mingyang/dev/personal/rna-fold/data/train_labels.csv"
val_labels = "/srv/mingyang/dev/personal/rna-fold/data/validation_labels.csv"
id_to_loc = dict()
for csv in [train_labels, val_labels]:
    df = pd.read_csv(csv)
    for idx, row in df.iterrows():
        id = row["ID"]
        id = "_".join(id.split("_")[:-1])
        x_1, y_1, z_1 = row["x_1"], row["y_1"], row["z_1"]
        if id not in id_to_loc:
            id_to_loc[id] = [(x_1, y_1, z_1)]
        else:
            id_to_loc[id].append((x_1, y_1, z_1))

import torch
for idx in id_to_loc.keys():
    id_to_loc[idx] = torch.tensor(id_to_loc[idx], dtype=torch.float32, device=device)




844


In [31]:
with torch.inference_mode():
    seq_id = train_list[3]
    outputs = seq_id_to_data(seq_id)
    print(len(outputs))
    print(outputs[0].keys())
    for otp in outputs:
        preds = otp["cords_c1'"][0][0]
        score = tm_score(preds, id_to_loc[seq_id])
        print(score)

GGGAUAACUUCGGUUGUCCC
torch.Size([1, 11, 20])
torch.Size([1, 20])
10
dict_keys(['frames', 'unnormalized_angles', 'angles', 'single', 'cord_tns_pred', "cords_c1'", 'plddt', 'ss', 'p', 'c4_', 'n'])
0.0027740008663386106
0.003528591711074114
0.009859115816652775
0.009748130105435848
0.00966308917850256
0.009605878964066505
0.009564719162881374
0.009551241993904114
0.009551722556352615
0.009569044224917889


In [None]:
test_list = all_seq_ids(test_csv)
print(len(test_list))