# CompareBinders: AlphaFold Binder Benchmark

This notebook contrasts a reference binder with a candidate binder design against the same target structure. It reuses the local BindCraft toolchain (ColabDesign + optional PyRosetta analytics) to produce side-by-side metrics


## Workflow
1. Configure the target structure plus the base and candidate binders (PDB or FASTA).
2. Extract clean amino-acid sequences and sanity-check lengths.
3. Run AlphaFold/ColabDesign binder predictions for each binder across the selected model heads.
4. Optionally evaluate PyRosetta interface energetics if the toolkit is available.
5. Aggregate per-model and summary statistics under `Results/CompareBinders` for downstream reporting.


In [9]:
import os
import copy
import functools
from pathlib import Path
from typing import Dict, Iterable, Optional, Tuple

import numpy as np
import pandas as pd
from IPython.display import display

try:
    import jax
    from jax import tree_util
except ImportError as exc:
    raise ImportError('JAX is required for ColabDesign-based evaluation. Install it inside the environment.') from exc

if not hasattr(jax, 'tree_map'):
    jax.tree_map = tree_util.tree_map  # type: ignore[attr-defined]

try:
    import jax.util as jax_util  # type: ignore
except ImportError:
    jax_util = None
if jax_util is not None and not hasattr(jax_util, 'wraps'):
    def _jax_util_wraps(wrapped, *, docstr=None,
                        assigned=("__module__", "__name__", "__qualname__", "__doc__", "__annotations__"),
                        updated=("__dict__",)):
        def decorator(func):
            result = functools.wraps(wrapped, assigned=assigned, updated=updated)(func)
            if docstr is not None:
                result.__doc__ = docstr
            return result
        return decorator
    jax_util.wraps = _jax_util_wraps  # type: ignore[attr-defined]

try:
    from colabdesign import mk_afdesign_model, clear_mem
except ImportError as exc:
    raise ImportError('ColabDesign is needed. Install it with `pip install colabdesign` in this environment.') from exc

try:
    from bindcraft.functions.pyrosetta_utils import score_interface
    import pyrosetta
    PYROSETTA_AVAILABLE = True
except Exception:
    score_interface = None
    pyrosetta = None
    PYROSETTA_AVAILABLE = False


In [10]:
root_dir = Path(r'/mnt/e/Code/BindCraft').resolve()
if not root_dir.exists():
    raise FileNotFoundError(f'Expected BindCraft root at {root_dir}')

input_candidates = [root_dir / 'InputTargets', root_dir / 'inputtargets']
for cand in input_candidates:
    if cand.exists():
        input_dir = cand.resolve()
        break
else:
    input_dir = (root_dir / 'InputTargets').resolve()
    input_dir.mkdir(parents=True, exist_ok=True)

results_candidates = [root_dir / 'Results', root_dir / 'results']
for cand in results_candidates:
    if cand.exists():
        results_dir = cand.resolve()
        break
else:
    results_dir = (root_dir / 'Results').resolve()
    results_dir.mkdir(parents=True, exist_ok=True)

af_params_dir = (root_dir / 'bindcraft' / 'params').resolve()
if not af_params_dir.exists():
    raise FileNotFoundError(f'AlphaFold parameters not found at {af_params_dir}')
os.environ.setdefault('AF_PARAMS_DIR', str(af_params_dir))

output_base = (results_dir / 'CompareBinders').resolve()
output_base.mkdir(parents=True, exist_ok=True)

dalphaball_path = (root_dir / 'bindcraft' / 'functions' / 'DAlphaBall.gcc').resolve()


In [11]:
# --- User configuration ---
TARGET_PDB = input_dir / 'HumanLysozyme.pdb'
TARGET_CHAIN = 'B'
TARGET_LABEL = TARGET_PDB.stem if TARGET_PDB else 'target'

BASE_BINDER_PDB = input_dir / 'HL6_camel_VHH_fragment.pdb'
BASE_BINDER_FASTA = None  # Optional FASTA override.
BASE_BINDER_CHAINS = ['A']
BASE_BINDER_LABEL = 'HL6 camel VHH template'
BASE_BINDER_LEN = None  # Set to int to force a specific length.

