In [1]:
import logging
import os
import sys

import numpy as np
import torch

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils import get_device, save_ss2ct, timing
from rhofold.relax.relax import AmberRelaxation
from rhofold.utils.alphabet import get_features

def seq_id_to_embedding(seq_id):
    """
    Get Evo2 embeddings for a given sequence ID from a .pt file
    """
    embeddings_dir = "/home/user/rna-fold/data/RNA3D_DATA/evo2_embeddings"
    file_path = os.path.join(embeddings_dir, f"{seq_id}.pt")
    
    if os.path.exists(file_path):
        return torch.load(file_path)
    else:
        raise FileNotFoundError(f"Embedding file for {seq_id} not found at {file_path}")

@torch.no_grad()
def main(ckpt='./pretrained/RhoFold_pretrained.pt'):
    model = RhoFold(rhofold_config)
    model.load_state_dict(torch.load(ckpt, map_location=torch.device('cpu'))['model'], strict=False)
    model.eval()

    return model

def inference(seq_id='165d_B', normal_inf=False):
    device = get_device('cpu')
    model = main().to(device)
    print("Number of params:", sum(p.numel() for p in model.parameters()))
    input_fas = f'../data/RNA3D_DATA/seq/{seq_id}.seq'
    input_a3m = f'../data/RNA3D_DATA/rMSA/{seq_id}.a3m'
    data_dict = get_features(input_fas, input_a3m)
    embedding = seq_id_to_embedding(seq_id).to(torch.float32)

    outputs = model(tokens=data_dict['tokens'].to(device),
                    rna_fm_tokens=data_dict['rna_fm_tokens'].to(device),
                    seq=data_dict['seq'],
                    evo2_fea=embedding if not normal_inf else None,
                    )

    output = outputs[-1]

    unrelaxed_model = f'tmp/unrelaxed_model.pdb'

    node_cords_pred = output['cord_tns_pred'][-1].squeeze(0)
    model.structure_module.converter.export_pdb_file(data_dict['seq'],
                                                        node_cords_pred.data.cpu().numpy(),
                                                        path=unrelaxed_model, chain_id=None,
                                                        confidence=output['plddt'][0].data.cpu().numpy())
    return output



In [2]:
import sys
openfold_path = os.path.abspath(os.path.join("..", "openfold"))
if openfold_path not in sys.path:
    sys.path.insert(0, openfold_path)
from typing import Optional, Union, Dict, Tuple

# PDB parser
from Bio.PDB import PDBParser

# openfold utilities
from openfold.utils.rigid_utils import Rigid, Rotation
from openfold.utils.loss import compute_fape, AlphaFoldLoss
import ml_collections

def compute_fape_from_output(
    output: Dict[str, torch.Tensor],
    pdb_path: str,
    chain_id: Optional[str] = None,
    length_scale: float = 10.0,
    l1_clamp_distance: Optional[float] = None,
) -> torch.Tensor:
    """
    Compute a single‐scalar FAPE loss between RhoFold model output and a ground-truth RNA PDB.

    Args:
        output: RhoFold model output dict containing 'frames' and "cords_c1'" keys.
        pdb_path: path to the ground-truth RNA PDB file.
        chain_id: which chain in the PDB to use; if None, picks the first.
        length_scale: FAPE length‐scale (Å) divisor.
        l1_clamp_distance: if set, clamps distances above this value.

    Returns:
        A differentiable scalar torch.Tensor.
    """
    # 1) Extract predicted frames
    pred_frames_tensor = output["frames"]
    
    # Handle the dimensions appropriately
    if isinstance(pred_frames_tensor, list):
        pred_frames_tensor = pred_frames_tensor[-1]  # Take last element if list
    
    # Extract the last recycle frame if there's a recycle dimension
    if pred_frames_tensor.dim() == 4:  # [recycle, batch, N, 7]
        pred_frames_tensor = pred_frames_tensor[-1]  # [batch, N, 7]
    
    # Remove batch dim if batch=1
    if pred_frames_tensor.dim() == 3 and pred_frames_tensor.shape[0] == 1:
        pred_frames_tensor = pred_frames_tensor.squeeze(0)  # [N, 7]
    
    # Convert to Rigid object
    pred_frames = Rigid.from_tensor_7(pred_frames_tensor)
    
    # Extract predicted C1' positions
    pred_c1_positions = output["cords_c1'"][-1].squeeze(0)  # Shape [N, 3]
    
    # 2) Parse the PDB to get ground truth C1' positions
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("gt", pdb_path)
    if chain_id is None:
        chain = next(structure.get_chains())
    else:
        chain = structure[0][chain_id]
    
    # Extract C1' atoms from the PDB structure
    gt_c1_positions = []
    for res in chain:
        if "C1'" in res:
            gt_c1_positions.append(res["C1'"].get_coord())
    
    assert len(gt_c1_positions) == pred_c1_positions.shape[0]
    
    # Convert to tensor with same dtype and device as predictions
    dtype = pred_frames_tensor.dtype
    device = pred_frames_tensor.device
    gt_c1_positions = torch.tensor(gt_c1_positions, dtype=dtype, device=device)
    
    # Create same-length arrays (truncate if necessary)
    n_residues = min(pred_c1_positions.shape[0], gt_c1_positions.shape[0])
    pred_c1_positions = pred_c1_positions[:n_residues]
    gt_c1_positions = gt_c1_positions[:n_residues]
    pred_frames = pred_frames[:n_residues]
    
    # Create a mask for the positions (all 1s since we've truncated to match)
    mask = torch.ones(n_residues, device=device)
    
    # 3) Create target frames
    # For RNA, we'll create an identity rotation frame with C1' positions as translations
    # Create identity quaternions: [w, x, y, z] with w=1, x,y,z=0
    zeros = torch.zeros(n_residues, 3, device=device, dtype=dtype)
    ones = torch.ones(n_residues, 1, device=device, dtype=dtype)
    quats = torch.cat([ones, zeros], dim=-1)
    
    # Create target frames with identity rotations and ground truth C1' positions
    target_frames = Rigid(
        Rotation(quats=quats, rot_mats=None),
        gt_c1_positions
    )
    
    # 4) Compute FAPE
    fape = compute_fape(
        pred_frames=pred_frames,
        target_frames=target_frames,
        frames_mask=mask,
        pred_positions=pred_c1_positions,
        target_positions=gt_c1_positions,
        positions_mask=mask,
        length_scale=length_scale,
        pair_mask=None,
        l1_clamp_distance=l1_clamp_distance,
    )
    
    return fape

