# TO DO



# Notes

- Makes use of AlphaFold weights, doesn't use BindCraft code/scripts outside of warm start
- Supports AF2 monomer or multimer via USE_AF_MULTIMER

# MutaCraft: Hybrid Seed + De Novo Binder Generation

Run this notebook inside your local WSL installation. It swaps the standard binder initialization for a hybrid strategy that mixes a user-provided seed with random exploration, then steps through a four-stage AlphaFold2 optimization (continuous → softmax → STE → semi-greedy).

You can point the notebook at any target structure under `InputTargets`, provide a template binder (optional), and supply a seed FASTA string or file. Outputs are written to the `Results` directory so they stay consistent with the rest of the tooling.

### Stage Flow
- Stage 1 – Continuous logits with a guided seed ratio anneal plus optional guided KL prior.
- Stage 2 – Temperature annealing down to a hard distribution while the KL weight tapers out.
- Stage 3 – Straight-through (hard) updates for sharper convergence.
- Stage 4 – Semi-greedy mutation search with optional bias toward the seed amino acids at guided sites.


# Changelog

6 Jan 2026:
- Added warm start if no template binder: If BindCraft provided InterfaceResidues for the selected warm start seed, those positions become the guided mask, else Random mask using GUIDED_FRACTION
- Added USE_AF_MULTIMER to toggle between multimer and monomer

In [1]:
import os
import copy
import json
import re
import sys
import time
import subprocess
import functools
from pathlib import Path
from typing import Any, Dict, Optional, Iterable, Tuple

import jax
import jax.numpy as jnp
from jax import tree_util
import numpy as np
import pandas as pd

# Ensure ColabDesign remains compatible with modern JAX releases.
if not hasattr(jax, "tree_map"):
    jax.tree_map = tree_util.tree_map  # type: ignore[attr-defined]
try:
    import jax.util as jax_util  # type: ignore
except ImportError:  # pragma: no cover - optional shim for newer JAX.
    jax_util = None
if jax_util is not None and not hasattr(jax_util, "wraps"):
    def _jax_util_wraps(wrapped, *, docstr=None, assigned=None, updated=None):
        assigned = ("__module__", "__name__", "__qualname__", "__doc__", "__annotations__")
        updated = ("__dict__",)
        def decorator(func):
            result = functools.wraps(wrapped, assigned=assigned, updated=updated)(func)
            if docstr is not None:
                result.__doc__ = docstr
            return result
        return decorator
    jax_util.wraps = _jax_util_wraps  # type: ignore[attr-defined]

from colabdesign import mk_afdesign_model, clear_mem

AA = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {a: i for i, a in enumerate(AA)}
idx_to_aa = {i: a for a, i in aa_to_idx.items()}

ROOT_DIR = Path(r"/mnt/e/Code/BindCraft").resolve()
if not ROOT_DIR.exists():
    raise FileNotFoundError(f"root dir not found: {ROOT_DIR}")

INPUT_DIR_CANDIDATES = [ROOT_DIR / "InputTargets", ROOT_DIR / "inputtargets"]
for cand in INPUT_DIR_CANDIDATES:
    if cand.exists():
        INPUT_DIR = cand.resolve()
        break
else:
    raise FileNotFoundError("No InputTargets directory found. Expected one of: " + ", ".join(str(p) for p in INPUT_DIR_CANDIDATES))

RESULTS_DIR_CANDIDATES = [ROOT_DIR / "Results", ROOT_DIR / "results"]
for cand in RESULTS_DIR_CANDIDATES:
    if cand.exists():
        RESULTS_DIR = cand.resolve()
        break
else:
    RESULTS_DIR = (ROOT_DIR / "Results").resolve()
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)

AF_PARAMS_DIR = (ROOT_DIR / "bindcraft" / "params").resolve()
if not AF_PARAMS_DIR.exists():
    raise FileNotFoundError(f"AlphaFold parameter directory not found: {AF_PARAMS_DIR}")
os.environ.setdefault("AF_PARAMS_DIR", str(AF_PARAMS_DIR))
OUTPUT_BASE = (RESULTS_DIR / "MutaCraft").resolve()
OUTPUT_BASE.mkdir(parents=True, exist_ok=True)

GLOBAL_RNG_SEED = 0


  if jax_util is not None and not hasattr(jax_util, "wraps"):


# User Inputs & Parameters

In [None]:
# --- User inputs ---
SEED_FASTA = ""  # paste binder FASTA directly here (optional if using SEED_FASTA_PATH).
SEED_FASTA_PATH = INPUT_DIR / "seed.fasta"  # optional FASTA file on disk.
TARGET_PDB_PATH = INPUT_DIR / "HumanLysozyme.pdb"    # Target Protein to design binders for
TEMPLATE_BINDER_PDB_PATH = ""  # optional template binder PDB. Leave empty for warm start
TEMPLATE_BINDER_CHAINS = ["A"]  # chains to pull sequence from when using template binder
TARGET_CHAIN = "B"  # chain ID in TARGET_PDB_PATH used for receptor
USE_AF_MULTIMER = True  # toggle AlphaFold multimer model for MutaCraft, default is monomer
MUTACRAFT_MODEL_COUNT = 1  # number of AF models to use
MUTACRAFT_NUM_RECYCLES = 0  # number of AF recycles per step


GUIDED_FRACTION = 0.9
GUIDED_SEED_RATIO_START = 0.9
GUIDED_SEED_RATIO_END = 0.7

STAGE1_ITERS = 50
STAGE1_EXTRA = 25
STAGE2_ITERS = 60
STAGE3_ITERS = 5
STAGE4_ITERS = 30
MUT_RATE = 0.02

KL_W_START = 0.1
KL_W_END = 0.0
USE_KL_PRIOR = True
BIAS_STAGE4_PROPOSALS = True
STAGE4_SEED_BIAS = 0.15  # probability mass reassigned to seed AA when biasing proposals.

RUN_NAME = os.environ.get("MUTACRAFT_RUN_NAME", "warmstart1")  # subdirectory within Results/MutaCraft/
RUN_BASELINES = False
RUN_RANDOM_BASELINE = True
RUN_SEED_BASELINE = True


# --- Warm start (used when no template binder is provided) ---
WARM_START_ENABLED = True
WARMSTART_NAME = "MutaCraft_WarmStart"
WARMSTART_RESULTS = RESULTS_DIR / WARMSTART_NAME
WARMSTART_SETTINGS = ROOT_DIR / "bindcraft" / "settings_target" / f"{WARMSTART_NAME}.json"
WARMSTART_OVERWRITE = True
WARMSTART_MAX_TRAJ = 1
WARMSTART_ADVANCED_PATH = ROOT_DIR / "bindcraft" / "settings_advanced" / f"{WARMSTART_NAME}.json"
WARMSTART_FILTERS = ROOT_DIR / "bindcraft" / "settings_filters" / "peptide_relaxed_filters_loose.json"
WARMSTART_LENGTHS = [80,120]  # Desired binder length(s) for warm start.
WARMSTART_MIN_PASS = 0 # minimum number of warm start candidates that must pass the filters(i_pTM/pLDDT/interface/clashes). If fewer pass, it doesn’t stop the run; it just prints a warning.
WARMSTART_TOP_K = 1 # how many of the top‑ranked candidates (after diversity penalty) are kept and written to warm_start_candidates.json. The first of these is used as the seed for MutaCraft.