CANDIDATE_BINDER_PDB = results_dir / 'MutaCraft' / 'HL6_VHH_run' / 'hybrid' / 'best_complex.pdb'
CANDIDATE_BINDER_FASTA = None
CANDIDATE_BINDER_CHAINS = ['B']
CANDIDATE_BINDER_LABEL = 'Mutacraft design'
CANDIDATE_BINDER_LEN = None

EVAL_MODELS = [0]  # AlphaFold model heads to sample.
NUM_RECYCLES = 1
USE_MULTIMER = False
SAVE_PDBS = False # Whether to save AlphaFold-predicted PDBs.
PREDICTED_BINDER_CHAIN = 'B'  # Binder chain ID in AF2 outputs.


In [12]:
AA = 'ACDEFGHIKLMNPQRSTVWY'
aa_to_idx = {a: i for i, a in enumerate(AA)}

THREE_TO_ONE: Dict[str, str] = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
    'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}

def read_fasta_str(fasta_str: str) -> str:
    lines = [line.strip() for line in fasta_str.splitlines() if line.strip()]
    if lines and lines[0].startswith('>'):
        lines = lines[1:]
    return ''.join(lines).replace(' ', '').replace('	', '').upper()

def extract_pdb_sequence(pdb_path: Path, chains: Optional[Iterable[str]] = None) -> str:
    '''Extract a concatenated sequence from the specified PDB chains.'''
    def _parse_resseq(val: str) -> Tuple[int, str]:
        val = val.strip()
        number = []
        suffix = []
        for ch in val:
            if ch.isdigit() or (ch == '-' and not number):
                number.append(ch)
            else:
                suffix.append(ch)
        num_val = int(''.join(number)) if number else 0
        return num_val, ''.join(suffix)

    if chains is not None:
        chain_order = list(chains)
    else:
        chain_order = []
    residues_by_chain: Dict[str, Dict[Tuple[str, str, str], str]] = {}
    with Path(pdb_path).open() as handle:
        for line in handle:
            if not line.startswith(('ATOM', 'HETATM')):
                continue
            chain_id = line[21] if len(line) > 21 else ' '
            chain_id = chain_id if chain_id.strip() else ' '
            if chains is not None and chain_id not in chain_order:
                continue
            resseq = line[22:26]
            icode = line[26]
            key = (chain_id, resseq, icode)
            chain_dict = residues_by_chain.setdefault(chain_id, {})
            if key in chain_dict:
                continue
            resname = line[17:20].strip().upper()
            aa = THREE_TO_ONE.get(resname, 'X')
            chain_dict[key] = aa
    if chains is None:
        chain_order = sorted(residues_by_chain.keys())
    sequence_parts = []
    for chain_id in chain_order:
        residues = residues_by_chain.get(chain_id, {})
        sorted_keys = sorted(residues.keys(), key=lambda x: (*_parse_resseq(x[1]), x[2]))
        for key in sorted_keys:
            sequence_parts.append(residues[key])
    return ''.join(sequence_parts)

def load_binder_sequence(label: str, fasta_path: Optional[Path], pdb_path: Optional[Path], chains: Optional[Iterable[str]]) -> str:
    if fasta_path and Path(fasta_path).exists():
        seq = read_fasta_str(Path(fasta_path).read_text())
        if seq:
            return seq
    if pdb_path and Path(pdb_path).exists():
        seq = extract_pdb_sequence(Path(pdb_path), chains)
        if seq:
            return seq
    raise ValueError(f'Could not recover a sequence for {label}. Provide a valid FASTA or PDB.')

def sanitize_label(label: str) -> str:
    safe = ''.join(ch if ch.isalnum() or ch in ('-', '_') else '_' for ch in label)
    safe = safe.strip('_')
    return safe or 'binder'


In [13]:
PYROSETTA_FLAGS = '-ignore_unrecognized_res -ignore_zero_occupancy -mute all'

def ensure_pyrosetta_ready() -> bool:
    if not PYROSETTA_AVAILABLE:
        return False
    if getattr(ensure_pyrosetta_ready, '_initialized', False):
        return True
    flags = [PYROSETTA_FLAGS]
    if dalphaball_path.exists():
        flags.append(f'-holes:dalphaball {dalphaball_path}')
    try:
        pyrosetta.init(' '.join(flags))
    except Exception as exc:
        print(f'PyRosetta init failed: {exc}')
        return False
    ensure_pyrosetta_ready._initialized = True
    return True