In [4]:
seq_id = '8psh_B'
normal_output = inference(seq_id=seq_id, normal_inf=True)
special_output = inference(seq_id=seq_id, normal_inf=False)

Number of params: 130059855
Number of params: 130059855


In [5]:
print("Sequence length:", normal_output["plddt"][0].shape[1])

Sequence length: 8


In [6]:
normal_L = compute_fape_from_output(normal_output, f"../data/RNA3D_DATA/pdb/{seq_id}.pdb")
special_L = compute_fape_from_output(special_output, f"../data/RNA3D_DATA/pdb/{seq_id}.pdb")



In [7]:
print(special_L)
print(normal_L)

tensor(2.9983, grad_fn=<DivBackward0>)
tensor(2.9983, grad_fn=<DivBackward0>)


In [8]:
print(special_output["plddt"][1])
print(normal_output["plddt"][1])

tensor([0.6061], grad_fn=<MeanBackward1>)
tensor([0.6061], grad_fn=<MeanBackward1>)


In [3]:
list(normal_output.keys())

['frames',
 'unnormalized_angles',
 'angles',
 'single',
 'cord_tns_pred',
 "cords_c1'",
 'plddt',
 'ss',
 'p',
 'c4_',
 'n']

In [9]:
special_L.backward()

In [10]:
normal_output["frames"][0]

tensor([[[  0.4549,   0.1827,   0.3823,  -0.7833,   4.3325,  10.6797,  -0.9257],
         [  0.6294,   0.4387,   0.3185,  -0.5568,   2.1081,   5.0937,   3.0105],
         [  0.7294,   0.5880,   0.2750,  -0.2158,  -0.5881,  -1.5042,   4.5593],
         [  0.7802,   0.5562,   0.2772,   0.0714,  -0.3522,  -7.9385,   2.5301],
         [  0.8623,   0.2562,   0.2818,   0.3339,  -0.7047, -12.9255,   1.7130],
         [  0.6348,  -0.6304,  -0.4455,  -0.0342,  -0.9307,  -7.9234,  -5.4222],
         [  0.6770,  -0.7215,  -0.0331,   0.1413,  -5.5291,  -2.1601,  -2.5444],
         [  0.5620,  -0.7495,   0.2591,   0.2351,  -6.6236,   6.9702,   0.2256],
         [  0.5077,   0.5388,  -0.2539,   0.6225,  -0.9057,   5.7893,  -5.3306]]],
       grad_fn=<SelectBackward0>)

In [11]:
special_output["frames"][0]

tensor([[[  0.4549,   0.1827,   0.3823,  -0.7833,   4.3326,  10.6789,  -0.9256],
         [  0.6294,   0.4387,   0.3185,  -0.5568,   2.1084,   5.0929,   3.0104],
         [  0.7294,   0.5880,   0.2750,  -0.2158,  -0.5878,  -1.5046,   4.5585],
         [  0.7802,   0.5562,   0.2772,   0.0713,  -0.3519,  -7.9391,   2.5294],
         [  0.8623,   0.2562,   0.2818,   0.3338,  -0.7046, -12.9258,   1.7119],
         [  0.6348,  -0.6304,  -0.4455,  -0.0342,  -0.9306,  -7.9234,  -5.4226],
         [  0.6770,  -0.7215,  -0.0331,   0.1413,  -5.5289,  -2.1602,  -2.5448],
         [  0.5620,  -0.7495,   0.2591,   0.2351,  -6.6234,   6.9700,   0.2253],
         [  0.5077,   0.5388,  -0.2539,   0.6225,  -0.9059,   5.7885,  -5.3302]]],
       grad_fn=<SelectBackward0>)

In [11]:
output["cords_c1'"]

[tensor([[[  7.5441,   8.0379,   3.7535],
          [  3.7557,   4.4225,   5.5446],
          [  1.0021,  -0.6748,   5.1964],
          [  0.7398,  -5.7531,   3.2155],
          [  1.4135, -10.5815,   2.6508],
          [  0.2612,  -6.4298,  -3.9785],
          [ -2.8647,  -0.6457,  -1.4059],
          [ -2.6314,   4.8960,   1.4187],
          [ -0.1994,   7.6226,   1.2691]]], grad_fn=<UnsqueezeBackward0>)]

In [19]:
output["frames"].shape

torch.Size([8, 1, 9, 7])

In [14]:
output['frames'].shape

torch.Size([8, 1, 9, 7])