COMPARE_BINDCRAFT_SETTINGS = ROOT_DIR / "bindcraft" / "settings_advanced" / f"{WARMSTART_NAME}.json"

# Base filters for warm start candidates
MIN_WARM_PLDTT = 0.60
MIN_WARM_IPTM = 0.20
MIN_WARM_INTERFACE_NRES = 8
MAX_WARM_RELAXED_CLASHES = 10

# Scoring weights for warm start ranking
SCORE_W_IPTM = 2.0
SCORE_W_PLDTT = 1.0
SCORE_W_INTERFACE = 0.02
SCORE_W_CLASHES = 0.1

# Diversity settings
DIVERSITY_LAMBDA = 0.5
DIVERSITY_KMER_K = 3
DIVERSITY_KMER_WEIGHT = 0.5
SIM_THRESHOLD = 0.80
SIM_MODE = "reject"  # one of: "reject", "inflate", "resample"
SIM_PENALTY_MULT = 3.0
ARCHIVE_PATH = OUTPUT_BASE / "warm_start_archive.json"


In [3]:
def read_fasta_str(fasta_str: str) -> str:
    """Parse a FASTA string and return the concatenated sequence."""
    lines = [line.strip() for line in fasta_str.splitlines() if line.strip()]
    if not lines:
        return ""
    if lines[0].startswith(">"):
        lines = lines[1:]
    seq = "".join(lines).replace(" ", "").replace("	", "")
    return seq.upper()


def seq_to_onehot(seq: str) -> jnp.ndarray:
    """Convert a sequence string to a one-hot (L, 20) array."""
    L_local = len(seq)
    oh = jnp.zeros((L_local, 20))
    idxs = jnp.array([aa_to_idx.get(a, -1) for a in seq])
    assert jnp.all(idxs >= 0), "Sequence contains non-standard amino acids."
    oh = oh.at[jnp.arange(L_local), idxs].set(1.0)
    return oh


def onehot_to_logits(oh: jnp.ndarray, sharp: float = 6.0) -> jnp.ndarray:
    """Map one-hot encodings to sharp logits so softmax(logits) is close to one-hot."""
    return sharp * oh


In [4]:
THREE_TO_ONE = {
    'ALA': 'A',
    'ARG': 'R',
    'ASN': 'N',
    'ASP': 'D',
    'CYS': 'C',
    'GLN': 'Q',
    'GLU': 'E',
    'GLY': 'G',
    'HIS': 'H',
    'ILE': 'I',
    'LEU': 'L',
    'LYS': 'K',
    'MET': 'M',
    'PHE': 'F',
    'PRO': 'P',
    'SER': 'S',
    'THR': 'T',
    'TRP': 'W',
    'TYR': 'Y',
    'VAL': 'V',
}

def extract_pdb_sequence(pdb_path: Path, chains: Optional[Iterable[str]] = None) -> str:
    """Extract a concatenated sequence from the specified chains in a PDB file."""
    def _parse_resseq(val: str) -> tuple[int, str]:
        val = val.strip()
        number = []
        suffix = []
        for ch in val:
            if ch.isdigit() or (ch == '-' and not number):
                number.append(ch)
            else:
                suffix.append(ch)
        num_val = int(''.join(number)) if number else 0
        return num_val, ''.join(suffix)

    if chains is not None:
        chain_order = list(chains)
    else:
        chain_order = []
    residues_by_chain: Dict[str, Dict[tuple, str]] = {}
    with pdb_path.open() as handle:
        for line in handle:
            if not line.startswith(('ATOM', 'HETATM')):
                continue
            chain_id = line[21] if len(line) > 21 else ' ' 
            chain_id = chain_id if chain_id.strip() else ' ' 
            if chains is not None and chain_id not in chain_order:
                continue
            resseq = line[22:26]
            icode = line[26]
            key = (chain_id, resseq, icode)
            chain_dict = residues_by_chain.setdefault(chain_id, {})
            if key in chain_dict:
                continue
            resname = line[17:20].strip().upper()
            aa = THREE_TO_ONE.get(resname, 'X')
            chain_dict[key] = aa
    if chains is None:
        chain_order = sorted(residues_by_chain.keys())
    sequence_parts = []
    for chain_id in chain_order:
        residues = residues_by_chain.get(chain_id, {})
        sorted_keys = sorted(residues.keys(), key=lambda x: (*_parse_resseq(x[1]), x[2]))
        for key in sorted_keys:
            sequence_parts.append(residues[key])
    return ''.join(sequence_parts)


In [5]:
def make_guided_mask(L_local: int, fraction: float = 0.5, rng: Optional[np.random.Generator] = None) -> np.ndarray:
    """Create a boolean mask marking guided positions."""
    rng = np.random.default_rng(None if rng is None else rng.integers(1 << 32))
    idx = np.arange(L_local)
    rng.shuffle(idx)
    k = int(round(fraction * L_local))
    guided = np.zeros(L_local, dtype=bool)
    guided[idx[:k]] = True
    return guided


def guided_seed_ratio_schedule(t: int, t_max: int, a0: float = 0.9, a1: float = 0.2) -> float:
    """Linear decay from a0 to a1 over t_max steps."""
    t_max = max(1, t_max)
    frac = np.clip(t / t_max, 0.0, 1.0)
    return float(a0 + (a1 - a0) * frac)


def kl_weight_schedule(t: int, t_max: int, w0: float = 0.1, w1: float = 0.0) -> float:
    """Linear decay of the KL weight."""
    t_max = max(1, t_max)
    frac = np.clip(t / t_max, 0.0, 1.0)
    return float(w0 + (w1 - w0) * frac)


In [6]:
def ensure_seed_length(seed_seq: str, L_local: int, rng: np.random.Generator) -> str:
    """Trim or pad the seed sequence to length L."""
    seq = seed_seq.strip().upper()
    if len(seq) >= L_local:
        return seq[:L_local]
    pad = ''.join(rng.choice(list(AA), size=L_local - len(seq)))
    return (seq + pad)[:L_local]


