# Filtering RFD2 outputs using PyRosetta and Biotite

You will find that many (most) of the backbones RF-diffusion 2 outputs is not suitable for sequence design. Common problems we see with the generated structures include:

**Chain breaks**\
**Insufficient interactions with ligand**\
**Unrealistic sidechains**\
**Unrealistically long helices**


In [1]:
import os, sys
import glob
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pyrosettacolabsetup py3dmol


Collecting pyrosettacolabsetup
  Downloading pyrosettacolabsetup-1.0.9-py3-none-any.whl.metadata (294 bytes)
Collecting py3dmol
  Downloading py3dmol-2.5.3-py2.py3-none-any.whl.metadata (2.1 kB)
Downloading pyrosettacolabsetup-1.0.9-py3-none-any.whl (4.9 kB)
Downloading py3dmol-2.5.3-py2.py3-none-any.whl (7.2 kB)
Installing collected packages: pyrosettacolabsetup, py3dmol
Successfully installed py3dmol-2.5.3 pyrosettacolabsetup-1.0.9


In [None]:
import pyrosettacolabsetup; pyrosettacolabsetup.install_pyrosetta()

Mounted at /content/google_drive

Note that USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRE PURCHASE OF A LICENSE.
See https://github.com/RosettaCommons/rosetta/blob/main/LICENSE.md or email license@uw.edu for details.

Looking for compatible PyRosetta wheel file at google-drive/PyRosetta/colab.bin//wheels...
Found compatible wheel: /content/google_drive/MyDrive/PyRosetta/colab.bin/wheels//content/google_drive/MyDrive/PyRosetta/colab.bin/wheels/pyrosetta-2025.43+release.b230e431d8-cp312-cp312-linux_x86_64.whl




In [None]:
!pip install biotite

Collecting biotite
  Downloading biotite-1.5.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (5.5 kB)
Collecting biotraj<2.0,>=1.0 (from biotite)
  Downloading biotraj-1.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (32 kB)