def compute_interface_metrics(pdb_path: Path, binder_chain: str = 'B') -> Dict[str, object]:
    if not ensure_pyrosetta_ready() or score_interface is None:
        return {}
    scores, aa_counts, residue_ids = score_interface(str(pdb_path), binder_chain=binder_chain)
    prefixed = {f'iface_{k}': v for k, v in scores.items()}
    prefixed['iface_residue_ids'] = residue_ids
    prefixed['iface_aa_counts'] = aa_counts
    return prefixed

def predict_binder_complex(
    binder_label: str,
    binder_sequence: str,
    binder_len: Optional[int],
    target_pdb: Path,
    target_chain: str,
    models: Iterable[int],
    num_recycles: int,
    use_multimer: bool,
    save_dir: Path,
    save_pdbs: bool,
    binder_chain: str,
) -> Tuple[pd.DataFrame, Dict[str, object]]:
    sequence = ''.join([aa for aa in binder_sequence.upper() if aa.isalpha()])
    if not sequence:
        raise ValueError(f'Binder `{binder_label}` produced an empty sequence.')
    if binder_len is None:
        binder_len = len(sequence)
    if len(sequence) != binder_len:
        if len(sequence) > binder_len:
            sequence = sequence[:binder_len]
        else:
            raise ValueError(f'Sequence for {binder_label} shorter than requested length ({len(sequence)} < {binder_len}).')
    save_dir.mkdir(parents=True, exist_ok=True)
    clear_mem()
    model = mk_afdesign_model(protocol='binder', data_dir=str(af_params_dir), use_multimer=use_multimer, num_recycles=num_recycles, best_metric='loss')
    model.prep_inputs(pdb_filename=str(target_pdb), chain=target_chain, binder_len=binder_len)
    records = []
    pdb_outputs = []
    for model_idx in models:
        model.predict(seq=sequence, models=[int(model_idx)], num_recycles=num_recycles, verbose=False)
        aux = copy.deepcopy(model.aux)
        metrics = aux.get('log', {})
        records.append({
            'binder': binder_label,
            'model_index': int(model_idx),
            'loss': float(aux.get('loss', np.nan)),
            'plddt': float(metrics.get('plddt', np.nan)),
            'ptm': float(metrics.get('ptm', np.nan)),
            'i_ptm': float(metrics.get('i_ptm', np.nan)),
            'pae': float(metrics.get('pae', np.nan)),
            'i_pae': float(metrics.get('i_pae', np.nan)),
        })
        if save_pdbs:
            out_path = save_dir / f'{sanitize_label(binder_label)}_model{int(model_idx) + 1}.pdb'
            model.save_pdb(str(out_path))
            pdb_outputs.append(out_path)
    per_model_df = pd.DataFrame(records)
    summary: Dict[str, object] = {
        'binder': binder_label,
        'binder_length': binder_len,
        'sequence': sequence,
        'prediction_pdbs': [str(p) for p in pdb_outputs],
    }
    if not per_model_df.empty:
        summary.update({
            'plddt_mean': float(per_model_df['plddt'].mean()),
            'plddt_std': float(per_model_df['plddt'].std(ddof=0)),
            'ptm_mean': float(per_model_df['ptm'].mean()),
            'ptm_std': float(per_model_df['ptm'].std(ddof=0)),
            'i_ptm_mean': float(per_model_df['i_ptm'].mean()),
            'i_ptm_std': float(per_model_df['i_ptm'].std(ddof=0)),
            'pae_mean': float(per_model_df['pae'].mean()),
            'pae_std': float(per_model_df['pae'].std(ddof=0)),
            'i_pae_mean': float(per_model_df['i_pae'].mean()),
            'i_pae_std': float(per_model_df['i_pae'].std(ddof=0)),
            'loss_mean': float(per_model_df['loss'].mean()),
            'loss_std': float(per_model_df['loss'].std(ddof=0)),
        })
    if pdb_outputs:
        try:
            iface_stats = compute_interface_metrics(pdb_outputs[0], binder_chain=binder_chain)
            if iface_stats:
                summary.update(iface_stats)
        except Exception as exc:
            print(f'Interface scoring failed for {binder_label}: {exc}')
    return per_model_df, summary