def init_mixed_logits(seed_seq: str, L_local: int, guided_mask: np.ndarray, seed_ratio: float, rng: Optional[np.random.Generator] = None, rand_scale: float = 1.0) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Initialise logits by blending seed and random draws."""
    assert len(seed_seq) >= L_local, "Seed sequence must be at least L residues long."
    seed_seq_use = seed_seq[:L_local]
    seed_oh = seq_to_onehot(seed_seq_use)
    seed_logits = onehot_to_logits(seed_oh, sharp=6.0)

    rng = np.random.default_rng(None if rng is None else rng.integers(1 << 32))
    rand_logits = jnp.array(rng.normal(0.0, rand_scale, size=(L_local, 20)))

    mask = jnp.array(guided_mask).reshape(-1, 1).astype(jnp.float32)
    pos_seed_ratio = mask * seed_ratio
    logits0 = pos_seed_ratio * seed_logits + (1.0 - pos_seed_ratio) * rand_logits
    return logits0, seed_logits


def blend_logits(current_logits: jnp.ndarray, seed_logits: jnp.ndarray, guided_mask: np.ndarray, seed_ratio: float) -> jnp.ndarray:
    """Blend current logits with seed logits at guided positions."""
    mask = jnp.array(guided_mask).reshape(-1, 1).astype(jnp.float32)
    blended = mask * (seed_ratio * seed_logits + (1.0 - seed_ratio) * current_logits) + (1.0 - mask) * current_logits
    return blended


def indices_to_seq(idxs: np.ndarray) -> str:
    """Convert integer indices to an amino-acid string."""
    return ''.join(idx_to_aa[int(i)] for i in idxs)


In [7]:

def set_binder_softseq(model, softseq: np.ndarray) -> None:
    """Inject a soft sequence (probabilities) into the binder portion of the model."""
    binder_len = softseq.shape[0]
    total_logits = np.array(model._params["seq"])
    binder_slice = slice(0, binder_len)
    logits = np.log(np.clip(softseq, 1e-8, 1.0))
    if logits.ndim == 2:
        logits = logits[None, :, :]
    total_logits[:, binder_slice, :] = logits
    model._params["seq"] = total_logits
    model._tmp = getattr(model, "_tmp", {})
    model._tmp["seq_logits"] = total_logits


def guided_seq_kl(logits: jnp.ndarray, seed_logits: jnp.ndarray, guided_mask: np.ndarray) -> jnp.ndarray:
    """KL divergence between current and seed distributions on guided sites."""
    p = jax.nn.softmax(logits, axis=-1)
    q = jax.nn.softmax(seed_logits, axis=-1)
    mask = jnp.array(guided_mask).reshape(-1, 1).astype(jnp.float32)
    p_safe = jnp.clip(p, 1e-8, 1.0)
    q_safe = jnp.clip(q, 1e-8, 1.0)
    kl = jnp.sum(mask * (p_safe * (jnp.log(p_safe) - jnp.log(q_safe))))
    return kl


def collect_metrics(model, stage: str, iter_in_stage: int, binder_aux_slice: slice, guided_seed_ratio: Optional[float], kl_weight: float, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    """Package per-iteration metrics for downstream logging."""
    aux = getattr(model, "aux", {}) or {}
    entry: Dict[str, Any] = {
        "stage": stage,
        "iter_in_stage": int(iter_in_stage),
        "global_iter": int(getattr(model, "_k", 0)),
        "guided_seed_ratio": None if guided_seed_ratio is None else float(guided_seed_ratio),
        "kl_weight": float(kl_weight),
        "loss_total": float(aux.get("loss", np.nan)),
    }
    losses = aux.get("losses", {})
    for name, value in losses.items():
        entry[f"loss_{name}"] = float(value)
    if "plddt" in aux:
        plddt = np.array(aux["plddt"], dtype=float)
        binder_plddt = plddt[binder_aux_slice]
        entry["binder_plddt_mean"] = float(np.nanmean(binder_plddt))
        entry["binder_plddt_min"] = float(np.nanmin(binder_plddt))
        entry["binder_plddt_max"] = float(np.nanmax(binder_plddt))
        if binder_aux_slice.start > 0:
            target_plddt = plddt[:binder_aux_slice.start]
            entry["target_plddt_mean"] = float(np.nanmean(target_plddt))
    for key in ("ptm", "i_ptm", "pae", "i_pae"):
        if key in aux:
            value = aux[key]
            if isinstance(value, (np.ndarray, jnp.ndarray)):
                value = np.nanmean(np.array(value, dtype=float))
            entry[key] = float(value)
    if extra:
        entry.update(extra)
    return entry


def update_best_record(model, best: Dict[str, Any], stage: str, iter_in_stage: int, binder_slice: slice) -> Dict[str, Any]:
    """Track the best (lowest-loss) design encountered so far."""
    aux = getattr(model, "aux", {}) or {}
    current_loss = float(aux.get("loss", np.inf))
    if best.get("loss") is None or current_loss < best["loss"]:
        if "seq" in aux and "logits" in aux["seq"]:
            seq_logits = np.array(aux["seq"]["logits"])
        else:
            seq_logits = np.array(model._params["seq"])
        seq_idx = seq_logits.argmax(-1)
        best = {
            "loss": current_loss,
            "seq": np.array(seq_idx, dtype=np.int64),
            "aux": copy.deepcopy(aux),
            "stage": stage,
            "iter_in_stage": int(iter_in_stage),
            "global_iter": int(getattr(model, "_k", 0))
        }
    return best


In [None]:
def _softmax_np(logits: np.ndarray) -> np.ndarray:
    logits = np.asarray(logits, dtype=float)
    logits = logits - logits.max(axis=-1, keepdims=True)
    exp = np.exp(logits)
    return exp / np.clip(exp.sum(axis=-1, keepdims=True), 1e-8, None)




def summarize_bindcraft_af_settings(path: Path) -> dict:
    if not path.exists():
        return {"models": None, "recycles": None}
    with path.open() as fh:
        data = json.load(fh)
    models = data.get("design_models_override")
    if isinstance(models, list):
        model_count = len(models)
    else:
        model_count = None
    recycles = data.get("num_recycles_design")
    return {"models": model_count, "recycles": recycles}


def write_json(path: Path, payload: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open('w', encoding='utf-8') as handle:
        json.dump(payload, handle, indent=2)


def stream_process(cmd: list[str], log_path: Path, env: dict[str, str] | None = None) -> tuple[int, float, Path]:
    log_path.parent.mkdir(parents=True, exist_ok=True)
    start = time.perf_counter()
    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        universal_newlines=True,
        env=env,
        cwd=str(ROOT_DIR),
    )
    with log_path.open('w', encoding='utf-8') as logf:
        if proc.stdout:
            for line in proc.stdout:
                print(line, end='')
                logf.write(line)
    proc.wait()
    elapsed = time.perf_counter() - start
    return proc.returncode, elapsed, log_path


def maybe_create_warmstart_settings() -> tuple[Path, Path, Path]:
    if not TARGET_PDB_PATH.exists():
        raise FileNotFoundError(f"Target PDB missing: {TARGET_PDB_PATH}")
    template_payload = {
        "design_path": str(WARMSTART_RESULTS),
        "binder_name": WARMSTART_NAME,
        "starting_pdb": str(TARGET_PDB_PATH),
        "chains": TARGET_CHAIN,
        "target_hotspot_residues": None,
        "lengths": WARMSTART_LENGTHS,
        "number_of_final_designs": 1,
    }
    if WARMSTART_OVERWRITE or not WARMSTART_SETTINGS.exists():
        write_json(WARMSTART_SETTINGS, template_payload)
    
    if WARMSTART_MAX_TRAJ is not None:
        try:
            with WARMSTART_ADVANCED_PATH.open() as fh:
                advanced_payload = json.load(fh)
            advanced_payload["max_trajectories"] = int(WARMSTART_MAX_TRAJ)
            if WARMSTART_OVERWRITE:
                write_json(WARMSTART_ADVANCED_PATH, advanced_payload)
        except Exception as exc:
            print(f"Warning: could not update warm-start advanced settings: {exc}")

    return WARMSTART_SETTINGS, WARMSTART_FILTERS, WARMSTART_ADVANCED_PATH


def run_warm_start() -> dict:
    settings, filters, advanced = maybe_create_warmstart_settings()
    log_path = OUTPUT_BASE / "warmstart.log"
    cmd = [
        sys.executable,
        str(ROOT_DIR / "bindcraft" / "bindcraft.py"),
        "--settings", str(settings),
        "--filters", str(filters),
        "--advanced", str(advanced),
    ]
    returncode, elapsed, log_path = stream_process(cmd, log_path)
    return {"returncode": returncode, "elapsed_seconds": elapsed, "log": log_path}


def parse_interface_residues(raw: Any, binder_len: int) -> list[int]:
    if raw is None:
        return []
    text = str(raw).strip()
    if not text or text.lower() in {"nan", "none"}:
        return []
    indices = [int(x) for x in re.findall(r"\d+", text)]
    if not indices:
        return []
    # BindCraft logs 1-based indices; convert to 0-based and clamp to binder length.
    cleaned = [idx - 1 for idx in indices if 1 <= idx <= binder_len]
    return sorted(set(cleaned))


def sequence_identity(a: str, b: str) -> float:
    if not a or not b:
        return 0.0
    if len(a) == len(b):
        matches = sum(1 for x, y in zip(a, b) if x == y)
        return matches / len(a)
    short, long = (a, b) if len(a) <= len(b) else (b, a)
    best = 0
    for start in range(len(long) - len(short) + 1):
        window = long[start:start + len(short)]
        matches = sum(1 for x, y in zip(short, window) if x == y)
        best = max(best, matches)
    return best / len(short)


def kmer_overlap(a: str, b: str, k: int = 3) -> float:
    if not a or not b or len(a) < k or len(b) < k:
        return 0.0
    set_a = {a[i:i + k] for i in range(len(a) - k + 1)}
    set_b = {b[i:i + k] for i in range(len(b) - k + 1)}
    if not set_a or not set_b:
        return 0.0
    overlap = len(set_a & set_b)
    denom = max(1, min(len(set_a), len(set_b)))
    return overlap / denom


def load_warm_start_archive(path: Path) -> list[str]:
    if path.exists():
        with path.open() as handle:
            payload = json.load(handle)
        seqs = payload.get("sequences", [])
        return [str(s).strip().upper() for s in seqs if str(s).strip()]
    return []


def save_warm_start_archive(path: Path, sequences: list[str]) -> None:
    payload = {"sequences": sequences}
    write_json(path, payload)


def score_warmstart_row(row: pd.Series) -> float:
    iptm = float(row.get("i_pTM", 0.0) or 0.0)
    plddt = float(row.get("pLDDT", 0.0) or 0.0)
    n_interface = float(row.get("n_InterfaceResidues", 0.0) or 0.0)
    clashes = float(row.get("Relaxed_Clashes", row.get("Unrelaxed_Clashes", 0.0)) or 0.0)
    score = (SCORE_W_IPTM * iptm) + (SCORE_W_PLDTT * plddt) + (SCORE_W_INTERFACE * n_interface) - (SCORE_W_CLASHES * clashes)
    return score


def select_warmstart_seeds(traj_df: pd.DataFrame, archive: list[str], apply_filters: bool = True, apply_diversity: bool = True) -> list[dict]:
    candidates = []
    for _, row in traj_df.iterrows():
        seq = str(row.get("Sequence", "")).strip().upper()
        if not seq:
            continue
        iptm = float(row.get("i_pTM", 0.0) or 0.0)
        plddt = float(row.get("pLDDT", 0.0) or 0.0)
        n_interface = float(row.get("n_InterfaceResidues", 0.0) or 0.0)
        clashes = float(row.get("Relaxed_Clashes", row.get("Unrelaxed_Clashes", 0.0)) or 0.0)
        if apply_filters and (iptm < MIN_WARM_IPTM or plddt < MIN_WARM_PLDTT or n_interface < MIN_WARM_INTERFACE_NRES or clashes > MAX_WARM_RELAXED_CLASHES):
            continue
        base_score = score_warmstart_row(row)
        max_identity = 0.0
        max_kmer = 0.0
        diversity_penalty = 0.0
        if apply_diversity:
            for prev in archive:
                max_identity = max(max_identity, sequence_identity(seq, prev))
                max_kmer = max(max_kmer, kmer_overlap(seq, prev, k=DIVERSITY_KMER_K))
            diversity_penalty = max_identity + (DIVERSITY_KMER_WEIGHT * max_kmer)
            if max_identity >= SIM_THRESHOLD:
                if SIM_MODE == "reject":
                    continue
                if SIM_MODE == "inflate":
                    diversity_penalty *= SIM_PENALTY_MULT
                if SIM_MODE == "resample":
                    continue
        total_score = base_score - (DIVERSITY_LAMBDA * diversity_penalty)
        binder_len = int(row.get("Length", len(seq)) or len(seq))
        interface_res = parse_interface_residues(row.get("InterfaceResidues"), binder_len)
        candidates.append({
            "sequence": seq,
            "interface_residues": interface_res,
            "base_score": base_score,
            "diversity_penalty": diversity_penalty,
            "total_score": total_score,
            "row": row,
        })
    candidates.sort(key=lambda x: x["total_score"], reverse=True)
    return candidates[:WARMSTART_TOP_K]


def warm_start_select_seed() -> dict:
    traj_csv = WARMSTART_RESULTS / "trajectory_stats.csv"
    if not traj_csv.exists():
        raise FileNotFoundError(f"BindCraft warm start trajectory CSV not found: {traj_csv}")
    traj_df = pd.read_csv(traj_csv)
    if traj_df.empty:
        raise ValueError("BindCraft warm start produced no trajectory rows.")
    archive = load_warm_start_archive(ARCHIVE_PATH)
    selected = select_warmstart_seeds(traj_df, archive)
    if len(selected) < WARMSTART_MIN_PASS:
        print(f"Warning: only {len(selected)} warm start candidates passed filters (target {WARMSTART_MIN_PASS}).")
    if not selected:
        print("No warm-start candidates passed filters; selecting best available trajectory.")
        selected = select_warmstart_seeds(traj_df, archive, apply_filters=False, apply_diversity=False)
        if not selected:
            raise ValueError("No warm-start candidates available in trajectory_stats.csv.")
        fallback_design = selected[0]["row"].get("Design", "<unknown>")
        print(f"Warm-start fallback seed from trajectory: {fallback_design}")
    archive.extend([item["sequence"] for item in selected])
    save_warm_start_archive(ARCHIVE_PATH, archive)
    candidates_path = OUTPUT_BASE / "warm_start_candidates.json"
    compact = []
    for item in selected:
        row = item["row"]
        compact.append({
            "sequence": item["sequence"],
            "interface_residues": item["interface_residues"],
            "base_score": item["base_score"],
            "diversity_penalty": item["diversity_penalty"],
            "total_score": item["total_score"],
            "pLDDT": float(row.get("pLDDT", 0.0) or 0.0),
            "i_pTM": float(row.get("i_pTM", 0.0) or 0.0),
            "n_InterfaceResidues": float(row.get("n_InterfaceResidues", 0.0) or 0.0),
            "Relaxed_Clashes": float(row.get("Relaxed_Clashes", row.get("Unrelaxed_Clashes", 0.0)) or 0.0),
            "Design": str(row.get("Design", "")),
        })
    write_json(candidates_path, {"candidates": compact})
    best = selected[0]
    guided_mask = None
    if best["interface_residues"]:
        binder_len = len(best["sequence"])
        guided_mask = np.zeros(binder_len, dtype=bool)
        for idx in best["interface_residues"]:
            if 0 <= idx < binder_len:
                guided_mask[idx] = True
    return {
        "seed_sequence": best["sequence"],
        "guided_mask": guided_mask,
        "candidates": selected,
    }


def _prepare_seed(seed_input: str, seed_path: Path, template_binder: Optional[Path] = None, binder_chains: Optional[Iterable[str]] = None) -> str:
    if seed_input.strip():
        return read_fasta_str(seed_input)
    if seed_path.exists() and seed_path.is_file():
        return read_fasta_str(seed_path.read_text())
    if template_binder and template_binder.exists():
        seq = extract_pdb_sequence(template_binder, binder_chains)
        if seq:
            return seq
    raise ValueError("Provide SEED_FASTA, a FASTA file at SEED_FASTA_PATH, or a valid template binder PDB.")
def run_stage4_semigreedy(model, stage_len: int, mut_rate: float, guided_mask: np.ndarray, seed_idxs: np.ndarray, binder_param_slice: slice, binder_aux_slice: slice, rng: np.random.Generator, bias_stage4: bool, bias_weight: float, metrics: list[Dict[str, Any]], best_record: Dict[str, Any], af_counter: Dict[str, int]) -> Dict[str, Any]:
    """Semi-greedy refinement with optional seed-aware proposal bias."""
    binder_len = binder_param_slice.stop - binder_param_slice.start
    num_mutations = max(1, int(round(mut_rate * binder_len)))
    tries_per_iter = max(4, num_mutations * 2)
    for local_idx in range(stage_len):
        aux = getattr(model, "aux", {}) or {}
        if "seq" in aux and "logits" in aux["seq"]:
            logits_full = np.array(aux["seq"]["logits"], dtype=float)
        else:
            logits_full = np.array(model._params["seq"], dtype=float)
        current_seq = logits_full.argmax(-1)
        binder_logits = logits_full[0, binder_param_slice, :]
        binder_seq = current_seq[:, binder_param_slice].copy()
        current_loss = float(aux.get("loss", np.inf))
        if "plddt" in aux:
            binder_plddt = np.array(aux["plddt"], dtype=float)[binder_aux_slice]
            pos_weights = 1.0 - np.clip(binder_plddt, 0.0, 1.0)
        else:
            pos_weights = np.ones(binder_len, dtype=float)
        if not np.isfinite(pos_weights).any():
            pos_weights = np.ones(binder_len, dtype=float)
        pos_weights = np.clip(pos_weights, 1e-3, None)
        pos_weights /= pos_weights.sum()
        proposals = []
        for _ in range(tries_per_iter):
            mutate_positions = rng.choice(binder_len, size=num_mutations, replace=False, p=pos_weights)
            mutant = binder_seq.copy()
            for pos in mutate_positions:
                probs = _softmax_np(binder_logits[pos])
                probs[int(mutant[0, pos])] = 0.0
                if bias_stage4 and guided_mask[pos]:
                    probs = (1.0 - bias_weight) * probs
                    probs[int(seed_idxs[pos])] += bias_weight
                probs = probs / np.clip(probs.sum(), 1e-8, None)
                mutant[0, pos] = rng.choice(20, p=probs)
            candidate = current_seq.copy()
            candidate[:, binder_param_slice] = mutant
            aux_candidate = model.predict(seq=candidate, return_aux=True, verbose=False, dropout=False, hard=True, soft=False, temp=1e-2, models=[0], num_models=1, sample_models=False)
            af_counter["count"] = af_counter.get("count", 0) + 1
            proposals.append((float(aux_candidate["loss"]), candidate.copy(), copy.deepcopy(model.aux)))
        best_loss, best_seq_full, best_aux = min(proposals, key=lambda x: x[0])
        if best_loss <= current_loss or not np.isfinite(current_loss):
            model.set_seq(seq=best_seq_full, bias=model._inputs.get("bias"), set_state=False)
            model.aux = best_aux
        metrics.append(collect_metrics(model, stage="stage4", iter_in_stage=local_idx, binder_aux_slice=binder_aux_slice, guided_seed_ratio=None, kl_weight=0.0, extra={"temp": 1e-2, "accepted_loss": float(model.aux.get("loss", np.nan))}))
        best_record = update_best_record(model, best_record, stage="stage4", iter_in_stage=local_idx, binder_slice=binder_param_slice)
        model._k = int(getattr(model, "_k", 0)) + 1
    return best_record
def run_hybrid_design(
    target_pdb_path: Path,
    binder_len: int,
    seed_fasta: str,
    template_binder_path: Optional[Path] = None,
    stage_params: Optional[Dict[str, int]] = None,
    init_params: Optional[Dict[str, float]] = None,
    kl_params: Optional[Dict[str, float]] = None,
    guided_fraction: float = 0.5,
    mut_rate: float = 0.05,
    use_kl_prior: bool = True,
    bias_stage4: bool = True,
    guided_mask_override: Optional[np.ndarray] = None,
    rng_seed: int = 0,
    run_tag: str = "hybrid",
    output_root: Path = OUTPUT_BASE,
    baseline_mode: str = "hybrid", #Default is 'hybrid'
) -> Dict[str, Any]:
    """Execute the hybrid AF2 binder hallucination workflow."""
    target_pdb_path = Path(target_pdb_path)
    if not target_pdb_path.exists():
        raise FileNotFoundError(f"Target PDB not found: {target_pdb_path}")
    template_binder_path = Path(template_binder_path) if template_binder_path else None
    stage_defaults = {
        "STAGE1_ITERS": 50,
        "STAGE1_EXTRA": 25,
        "STAGE2_ITERS": 45,
        "STAGE3_ITERS": 5,
        "STAGE4_ITERS": 15,
    }
    if stage_params:
        stage_defaults.update(stage_params)
    stage_params = stage_defaults
    init_defaults = {"seed_ratio_start": 0.9, "seed_ratio_end": 0.2}
    if init_params:
        init_defaults.update({"seed_ratio_start": init_params.get("seed_ratio_start", 0.9), "seed_ratio_end": init_params.get("seed_ratio_end", 0.2)})
    init_params = init_defaults
    kl_defaults = {"w_start": 0.1, "w_end": 0.0}
    if kl_params:
        kl_defaults.update({"w_start": kl_params.get("w_start", 0.1), "w_end": kl_params.get("w_end", 0.0)})
    kl_params = kl_defaults
    rng = np.random.default_rng(rng_seed)
    af_counter = {"count": 0}
    seed_seq_raw = read_fasta_str(seed_fasta)
    if not seed_seq_raw:
        raise ValueError("Seed FASTA is empty after parsing.")
    if baseline_mode == "random":
        guided_fraction = 0.0
        init_params["seed_ratio_start"] = 0.0
        init_params["seed_ratio_end"] = 0.0
    elif baseline_mode == "seed":
        guided_fraction = 1.0
        init_params["seed_ratio_start"] = 1.0
        init_params["seed_ratio_end"] = 1.0
    clear_mem()
    model = mk_afdesign_model(protocol="binder", data_dir=str(AF_PARAMS_DIR), use_multimer=USE_AF_MULTIMER)
    if MUTACRAFT_MODEL_COUNT and MUTACRAFT_MODEL_COUNT > 0:
        model._model_names = model._model_names[:MUTACRAFT_MODEL_COUNT]
    model.set_opt(num_models=MUTACRAFT_MODEL_COUNT, sample_models=False, num_recycles=MUTACRAFT_NUM_RECYCLES)
    bindcraft_settings = summarize_bindcraft_af_settings(COMPARE_BINDCRAFT_SETTINGS)
    same_models = bindcraft_settings.get("models") == MUTACRAFT_MODEL_COUNT
    same_recycles = bindcraft_settings.get("recycles") == MUTACRAFT_NUM_RECYCLES
    print(f"BindCraft models: {bindcraft_settings.get('models')}, MutaCraft models: {MUTACRAFT_MODEL_COUNT}")
    print(f"BindCraft recycles: {bindcraft_settings.get('recycles')}, MutaCraft recycles: {MUTACRAFT_NUM_RECYCLES}")
    print(f"AF2 settings match: models={same_models}, recycles={same_recycles}")
    prep_kwargs = {"pdb_filename": str(target_pdb_path), "binder_len": binder_len, "chain": TARGET_CHAIN}
    if template_binder_path and template_binder_path.exists():
        prep_kwargs["binder_pdb"] = str(template_binder_path)
    model.prep_inputs(**prep_kwargs)
    target_len = model._target_len
    total_seq_len = int(model._params["seq"].shape[1])
    binder_len_model = int(getattr(model, "_binder_len", total_seq_len))
    if binder_len_model <= 0:
        raise ValueError(f"Invalid binder length ({binder_len_model}). Check TARGET_CHAIN={TARGET_CHAIN} and that {target_pdb_path} contains that chain.")
    print(f"target_len={target_len}, total_seq_len={total_seq_len}, binder_len={binder_len_model}")
    total_seq_len = model._params["seq"].shape[1]
    if binder_len_model != binder_len:
        print(f"Adjusted binder length from {binder_len} to {binder_len_model} to match AF2 model inputs.")
    binder_len = binder_len_model
    seed_seq = ensure_seed_length(seed_seq_raw, binder_len, rng)
    binder_param_slice = slice(0, binder_len)
    binder_aux_slice = slice(target_len, target_len + binder_len)
    guided_mask = None
    if guided_mask_override is not None:
        guided_mask = np.array(guided_mask_override, dtype=bool)
        if guided_mask.shape[0] != binder_len:
            print(f"Guided mask length {guided_mask.shape[0]} does not match binder length {binder_len}; falling back to random mask.")
            guided_mask = None
    if guided_mask is None:
        guided_mask = make_guided_mask(binder_len, fraction=guided_fraction, rng=rng)
    logits0, seed_logits = init_mixed_logits(seed_seq, binder_len, guided_mask, seed_ratio=init_params["seed_ratio_start"], rng=rng)
    logits0_np = np.array(np.asarray(logits0), dtype=np.float32)
    print(f"logits0_np shape: {logits0_np.shape}, seed length: {len(seed_seq)}")
    seed_logits = jnp.array(seed_logits)
    seed_idxs = np.array([aa_to_idx[a] for a in seed_seq], dtype=np.int64)
    seq_logits_full = np.array(model._params["seq"], dtype=float)
    seq_logits_full[:, binder_param_slice, :] = logits0_np
    model._params["seq"] = seq_logits_full
    soft_init = np.array(jax.nn.softmax(logits0, axis=-1))
    set_binder_softseq(model, soft_init)
    model.set_opt(soft=1.0, hard=0.0, temp=1.0, dropout=True, num_models=1, sample_models=False)
    model._callbacks["design"]["pre"] = []
    metrics: list[Dict[str, Any]] = []
    best_record: Dict[str, Any] = {"loss": None}
    if use_kl_prior:
        kl_loss_fn = jax.jit(lambda x: guided_seq_kl(x, seed_logits, guided_mask))
        kl_grad_fn = jax.jit(jax.grad(lambda x: guided_seq_kl(x, seed_logits, guided_mask)))
    else:
        kl_loss_fn = None
        kl_grad_fn = None
    stage1_total = stage_params["STAGE1_ITERS"] + stage_params["STAGE1_EXTRA"]
    print(f"Stage 1/4 (logits): {stage1_total} iterations")
    kl_total = stage1_total + max(0, stage_params["STAGE2_ITERS"] // 2)
    kl_counter = 0
    for local_idx in range(stage1_total):
        print(f"Stage 1 progress {local_idx+1}/{stage1_total}", end="\r", flush=True)
        seed_ratio = guided_seed_ratio_schedule(local_idx, max(1, stage1_total - 1), init_params["seed_ratio_start"], init_params["seed_ratio_end"])
        kl_weight = 0.0
        if use_kl_prior and kl_loss_fn is not None:
            kl_weight = kl_weight_schedule(kl_counter, max(1, kl_total - 1), kl_params["w_start"], kl_params["w_end"])
            kl_counter += 1
        def pre_cb(mod, seed_ratio_val=seed_ratio):
            logits_now = jnp.array(mod._params["seq"][0, binder_param_slice, :])
            blended = blend_logits(logits_now, seed_logits, guided_mask, seed_ratio_val)
            seq_all = np.array(mod._params["seq"], dtype=float)
            seq_all[0, binder_param_slice, :] = np.array(blended)
            mod._params["seq"] = seq_all
        def post_cb(mod, seed_ratio_val=seed_ratio, kl_weight_val=kl_weight):
            if use_kl_prior and kl_loss_fn is not None and kl_weight_val > 0:
                binder_logits = jnp.array(mod._params["seq"][0, binder_param_slice, :])
                kl_val = kl_loss_fn(binder_logits)
                grad = kl_grad_fn(binder_logits)
                mod.aux["grad"]["seq"][0, binder_param_slice, :] += np.array(kl_weight_val * grad)
                mod.aux.setdefault("log", {})["guided_kl"] = float(kl_val)
                mod.aux["log"]["kl_weight"] = float(kl_weight_val)
            mod.aux.setdefault("log", {})["guided_seed_ratio"] = float(seed_ratio_val)
        model._callbacks["design"]["pre"] = [pre_cb]
        model.step(callback=post_cb, save_best=True, models=[0], num_models=1, sample_models=False)
        af_counter["count"] += 1
        model._callbacks["design"]["pre"] = []
        metrics.append(collect_metrics(model, stage="stage1", iter_in_stage=local_idx, binder_aux_slice=binder_aux_slice, guided_seed_ratio=seed_ratio, kl_weight=kl_weight))
        best_record = update_best_record(model, best_record, stage="stage1", iter_in_stage=local_idx, binder_slice=binder_param_slice)
    print(f"Stage 2/4 (softmax anneal): {stage_params['STAGE2_ITERS']} iterations")
    for local_idx in range(stage_params["STAGE2_ITERS"]):
        print(f"Stage 2 progress {local_idx+1}/{stage_params['STAGE2_ITERS']}", end="\r", flush=True)
        frac = np.clip(local_idx / max(1, stage_params["STAGE2_ITERS"] - 1), 0.0, 1.0)
        temp = 1.0 + (0.05 - 1.0) * frac
        kl_weight = 0.0
        apply_kl = use_kl_prior and kl_loss_fn is not None and local_idx < (stage_params["STAGE2_ITERS"] // 2)
        if apply_kl:
            kl_weight = kl_weight_schedule(kl_counter, max(1, kl_total - 1), kl_params["w_start"], kl_params["w_end"])
            kl_counter += 1
        def post_cb(mod, temp_val=temp, kl_weight_val=kl_weight):
            mod.aux.setdefault("log", {})["temp"] = float(temp_val)
            if apply_kl and kl_weight_val > 0:
                binder_logits = jnp.array(mod._params["seq"][0, binder_param_slice, :])
                kl_val = kl_loss_fn(binder_logits)
                grad = kl_grad_fn(binder_logits)
                mod.aux["grad"]["seq"][0, binder_param_slice, :] += np.array(kl_weight_val * grad)
                mod.aux["log"]["guided_kl"] = float(kl_val)
                mod.aux["log"]["kl_weight"] = float(kl_weight_val)
        model.set_opt(temp=float(temp), num_models=1, sample_models=False)
        model.step(callback=post_cb, save_best=True, models=[0], num_models=1, sample_models=False)
        af_counter["count"] += 1
        metrics.append(collect_metrics(model, stage="stage2", iter_in_stage=local_idx, binder_aux_slice=binder_aux_slice, guided_seed_ratio=None, kl_weight=kl_weight, extra={"temp": float(temp)}))
        best_record = update_best_record(model, best_record, stage="stage2", iter_in_stage=local_idx, binder_slice=binder_param_slice)
    model.set_opt(soft=0.0, hard=1.0, temp=1e-2, dropout=False, num_models=1, sample_models=False)
    print(f"Stage 3/4 (hard/STE): {stage_params['STAGE3_ITERS']} iterations")
    for local_idx in range(stage_params["STAGE3_ITERS"]):
        print(f"Stage 3 progress {local_idx+1}/{stage_params['STAGE3_ITERS']}", end="\r", flush=True)
        def post_cb(mod):
            mod.aux.setdefault("log", {})["temp"] = 1e-2
        model.step(callback=post_cb, save_best=True, models=[0], num_models=1, sample_models=False)
        af_counter["count"] += 1
        metrics.append(collect_metrics(model, stage="stage3", iter_in_stage=local_idx, binder_aux_slice=binder_aux_slice, guided_seed_ratio=None, kl_weight=0.0, extra={"temp": 1e-2}))
        best_record = update_best_record(model, best_record, stage="stage3", iter_in_stage=local_idx, binder_slice=binder_param_slice)
    if not getattr(model, "aux", None):
        model.predict(return_aux=True, verbose=False, dropout=False, hard=True, soft=False)
        af_counter["count"] += 1
    print(f"Stage 4/4 (semi-greedy): {stage_params['STAGE4_ITERS']} iterations")
    best_record = run_stage4_semigreedy(
        model=model,
        stage_len=stage_params["STAGE4_ITERS"],
        mut_rate=mut_rate,
        guided_mask=np.array(guided_mask, dtype=bool),
        seed_idxs=seed_idxs,
        binder_param_slice=binder_param_slice,
        binder_aux_slice=binder_aux_slice,
        rng=rng,
        bias_stage4=bias_stage4,
        bias_weight=float(STAGE4_SEED_BIAS),
        metrics=metrics,
        best_record=best_record,
        af_counter=af_counter,
    )
    run_dir = output_root / run_tag
    run_dir.mkdir(parents=True, exist_ok=True)
    metrics_df = pd.DataFrame(metrics)
    metrics_path = run_dir / "metrics.csv"
    metrics_df.to_csv(metrics_path, index=False)
    best_seq_full = np.array(best_record["seq"], dtype=np.int64)
    binder_idx = best_seq_full[:, binder_param_slice]
    binder_seq = indices_to_seq(binder_idx[0])
    fasta_path = run_dir / "best_binder.fasta"
    with open(fasta_path, "w") as handle:
        handle.write(">designed_binder\n")
        handle.write(binder_seq + "\n")
    pdb_path = run_dir / "best_complex.pdb"
    model.set_seq(seq=best_seq_full, bias=model._inputs.get("bias"), set_state=False)
    model.save_pdb(filename=str(pdb_path), get_best=False, aux=best_record["aux"])
    binder_plddt = float(np.nanmean(np.array(best_record["aux"].get("plddt", np.nan), dtype=float)[binder_aux_slice])) if "plddt" in best_record["aux"] else np.nan
    metadata = {
        "run_tag": run_tag,
        "binder_len": binder_len,
        "af_calls": af_counter["count"],
        "best_loss": float(best_record["loss"]),
        "best_plddt": binder_plddt,
        "stage_params": stage_params,
        "init_params": init_params,
        "kl_params": kl_params,
        "mut_rate": mut_rate,
    }
    metadata_path = run_dir / "run_metadata.json"
    with open(metadata_path, "w") as handle:
        json.dump(metadata, handle, indent=2)
    result = {
        "binder_len": binder_len,
        "model": model,
        "metrics_df": metrics_df,
        "metrics_path": metrics_path,
        "fasta_path": fasta_path,
        "best_pdb_path": pdb_path,
        "best_binder_seq": binder_seq,
        "best_loss": float(best_record["loss"]),
        "best_plddt": binder_plddt,
        "af_calls": af_counter["count"],
        "run_tag": run_tag,
        "guided_mask": np.array(guided_mask, dtype=bool),
        "seed_seq": seed_seq,
        "binder_slice": binder_aux_slice,
        "best_record": best_record,
        "target_chain": prep_kwargs.get("target_chain", "A"),
        "binder_chain": prep_kwargs.get("binder_chain", "B"),
        "metadata_path": metadata_path,
    }
    return result


In [9]:
seed_source = ""
seed_interface_mask = None
template_path = None
if WARM_START_ENABLED:
    has_seed = bool(SEED_FASTA.strip()) or (SEED_FASTA_PATH.exists() and SEED_FASTA_PATH.is_file())
    temp_candidate = Path(TEMPLATE_BINDER_PDB_PATH) if str(TEMPLATE_BINDER_PDB_PATH).strip() else None
    has_template = bool(temp_candidate and temp_candidate.exists() and temp_candidate.is_file())
    if not has_seed and not has_template:
        print("No template binder provided; running BindCraft warm start...")
        run_info = run_warm_start()
        print(f"BindCraft warm start exit code {run_info['returncode']} | elapsed {run_info['elapsed_seconds']:.1f} s")
        warm_info = warm_start_select_seed()
        seed_source = warm_info["seed_sequence"]
        seed_interface_mask = warm_info.get("guided_mask")
    else:
        template_path = temp_candidate if has_template else None
        seed_source = _prepare_seed(SEED_FASTA, SEED_FASTA_PATH, template_binder=template_path, binder_chains=TEMPLATE_BINDER_CHAINS)
else:
    template_path = Path(TEMPLATE_BINDER_PDB_PATH) if str(TEMPLATE_BINDER_PDB_PATH).strip() else None
    seed_source = _prepare_seed(SEED_FASTA, SEED_FASTA_PATH, template_binder=template_path, binder_chains=TEMPLATE_BINDER_CHAINS)

if template_path and (not str(template_path).strip() or not template_path.exists() or not template_path.is_file()):
    template_path = None
binder_len = len(seed_source)
if template_path:
    template_seq = extract_pdb_sequence(template_path, TEMPLATE_BINDER_CHAINS)
    if template_seq:
        binder_len = len(template_seq)
stage_cfg = {
    "STAGE1_ITERS": STAGE1_ITERS,
    "STAGE1_EXTRA": STAGE1_EXTRA,
    "STAGE2_ITERS": STAGE2_ITERS,
    "STAGE3_ITERS": STAGE3_ITERS,
    "STAGE4_ITERS": STAGE4_ITERS,
}
init_cfg = {"seed_ratio_start": GUIDED_SEED_RATIO_START, "seed_ratio_end": GUIDED_SEED_RATIO_END}
kl_cfg = {"w_start": KL_W_START, "w_end": KL_W_END}
run_root = OUTPUT_BASE / RUN_NAME
run_root.mkdir(parents=True, exist_ok=True)
HYBRID_RESULT = run_hybrid_design(
    target_pdb_path=TARGET_PDB_PATH,
    binder_len=binder_len,
    seed_fasta=seed_source,
    template_binder_path=template_path,
    stage_params=stage_cfg,
    init_params=init_cfg,
    kl_params=kl_cfg,
    guided_fraction=GUIDED_FRACTION,
    mut_rate=MUT_RATE,
    use_kl_prior=USE_KL_PRIOR,
    bias_stage4=BIAS_STAGE4_PROPOSALS,
    guided_mask_override=seed_interface_mask,
    rng_seed=GLOBAL_RNG_SEED,
    run_tag="hybrid",
    output_root=run_root,
    baseline_mode="hybrid",
)
binder_len = HYBRID_RESULT["binder_len"]
print("Hybrid design complete.")
print(f"  Best loss: {HYBRID_RESULT['best_loss']:.4f}")
print(f"  Binder pLDDT (mean): {HYBRID_RESULT['best_plddt']:.3f}")
print(f"  Sequence saved to: {HYBRID_RESULT['fasta_path']}")
print(f"  Complex PDB saved to: {HYBRID_RESULT['best_pdb_path']}")
print(f"  AlphaFold calls: {HYBRID_RESULT['af_calls']}")
BASELINE_RESULTS = []
if RUN_BASELINES:
    modes = []
    if RUN_RANDOM_BASELINE:
        modes.append("random")
    if RUN_SEED_BASELINE:
        modes.append("seed")
    for offset, mode in enumerate(modes, start=1):
        result = run_hybrid_design(
            target_pdb_path=TARGET_PDB_PATH,
            binder_len=binder_len,
            seed_fasta=seed_source,
            template_binder_path=template_path,
            stage_params=stage_cfg,
            init_params=init_cfg,
            kl_params=kl_cfg,
            guided_fraction=GUIDED_FRACTION,
            mut_rate=MUT_RATE,
            use_kl_prior=USE_KL_PRIOR,
            bias_stage4=BIAS_STAGE4_PROPOSALS,
            rng_seed=GLOBAL_RNG_SEED + offset,
            run_tag=mode,
            output_root=run_root,
            baseline_mode=mode,
        )
        BASELINE_RESULTS.append(result)
        print(f"Baseline '{mode}' complete -> loss {result['best_loss']:.4f}")
        print(f"  AlphaFold calls: {result['af_calls']}")


No template binder provided; running BindCraft warm start...
Available GPUs:
NVIDIA GeForce RTX 4070 Ti1: gpu
┌───────────────────────────────────────────────────────────────────────────────┐
│                                  PyRosetta-4                                  │
│               Created in JHU by Sergey Lyskov and PyRosetta Team              │
│               (C) Copyright Rosetta Commons Member Institutions               │
│                                                                               │
│ NOTE: USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRES PURCHASE OF A LICENSE │
│          See LICENSE.PyRosetta.md or email license@uw.edu for details         │
└───────────────────────────────────────────────────────────────────────────────┘
PyRosetta-4 2025 [Rosetta PyRosetta4.conda.ubuntu.cxx11thread.serialization.Ubuntu.python310.Release 2025.41+release.de3cc17d509259e29147a2ed8f2a726d644e7e34 2025-10-06T16:25:46] retrieved from: http://www.pyrosetta.org
Running binder

KeyboardInterrupt: 

In [None]:
import py3Dmol

pdb_view = py3Dmol.view(width=640, height=480)
with open(str(HYBRID_RESULT['best_pdb_path'])) as handle:
    pdb_view.addModel(handle.read(), 'pdb')

binder_chain = HYBRID_RESULT.get('binder_chain', 'B')
target_chain = HYBRID_RESULT.get('target_chain', 'A')

pdb_view.setStyle({'chain': target_chain}, {'cartoon': {'color': 'white'}})
pdb_view.setStyle({'chain': binder_chain}, {'cartoon': {'color': 'rainbow'}})
pdb_view.zoomTo()
pdb_view.show()


In [None]:
summary_rows = [{
    'run': 'hybrid',
    'loss': HYBRID_RESULT['best_loss'],
    'binder_plddt': HYBRID_RESULT['best_plddt'],
    'fasta_path': str(HYBRID_RESULT['fasta_path']),
    'pdb_path': str(HYBRID_RESULT['best_pdb_path']),
}]
for res in BASELINE_RESULTS:
    summary_rows.append({
        'run': res['run_tag'],
        'loss': res['best_loss'],
        'binder_plddt': res['best_plddt'],
        'fasta_path': str(res['fasta_path']),
        'pdb_path': str(res['best_pdb_path']),
    })
summary_df = pd.DataFrame(summary_rows)
display(summary_df)
summary_df


Unnamed: 0,run,loss,binder_plddt,fasta_path,pdb_path
0,hybrid,4.475759,0.340782,/mnt/e/Code/BindCraft/Results/MutaCraft/warmst...,/mnt/e/Code/BindCraft/Results/MutaCraft/warmst...


Unnamed: 0,run,loss,binder_plddt,fasta_path,pdb_path
0,hybrid,4.475759,0.340782,/mnt/e/Code/BindCraft/Results/MutaCraft/warmst...,/mnt/e/Code/BindCraft/Results/MutaCraft/warmst...
