# Stanford RNA 3D Folding Part 2 - RhoFold+ Inference

This notebook predicts RNA 3D structures using RhoFold+ and generates the submission file.

**Submission Format:**
- 5 predictions per target (best-of-5 TM-score used)
- C1' atom coordinates only
- Columns: ID, resname, resid, x_1,y_1,z_1 ... x_5,y_5,z_5

In [None]:
# Install RhoFold+ dependencies
!pip install -q einops biopython ml-collections dm-tree tqdm pyyaml scipy matplotlib pandas transformers

In [None]:
# Clone and setup RhoFold+
!git clone https://github.com/ml4bio/RhoFold.git
%cd RhoFold
!pip install -q -e .

In [None]:
# Download pretrained weights
!mkdir -p pretrained
!wget -q https://huggingface.co/cuhkaih/rhofold/resolve/main/rhofold_pretrained_params.pt -O pretrained/RhoFold_pretrained.pt
print("Pretrained weights downloaded")

In [None]:
import os
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Load competition data
test_seqs = pd.read_csv('/kaggle/input/stanford-rna-3d-folding-2/test_sequences.csv')
sample_sub = pd.read_csv('/kaggle/input/stanford-rna-3d-folding-2/sample_submission.csv')

print(f"Test sequences: {len(test_seqs)}")
print(f"Sample submission rows: {len(sample_sub)}")
print(test_seqs.head())

In [None]:
# Fix OpenMM import for newer versions
import sys
rhofold_relax_path = Path('rhofold/relax/relax.py')
if rhofold_relax_path.exists():
    content = rhofold_relax_path.read_text()
    if 'from simtk.openmm' in content:
        content = content.replace(
            'from simtk.openmm.app import *\nfrom simtk.openmm import *\nfrom simtk.unit import *\nimport simtk.openmm as mm',
            '''try:
    from simtk.openmm.app import *
    from simtk.openmm import *
    from simtk.unit import *
    import simtk.openmm as mm
except ImportError:
    from openmm.app import *
    from openmm import *
    from openmm.unit import *
    import openmm as mm'''
        )
        rhofold_relax_path.write_text(content)
        print("Fixed OpenMM imports")

In [None]:
# Import RhoFold
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils.alphabet import get_features

# Initialize model
config = rhofold_config()
model = RhoFold(config)

# Load weights
ckpt = torch.load('pretrained/RhoFold_pretrained.pt', map_location='cpu')
model.load_state_dict(ckpt['model'])
model = model.to(device)
model.eval()

print("RhoFold+ model loaded")

In [None]:
def predict_structure(sequence, model, device, n_recycles=3):
    """Predict RNA structure and return C1' coordinates."""
    # Get features
    features = get_features(sequence)
    
    # Move to device
    for k, v in features.items():
        if isinstance(v, torch.Tensor):
            features[k] = v.to(device)
    
    # Predict
    with torch.no_grad():
        output = model(features, n_recycles=n_recycles)
    
    # Extract C1' coordinates (atom index 1 in RhoFold output)
    # RhoFold predicts: C4', C1', N1/N9, C2, C6
    coords = output['positions'][-1].cpu().numpy()  # [seq_len, n_atoms, 3]
    c1_prime_coords = coords[:, 1, :]  # C1' is index 1
    
    return c1_prime_coords

def predict_with_seeds(sequence, model, device, seeds=[42, 123, 456, 789, 1234]):
    """Generate 5 predictions with different seeds."""
    predictions = []
    
    for seed in seeds:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        
        try:
            coords = predict_structure(sequence, model, device)
            predictions.append(coords)
        except Exception as e:
            print(f"  Seed {seed} failed: {e}")
            predictions.append(None)
    
    return predictions

In [None]:
# Run predictions for all targets
all_predictions = {}

for idx, row in tqdm(test_seqs.iterrows(), total=len(test_seqs)):
    target_id = row['target_id']
    sequence = row['sequence']
    
    print(f"\n[{idx+1}/{len(test_seqs)}] {target_id} (len={len(sequence)})")
    
    # Skip very long sequences (>1500 nt) due to memory
    if len(sequence) > 1500:
        print(f"  Skipping - too long for GPU memory")
        all_predictions[target_id] = [None] * 5
        continue
    
    try:
        preds = predict_with_seeds(sequence, model, device)
        all_predictions[target_id] = preds
        print(f"  Success: {sum(1 for p in preds if p is not None)}/5 predictions")
    except Exception as e:
        print(f"  Failed: {e}")
        all_predictions[target_id] = [None] * 5

print(f"\nCompleted {len(all_predictions)} targets")

In [None]:
# Generate submission file
submission_rows = []

for idx, row in sample_sub.iterrows():
    row_id = row['ID']
    resname = row['resname']
    resid = row['resid']
    
    # Parse target_id from ID (format: TARGET_RESIDUE)
    parts = row_id.rsplit('_', 1)
    target_id = parts[0]
    res_idx = int(parts[1]) - 1  # 0-indexed
    
    # Get predictions for this target
    preds = all_predictions.get(target_id, [None] * 5)
    
    # Build row with coordinates from 5 models
    new_row = {
        'ID': row_id,
        'resname': resname,
        'resid': resid
    }
    
    for model_idx in range(5):
        pred = preds[model_idx] if model_idx < len(preds) else None
        
        if pred is not None and res_idx < len(pred):
            x, y, z = pred[res_idx]
        else:
            x, y, z = 0.0, 0.0, 0.0
        
        new_row[f'x_{model_idx+1}'] = round(x, 3)
        new_row[f'y_{model_idx+1}'] = round(y, 3)
        new_row[f'z_{model_idx+1}'] = round(z, 3)
    
    submission_rows.append(new_row)

submission = pd.DataFrame(submission_rows)
print(f"Submission shape: {submission.shape}")
print(submission.head())

In [None]:
# Validate submission
expected_cols = ['ID', 'resname', 'resid', 'x_1', 'y_1', 'z_1', 'x_2', 'y_2', 'z_2',
                 'x_3', 'y_3', 'z_3', 'x_4', 'y_4', 'z_4', 'x_5', 'y_5', 'z_5']

assert list(submission.columns) == expected_cols, "Column mismatch!"
assert len(submission) == len(sample_sub), f"Row count mismatch: {len(submission)} vs {len(sample_sub)}"
assert not submission.isnull().any().any(), "Contains NaN values!"

print("Validation passed!")
print(f"Rows: {len(submission)}")
print(f"Non-zero predictions: {(submission['x_1'] != 0).sum()}")

In [None]:
# Save submission
submission.to_csv('submission.csv', index=False)
print("Saved submission.csv")

# Verify file
!head -5 submission.csv
!wc -l submission.csv