In [14]:
binder_specs = [
    {
        'label': BASE_BINDER_LABEL,
        'fasta': BASE_BINDER_FASTA,
        'pdb': BASE_BINDER_PDB,
        'chains': BASE_BINDER_CHAINS,
        'length': BASE_BINDER_LEN,
    },
    {
        'label': CANDIDATE_BINDER_LABEL,
        'fasta': CANDIDATE_BINDER_FASTA,
        'pdb': CANDIDATE_BINDER_PDB,
        'chains': CANDIDATE_BINDER_CHAINS,
        'length': CANDIDATE_BINDER_LEN,
    },
]

comparison_rows = []
per_model_tables = []
for spec in binder_specs:
    label = spec['label']
    sequence = load_binder_sequence(label, spec['fasta'], spec['pdb'], spec['chains'])
    length = spec['length'] if spec['length'] is not None else len(sequence)
    binder_dir = output_base / sanitize_label(label)
    per_model_df, summary = predict_binder_complex(
        binder_label=label,
        binder_sequence=sequence,
        binder_len=length,
        target_pdb=TARGET_PDB,
        target_chain=TARGET_CHAIN,
        models=EVAL_MODELS,
        num_recycles=NUM_RECYCLES,
        use_multimer=USE_MULTIMER,
        save_dir=binder_dir,
        save_pdbs=SAVE_PDBS,
        binder_chain=PREDICTED_BINDER_CHAIN,
    )
    if not per_model_df.empty:
        per_model_df.to_csv(binder_dir / 'per_model_metrics.csv', index=False)
    comparison_rows.append(summary)
    per_model_tables.append(per_model_df.assign(output_dir=str(binder_dir)))
comparison_df = pd.DataFrame(comparison_rows)
summary_path = output_base / f'{sanitize_label(TARGET_LABEL)}_comparison_summary.csv'
comparison_df.to_csv(summary_path, index=False)
print(f'Summary metrics saved to: {summary_path}')
if not PYROSETTA_AVAILABLE:
    print('PyRosetta not detected; interface metrics columns may be empty.')
display(comparison_df)


Summary metrics saved to: /mnt/e/Code/BindCraft/Results/CompareBinders/HumanLysozyme_comparison_summary.csv
PyRosetta not detected; interface metrics columns may be empty.


Unnamed: 0,binder,binder_length,sequence,prediction_pdbs,plddt_mean,plddt_std,ptm_mean,ptm_std,i_ptm_mean,i_ptm_std,pae_mean,pae_std,i_pae_mean,i_pae_std,loss_mean,loss_std
0,HL6 camel VHH template,229,QVQLQESGGGSVQAGGSLRLSCSASGYTYISGWFRQAPGKEREGVA...,[],0.397544,0.0,0.439117,0.0,0.154836,0.0,0.766483,0.0,0.881567,0.0,6.450178,0.0
1,Mutacraft design,229,MPAQMQAQQKQQQEMQQQQQRRRQDEQRYKMERRREQQNMQQLQQE...,[],0.494419,0.0,0.44309,0.0,0.155191,0.0,0.761355,0.0,0.853633,0.0,5.961434,0.0


## Suggested Reporting Metrics
- **AlphaFold confidence (pLDDT / pTM / iPTM / pAE / iPAE):** Quantify structural certainty for both the binder core and the interface. The per-model table exposes variance so you can report robustness across heads.
- **Interface thermodynamics (Rosetta ΔG, ΔSASA, Packstat):** Capture binding energetics, burial, and packing if PyRosetta is available. Highlight improvements or trade-offs versus the seed binder.
- **Interface residue composition:** Track the count and identity of contacting residues (hydrophobic fraction, hydrogen-bond density, buried unsatisfied H-bonds) to discuss how Mutacraft reshapes the binding surface.
- **Binder structural confidence:** Use binder-only pLDDT (from `per_model_metrics.csv`) or rerun monomer predictions to ensure the designed binder is intrinsically stable.
- **Sequence divergence:** Report sequence identity between the seed and designed binder within guided vs unguided regions to illustrate how the hybrid strategy explores sequence space.
- **Kinetic/validation hooks:** Wet-lab validation, pair these metrics with docking scores or experimental assays