Downloading biotite-1.5.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (57.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.5/57.5 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading biotraj-1.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m75.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biotraj, biotite
Successfully installed biotite-1.5.0 biotraj-1.2.2


In [None]:
sys.path.append('/content/drive/MyDrive/rfd2/')

In [None]:
!pip install -e /content/google_drive/MyDrive/rfd2/

Obtaining file:///content/google_drive/MyDrive/rfd2
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: rf_diffusion
  Building editable for rf_diffusion (pyproject.toml) ... [?25l[?25hdone
  Created wheel for rf_diffusion: filename=rf_diffusion-0.0.0-py3-none-any.whl size=2017 sha256=fea898fca5db3684de751709508a95803431fc3c3db9b1e8c53efe232436b8d0
  Stored in directory: /tmp/pip-ephem-wheel-cache-y6a6wxa8/wheels/c0/d9/a2/7a47dd5b76ced4f21b9cd8090e7b27b6aebbebd8fe0fb1df68
Successfully built rf_diffusion
Installing collected packages: rf_diffusion
Successfully installed rf_diffusion-0.0.0


In [14]:
os.chdir('/content/drive/MyDrive/clean_enzdes_rfd2_tutorial/') #make sure this is the correct path!
working_dir = '/content/drive/MyDrive/clean_enzdes_rfd2_tutorial'
rfd2_pdbdir = f'{working_dir}/examples_for_pyrosetta_filtering'
os.path.exists(rfd2_pdbdir)

OSError: [Errno 107] Transport endpoint is not connected: '/content/drive/MyDrive/clean_enzdes_rfd2_tutorial/'

In [None]:
import pickle
from pyrosetta import *
from pyrosetta.rosetta import *
import pyrosetta.rosetta.core.scoring.dssp as dssp
from pyrosetta.rosetta.core.scoring import score_type_from_name
from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
from pyrosetta.rosetta.core.scoring.sasa import SasaCalc

import biotite
import biotite.structure.io.pdb as Pdb
from collections import defaultdict, Counter
from scipy.spatial import cKDTree
import numpy as np
import pandas as pd
import math
import py3Dmol


In [None]:
#obtains where in the sequence your theozyme residues are
def get_input_residues(trb):
    with open(trb, 'rb') as f:
        data = pickle.load(f)
        resis = data['con_hal_pdb_idx']
        resis = [int(i[1]) for i in resis]
        return resis


In [None]:
#use pyrosetta to get ROG, secondary structure, and backbone+sidechain quality

#The function below
def get_rosetta_scores(pose, sfx, sfx_cart, catres_list):
    """
    Getting sidechain quality scores for a given pose.
    Inspired by Florence Hardy and Aiko Muraishi
    """
    sfx(pose)
    sfx_cart(pose)

    scoredict = {}
    for k in ["cart_bonded", "fa_dun"]:
        scoredict[k] = {
        res.seqpos(): pose.energies().residue_total_energies(res.seqpos()).get(score_type_from_name(k))
        for res in pose.residues if res.seqpos() in catres_list}

    averages = {k: np.average(list(scores.values())) for k,scores in scoredict.items()}

    return averages, scoredict

def get_ROG(pose):
    centroid = np.array([np.average([res.xyz("CA").__getattribute__(c) for res in pose.residues if res.is_protein()]) for c in "xyz"])
    ROG = max([np.linalg.norm(centroid - res.xyz("CA")) for res in pose.residues if res.is_protein()])
    return ROG


def res_on_loop(secstruct, seq, resis): #how many catalytic residues were placed on a loop?
    c = 0
    for resi in resis:
        if secstruct[resi] == 'L':
            c+=1
    return c

def longest_helix(secstruct): #sometimes rfd2 makes backbones with unrealistically long helices
    if 'H' in secstruct:
        longest_helix = max([len(x.replace("E", "")) for x in secstruct.split("L") if "H" in x])
    else:
        longest_helix = 0
    return longest_helix

def get_SASA(pose):
    #get ligand index
    ligand_indices = [i for i in range(1, pose.size()+1) if pose.residue(i).is_ligand()]
    #make ligand only pose
    ligand_residue = pose.residue(ligand_indices[0])
    ligand_pose = Pose()
    ligand_pose.append_residue_by_jump(ligand_residue, 1)

    #setup sasa calc
    sasa_calc = rosetta.core.scoring.sasa.SasaCalc()
    # typical probe radius 1.4 A
    sasa_calc.set_probe_radius(1.4)
    # perform calculations
    sasa_calc.calculate(pose)
    rsd_sasa_full = sasa_calc.get_residue_sasa()   # utility.vector1_double, 1-based indexing

    sasa_calc.calculate(ligand_pose)
    rsd_sasa_free = sasa_calc.get_residue_sasa()

    # sum ligand residues in full pose (non-protein residues)
    total_bound_sasa = 0.0
    for i in range(1, pose.total_residue() + 1):
        res = pose.residue(i)
        if not res.is_protein():
            # vector1_double is 1-based; convert via [i]
            try:
                total_bound_sasa += float(rsd_sasa_full[i])
            except Exception:
                # if value missing or -1, skip
                pass

    # sum free ligand sasa (all residues in pose_lig should be ligand residues)
    total_free_sasa = 0.0
    for i in range(1, ligand_pose.total_residue() + 1):
        try:
            total_free_sasa += float(rsd_sasa_free[i])
        except Exception:
            pass

        # safety: if total_free_sasa is zero or nan, return nan (same semantics as original)
    if total_free_sasa == 0 or np.isnan(total_free_sasa):
        return np.nan

    return float(total_bound_sasa / total_free_sasa)

In [None]:
#Using biotite to get chain breaks and clashes

BREAK_CUTOFF = 2.4
STANDARD_ATOM_NAMES = {
    "N", "CA", "C", "O", "CB",
    "CG", "CD", "CE", "NZ", "OG", "SG",
    "OD1", "OD2", "OE1", "OE2",
    "ND1", "ND2", "NE", "NE1", "NE2",
    "SD", "CE1", "CE2", "CE3",
    "CZ", "CZ2", "CZ3", "CH2",
    "NH1", "NH2", "OH", "OG1",
    "CD1", "CD2", "CG1", "CG2",
    "ND3", "OXT", "H", "HA", "HB", "HG", "HD", "HE", "HH"
}


def get_atom_array(pdb):
    pdb_file = Pdb.PDBFile.read(pdb)
    atom_array = Pdb.get_structure(pdb_file, model=1, altloc='occupancy')
    return atom_array

def clean_atom_array(atom_array):
    """For UNK residues keep only standard atom names; otherwise keep all atoms."""
    is_unk = atom_array.res_name == "UNK"
    is_standard_atom = np.isin(atom_array.atom_name, list(STANDARD_ATOM_NAMES))
    keep_mask = (~is_unk) | (is_unk & is_standard_atom)
    return atom_array[keep_mask]


def chain_break_check(atm_array):
    """Return True if there's any chain break (distance between C(i) and N(i+1) > BREAK_CUTOFF)."""
    is_bb = np.isin(atm_array.atom_name, ["N", "C"])
    is_prot = np.isin(atm_array.res_name, [
        "ALA","CYS","ASP","GLU","PHE","GLY","HIS",
        "ILE","LYS","LEU","MET","ASN","PRO","GLN",
        "ARG","SER","THR","VAL","TRP","TYR"
    ])
    bb_atoms = atm_array[is_bb & is_prot]

    backbone_map = defaultdict(dict)
    for atom in bb_atoms:
        key = (atom.chain_id, int(atom.res_id))
        backbone_map[key][atom.atom_name] = atom.coord

    # iterate by chain
    chains = sorted(set(ch for ch, res in backbone_map.keys()))
    for chain in chains:
        res_ids = sorted(res for ch, res in backbone_map.keys() if ch == chain)
        for i in range(len(res_ids) - 1):
            r1 = res_ids[i]
            r2 = res_ids[i+1]
            c_coord = backbone_map[(chain, r1)].get("C")
            n_coord = backbone_map[(chain, r2)].get("N")
            if c_coord is None or n_coord is None:
                return True
            if np.linalg.norm(c_coord - n_coord) > BREAK_CUTOFF:
                return True
    return False


In [None]:
#make an empty dictionary to store metrics from rfd2 outputs

dic = {'pdb path':[],
       'pdb': [],
       'fixed residues': [],
       'chain break': [],
       'num on loop': [],
       'rog' : [],
       'loop frac': [],
       'helix frac': [],
       'cart bonded': [],
       'fa dun': [],
       'worst cart': [],
       'worst fa': [],
       'longest helix': [],
       'ligand SASA':[]
      }

In [None]:
trbs = glob.glob(f'{rfd2_pdbdir}/*trb')
params_file = f'{working_dir}/theozymes/QA0.params'
os.path.exists(params_file)

True

In [None]:
pyrosetta.init(f'-extra_res_fa {params_file} -mute all')
sfx = pyrosetta.get_fa_scorefxn()
sfx_cart = sfx.clone()
sfx_cart.set_weight(score_type_from_name("cart_bonded"), 0.5)
sfx_cart.set_weight(score_type_from_name("pro_close"), 0.0)
for trb in trbs:
    pdb = trb.replace('.trb', '.pdb')
    dic['pdb path'].append(pdb)
    dic['pdb'].append(pdb.split('/')[-1])
    input_residues = get_input_residues(trb)
    dic['fixed residues'].append(" ".join(str(i) for i in input_residues))
    atm_array = get_atom_array(pdb)
    atm_array = clean_atom_array(atm_array)
    dic['chain break'].append(chain_break_check(atm_array))
    pose = pose_from_pdb(pdb)
    dic['rog'].append(get_ROG(pose))
    dssp = pyrosetta.rosetta.core.scoring.dssp.Dssp(pose)
    secstruct = dssp.get_dssp_secstruct()
    seq = pose.sequence()
    dic['num on loop'].append(res_on_loop(secstruct, seq, input_residues))
    dic['longest helix'].append(longest_helix(secstruct))
    dic['loop frac'].append(secstruct.count("L") / pose.size())
    dic['helix frac'].append(secstruct.count("H") / pose.size())
    averages, scoredict = get_rosetta_scores(pose, sfx, sfx_cart, input_residues)
    dic['cart bonded'].append(averages['cart_bonded'])
    dic['fa dun'].append(averages['fa_dun'])
    dic['worst cart'].append(max([scoredict['cart_bonded'][k] for k in scoredict['cart_bonded']]))
    dic['worst fa'].append(max([scoredict['fa_dun'][k] for k in scoredict['fa_dun']]))
    dic['ligand SASA'].append(get_SASA(pose))

┌───────────────────────────────────────────────────────────────────────────────┐
│                                  PyRosetta-4                                  │
│               Created in JHU by Sergey Lyskov and PyRosetta Team              │
│               (C) Copyright Rosetta Commons Member Institutions               │
│                                                                               │
│ NOTE: USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRES PURCHASE OF A LICENSE │
│          See LICENSE.PyRosetta.md or email license@uw.edu for details         │
└───────────────────────────────────────────────────────────────────────────────┘
PyRosetta-4 2025 [Rosetta PyRosetta4.MinSizeRel.python312.ubuntu 2025.43+release.b230e431d8ef0bcdea01dbb0065ca62c7dd694ad 2025-10-15T17:07:03] retrieved from: http://www.pyrosetta.org




In [None]:
df = pd.DataFrame(dic)

In [None]:
i = 29
pdb_path = df.iloc[i]['pdb path']
with open(pdb_path, 'r') as f:
    pdb_data = f.read()
view = py3Dmol.view(width=600, height=400)
view.addModel(pdb_data, 'pdb')
view.setStyle({'cartoon': {'color': 'spectrum'}})
view.addStyle({'and': [
    {'resn': ['ALA','VAL','LEU','ILE','PHE','TYR','TRP','SER','THR','ASN','GLN','ASP','GLU','LYS','ARG','HIS','CYS','MET','PRO','GLY']}
]}, {'stick': {}})

# Sticks for ligands (anything that’s not a standard residue)
view.addStyle({'hetflag': True}, {'stick': {'colorscheme': 'greenCarbon'}})
view.addStyle({'atom': 'ZN'}, {'sphere': {'color': 'silver', 'radius': 1.0}})

print(f'pdb: {df.iloc[i]["pdb"]}')
print(f'SASA: {df.iloc[i]["ligand SASA"]}')
print(f'ROG: {df.iloc[i]["rog"]}')
print(f'loop: {df.iloc[i]["loop frac"]} \t helix: {df.iloc[i]["helix frac"]} \t sheet: {1-df.iloc[i]["loop frac"]-df.iloc[i]["helix frac"]}')

view.zoomTo()
view.show()

pdb: 1qji_theozyme_HEHHMY_ORI_91_LENGTH_140_160_tutorial_t1_1_2-atomized-bb-False.pdb
SASA: 0.7012632665395816
ROG: 22.15479600787047
loop: 0.1592356687898089 	 helix: 0.8407643312101911 	 sheet: 0.0


In [None]:
filtered = df.loc[df['chain break'] == False]
filtered = filtered.loc[filtered['num on loop'] <= 2]
filtered = filtered.loc[filtered['longest helix'] < 35]
filtered = filtered.loc[filtered['ligand SASA'] < 0.5]
filtered = filtered.loc[filtered['ligand SASA'] > 0.3]
filtered = filtered.loc[filtered['loop frac'] <= 0.2]

filtered.shape[0]

33

In [None]:
i = 2
pdb_path = filtered.iloc[i]['pdb path']
with open(pdb_path, 'r') as f:
    pdb_data = f.read()
view = py3Dmol.view(width=600, height=400)
view.addModel(pdb_data, 'pdb')
view.setStyle({'cartoon': {'color': 'spectrum'}})
view.addStyle({'and': [
    {'resn': ['ALA','VAL','LEU','ILE','PHE','TYR','TRP','SER','THR','ASN','GLN','ASP','GLU','LYS','ARG','HIS','CYS','MET','PRO','GLY']}
]}, {'stick': {}})

# Sticks for ligands (anything that’s not a standard residue)
view.addStyle({'hetflag': True}, {'stick': {'colorscheme': 'greenCarbon'}})
view.addStyle({'atom': 'ZN'}, {'sphere': {'color': 'silver', 'radius': 1.0}})

print(f'pdb: {filtered.iloc[i]["pdb"]}')
print(f'SASA: {filtered.iloc[i]["ligand SASA"]}')
print(f'ROG: {filtered.iloc[i]["rog"]}')
print(f'loop: {filtered.iloc[i]["loop frac"]} \t helix: {filtered.iloc[i]["helix frac"]} \t sheet: {1-filtered.iloc[i]["loop frac"]-filtered.iloc[i]["helix frac"]}')

view.zoomTo()
view.show()

pdb: 1qji_theozyme_HEHHMY_ORI_102_LENGTH_200_220_tutorial_t1_1_1-atomized-bb-False.pdb
SASA: 0.4951705295188125
ROG: 25.29791983200225
loop: 0.1650485436893204 	 helix: 0.7669902912621359 	 sheet: 0.06796116504854366


In [None]:

filtered.to_csv('')