# MutaCraft: Hybrid Seed + De Novo Binder Hallucination

Run this notebook inside your local WSL installation. It swaps the standard binder initialization for a hybrid strategy that mixes a user-provided seed with random exploration, then steps through a four-stage AlphaFold2 optimization (continuous → softmax → STE → semi-greedy).

You can point the notebook at any target structure under `InputTargets`, provide a template binder (optional), and supply a seed FASTA string or file. Outputs are written to the `Results` directory so they stay consistent with the rest of the tooling.

### Stage Flow
- Stage 1 – Continuous logits with alpha-annealed seed mixing plus optional guided KL prior.
- Stage 2 – Temperature annealing down to a hard distribution while the KL weight tapers out.
- Stage 3 – Straight-through (hard) updates for sharper convergence.
- Stage 4 – Semi-greedy mutation search with optional bias toward the seed amino acids at guided sites.


In [1]:
import os
import copy
import functools
from pathlib import Path
from typing import Any, Dict, Optional, Iterable, Tuple

import jax
import jax.numpy as jnp
from jax import tree_util
import numpy as np
import pandas as pd

# Ensure ColabDesign remains compatible with modern JAX releases.
if not hasattr(jax, "tree_map"):
    jax.tree_map = tree_util.tree_map  # type: ignore[attr-defined]
try:
    import jax.util as jax_util  # type: ignore
except ImportError:  # pragma: no cover - optional shim for newer JAX.
    jax_util = None
if jax_util is not None and not hasattr(jax_util, "wraps"):
    def _jax_util_wraps(wrapped, *, docstr=None, assigned=None, updated=None):
        assigned = ("__module__", "__name__", "__qualname__", "__doc__", "__annotations__")
        updated = ("__dict__",)
        def decorator(func):
            result = functools.wraps(wrapped, assigned=assigned, updated=updated)(func)
            if docstr is not None:
                result.__doc__ = docstr
            return result
        return decorator
    jax_util.wraps = _jax_util_wraps  # type: ignore[attr-defined]

from colabdesign import mk_afdesign_model, clear_mem

AA = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {a: i for i, a in enumerate(AA)}
idx_to_aa = {i: a for a, i in aa_to_idx.items()}

ROOT_DIR = Path(r"/mnt/e/Code/BindCraft").resolve()
if not ROOT_DIR.exists():
    raise FileNotFoundError(f"root dir not found: {ROOT_DIR}")

INPUT_DIR_CANDIDATES = [ROOT_DIR / "InputTargets", ROOT_DIR / "inputtargets"]
for cand in INPUT_DIR_CANDIDATES:
    if cand.exists():
        INPUT_DIR = cand.resolve()
        break
else:
    raise FileNotFoundError("No InputTargets directory found. Expected one of: " + ", ".join(str(p) for p in INPUT_DIR_CANDIDATES))

RESULTS_DIR_CANDIDATES = [ROOT_DIR / "Results", ROOT_DIR / "results"]
for cand in RESULTS_DIR_CANDIDATES:
    if cand.exists():
        RESULTS_DIR = cand.resolve()
        break
else:
    RESULTS_DIR = (ROOT_DIR / "Results").resolve()
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)

AF_PARAMS_DIR = (ROOT_DIR / "bindcraft" / "params").resolve()
if not AF_PARAMS_DIR.exists():
    raise FileNotFoundError(f"AlphaFold parameter directory not found: {AF_PARAMS_DIR}")
os.environ.setdefault("AF_PARAMS_DIR", str(AF_PARAMS_DIR))
OUTPUT_BASE = (RESULTS_DIR / "MutaCraft").resolve()
OUTPUT_BASE.mkdir(parents=True, exist_ok=True)

GLOBAL_RNG_SEED = 0


  if jax_util is not None and not hasattr(jax_util, "wraps"):


In [2]:
# --- User inputs ---
SEED_FASTA = ""  # paste binder FASTA directly here (optional if using SEED_FASTA_PATH).
SEED_FASTA_PATH = INPUT_DIR / "seed.fasta"  # optional FASTA file on disk.
TARGET_PDB_PATH = INPUT_DIR / "HumanLysozyme.pdb"    # receptor or complex PDB.
TEMPLATE_BINDER_PDB_PATH = INPUT_DIR / "HL6_camel_VHH_fragment.pdb"  # optional template binder PDB.
TEMPLATE_BINDER_CHAINS = ["A"]  # chains to pull sequence from when using template binder
TARGET_CHAIN = "B"  # chain ID in TARGET_PDB_PATH used for receptor

L = None  # binder length (None auto-detects from seed/template).
GUIDED_FRACTION = 0.5
INIT_ALPHA_START = 0.9
INIT_ALPHA_END = 0.2

STAGE1_ITERS = 50
STAGE1_EXTRA = 25
STAGE2_ITERS = 45
STAGE3_ITERS = 5
STAGE4_ITERS = 15
MUT_RATE = 0.05

KL_W_START = 0.1
KL_W_END = 0.0
USE_KL_PRIOR = True
BIAS_STAGE4_PROPOSALS = True
STAGE4_SEED_BIAS = 0.15  # probability mass reassigned to seed AA when biasing proposals.

RUN_NAME = "HL6_VHH_run"  # subdirectory within Results/MutaCraft/
RUN_BASELINES = False
RUN_RANDOM_BASELINE = True
RUN_SEED_BASELINE = True


In [3]:
def read_fasta_str(fasta_str: str) -> str:
    """Parse a FASTA string and return the concatenated sequence."""
    lines = [line.strip() for line in fasta_str.splitlines() if line.strip()]
    if not lines:
        return ""
    if lines[0].startswith(">"):
        lines = lines[1:]
    seq = "".join(lines).replace(" ", "").replace("	", "")
    return seq.upper()


def seq_to_onehot(seq: str) -> jnp.ndarray:
    """Convert a sequence string to a one-hot (L, 20) array."""
    L_local = len(seq)
    oh = jnp.zeros((L_local, 20))
    idxs = jnp.array([aa_to_idx.get(a, -1) for a in seq])
    assert jnp.all(idxs >= 0), "Sequence contains non-standard amino acids."
    oh = oh.at[jnp.arange(L_local), idxs].set(1.0)
    return oh


def onehot_to_logits(oh: jnp.ndarray, sharp: float = 6.0) -> jnp.ndarray:
    """Map one-hot encodings to sharp logits so softmax(logits) is close to one-hot."""
    return sharp * oh


In [4]:
THREE_TO_ONE = {
    'ALA': 'A',
    'ARG': 'R',
    'ASN': 'N',
    'ASP': 'D',
    'CYS': 'C',
    'GLN': 'Q',
    'GLU': 'E',
    'GLY': 'G',
    'HIS': 'H',
    'ILE': 'I',
    'LEU': 'L',
    'LYS': 'K',
    'MET': 'M',
    'PHE': 'F',
    'PRO': 'P',
    'SER': 'S',
    'THR': 'T',
    'TRP': 'W',
    'TYR': 'Y',
    'VAL': 'V',
}

def extract_pdb_sequence(pdb_path: Path, chains: Optional[Iterable[str]] = None) -> str:
    """Extract a concatenated sequence from the specified chains in a PDB file."""
    def _parse_resseq(val: str) -> tuple[int, str]:
        val = val.strip()
        number = []
        suffix = []
        for ch in val:
            if ch.isdigit() or (ch == '-' and not number):
                number.append(ch)
            else:
                suffix.append(ch)
        num_val = int(''.join(number)) if number else 0
        return num_val, ''.join(suffix)

    if chains is not None:
        chain_order = list(chains)
    else:
        chain_order = []
    residues_by_chain: Dict[str, Dict[tuple, str]] = {}
    with pdb_path.open() as handle:
        for line in handle:
            if not line.startswith(('ATOM', 'HETATM')):
                continue
            chain_id = line[21] if len(line) > 21 else ' ' 
            chain_id = chain_id if chain_id.strip() else ' ' 
            if chains is not None and chain_id not in chain_order:
                continue
            resseq = line[22:26]
            icode = line[26]
            key = (chain_id, resseq, icode)
            chain_dict = residues_by_chain.setdefault(chain_id, {})
            if key in chain_dict:
                continue
            resname = line[17:20].strip().upper()
            aa = THREE_TO_ONE.get(resname, 'X')
            chain_dict[key] = aa
    if chains is None:
        chain_order = sorted(residues_by_chain.keys())
    sequence_parts = []
    for chain_id in chain_order:
        residues = residues_by_chain.get(chain_id, {})
        sorted_keys = sorted(residues.keys(), key=lambda x: (*_parse_resseq(x[1]), x[2]))
        for key in sorted_keys:
            sequence_parts.append(residues[key])
    return ''.join(sequence_parts)


In [5]:
def make_guided_mask(L_local: int, fraction: float = 0.5, rng: Optional[np.random.Generator] = None) -> np.ndarray:
    """Create a boolean mask marking guided positions."""
    rng = np.random.default_rng(None if rng is None else rng.integers(1 << 32))
    idx = np.arange(L_local)
    rng.shuffle(idx)
    k = int(round(fraction * L_local))
    guided = np.zeros(L_local, dtype=bool)
    guided[idx[:k]] = True
    return guided


def alpha_schedule(t: int, t_max: int, a0: float = 0.9, a1: float = 0.2) -> float:
    """Linear decay from a0 to a1 over t_max steps."""
    t_max = max(1, t_max)
    frac = np.clip(t / t_max, 0.0, 1.0)
    return float(a0 + (a1 - a0) * frac)


def kl_weight_schedule(t: int, t_max: int, w0: float = 0.1, w1: float = 0.0) -> float:
    """Linear decay of the KL weight."""
    t_max = max(1, t_max)
    frac = np.clip(t / t_max, 0.0, 1.0)
    return float(w0 + (w1 - w0) * frac)


In [6]:
def ensure_seed_length(seed_seq: str, L_local: int, rng: np.random.Generator) -> str:
    """Trim or pad the seed sequence to length L."""
    seq = seed_seq.strip().upper()
    if len(seq) >= L_local:
        return seq[:L_local]
    pad = ''.join(rng.choice(list(AA), size=L_local - len(seq)))
    return (seq + pad)[:L_local]


def init_mixed_logits(seed_seq: str, L_local: int, guided_mask: np.ndarray, alpha: float, rng: Optional[np.random.Generator] = None, rand_scale: float = 1.0) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Initialise logits by blending seed and random draws."""
    assert len(seed_seq) >= L_local, "Seed sequence must be at least L residues long."
    seed_seq_use = seed_seq[:L_local]
    seed_oh = seq_to_onehot(seed_seq_use)
    seed_logits = onehot_to_logits(seed_oh, sharp=6.0)

    rng = np.random.default_rng(None if rng is None else rng.integers(1 << 32))
    rand_logits = jnp.array(rng.normal(0.0, rand_scale, size=(L_local, 20)))

    mask = jnp.array(guided_mask).reshape(-1, 1).astype(jnp.float32)
    pos_alpha = mask * alpha
    logits0 = pos_alpha * seed_logits + (1.0 - pos_alpha) * rand_logits
    return logits0, seed_logits


def blend_logits(current_logits: jnp.ndarray, seed_logits: jnp.ndarray, guided_mask: np.ndarray, alpha: float) -> jnp.ndarray:
    """Blend current logits with seed logits at guided positions."""
    mask = jnp.array(guided_mask).reshape(-1, 1).astype(jnp.float32)
    blended = mask * (alpha * seed_logits + (1.0 - alpha) * current_logits) + (1.0 - mask) * current_logits
    return blended


def indices_to_seq(idxs: np.ndarray) -> str:
    """Convert integer indices to an amino-acid string."""
    return ''.join(idx_to_aa[int(i)] for i in idxs)


In [7]:
def set_binder_softseq(model, softseq: np.ndarray) -> None:
    """Inject a soft sequence (probabilities) into the binder portion of the model."""
    binder_len = softseq.shape[0]
    total_logits = np.array(model._params["seq"])
    binder_start = model._target_len
    binder_slice = slice(binder_start, binder_start + binder_len)
    logits = np.log(np.clip(softseq, 1e-8, 1.0))
    total_logits[:, binder_slice, :] = logits
    model._params["seq"] = total_logits
    model._tmp = getattr(model, "_tmp", {})
    model._tmp["seq_logits"] = total_logits


def guided_seq_kl(logits: jnp.ndarray, seed_logits: jnp.ndarray, guided_mask: np.ndarray) -> jnp.ndarray:
    """KL divergence between current and seed distributions on guided sites."""
    p = jax.nn.softmax(logits, axis=-1)
    q = jax.nn.softmax(seed_logits, axis=-1)
    mask = jnp.array(guided_mask).reshape(-1, 1).astype(jnp.float32)
    p_safe = jnp.clip(p, 1e-8, 1.0)
    q_safe = jnp.clip(q, 1e-8, 1.0)
    kl = jnp.sum(mask * (p_safe * (jnp.log(p_safe) - jnp.log(q_safe))))
    return kl


def collect_metrics(model, stage: str, iter_in_stage: int, binder_slice: slice, alpha: Optional[float], kl_weight: float, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    """Package per-iteration metrics for downstream logging."""
    aux = getattr(model, "aux", {}) or {}
    entry: Dict[str, Any] = {
        "stage": stage,
        "iter_in_stage": int(iter_in_stage),
        "global_iter": int(getattr(model, "_k", 0)),
        "alpha": None if alpha is None else float(alpha),
        "kl_weight": float(kl_weight),
        "loss_total": float(aux.get("loss", np.nan)),
    }
    losses = aux.get("losses", {})
    for name, value in losses.items():
        entry[f"loss_{name}"] = float(value)
    if "plddt" in aux:
        plddt = np.array(aux["plddt"], dtype=float)
        binder_plddt = plddt[binder_slice]
        entry["binder_plddt_mean"] = float(np.nanmean(binder_plddt))
        entry["binder_plddt_min"] = float(np.nanmin(binder_plddt))
        entry["binder_plddt_max"] = float(np.nanmax(binder_plddt))
        if binder_slice.start > 0:
            target_plddt = plddt[:binder_slice.start]
            entry["target_plddt_mean"] = float(np.nanmean(target_plddt))
    for key in ("ptm", "i_ptm", "pae", "i_pae"):
        if key in aux:
            value = aux[key]
            if isinstance(value, (np.ndarray, jnp.ndarray)):
                value = np.nanmean(np.array(value, dtype=float))
            entry[key] = float(value)
    if extra:
        entry.update(extra)
    return entry


def update_best_record(model, best: Dict[str, Any], stage: str, iter_in_stage: int, binder_slice: slice) -> Dict[str, Any]:
    """Track the best (lowest-loss) design encountered so far."""
    aux = getattr(model, "aux", {}) or {}
    current_loss = float(aux.get("loss", np.inf))
    if best.get("loss") is None or current_loss < best["loss"]:
        if "seq" in aux and "logits" in aux["seq"]:
            seq_logits = np.array(aux["seq"]["logits"])
        else:
            seq_logits = np.array(model._params["seq"])
        seq_idx = seq_logits.argmax(-1)
        best = {
            "loss": current_loss,
            "seq": np.array(seq_idx, dtype=np.int64),
            "aux": copy.deepcopy(aux),
            "stage": stage,
            "iter_in_stage": int(iter_in_stage),
            "global_iter": int(getattr(model, "_k", 0))
        }
    return best


In [8]:
def _softmax_np(logits: np.ndarray) -> np.ndarray:
    logits = np.asarray(logits, dtype=float)
    logits = logits - logits.max(axis=-1, keepdims=True)
    exp = np.exp(logits)
    return exp / np.clip(exp.sum(axis=-1, keepdims=True), 1e-8, None)
def _prepare_seed(seed_input: str, seed_path: Path, template_binder: Optional[Path] = None, binder_chains: Optional[Iterable[str]] = None) -> str:
    if seed_input.strip():
        return read_fasta_str(seed_input)
    if seed_path.exists():
        return read_fasta_str(seed_path.read_text())
    if template_binder and template_binder.exists():
        seq = extract_pdb_sequence(template_binder, binder_chains)
        if seq:
            return seq
    raise ValueError("Provide SEED_FASTA, a FASTA file at SEED_FASTA_PATH, or a valid template binder PDB.")
def run_stage4_semigreedy(model, stage_len: int, mut_rate: float, guided_mask: np.ndarray, seed_idxs: np.ndarray, binder_slice: slice, rng: np.random.Generator, bias_stage4: bool, bias_weight: float, metrics: list[Dict[str, Any]], best_record: Dict[str, Any]) -> Dict[str, Any]:
    """Semi-greedy refinement with optional seed-aware proposal bias."""
    binder_len = binder_slice.stop - binder_slice.start
    num_mutations = max(1, int(round(mut_rate * binder_len)))
    tries_per_iter = max(4, num_mutations * 2)
    for local_idx in range(stage_len):
        aux = getattr(model, "aux", {}) or {}
        if "seq" in aux and "logits" in aux["seq"]:
            logits_full = np.array(aux["seq"]["logits"], dtype=float)
        else:
            logits_full = np.array(model._params["seq"], dtype=float)
        current_seq = logits_full.argmax(-1)
        binder_logits = logits_full[0, binder_slice, :]
        binder_seq = current_seq[:, binder_slice].copy()
        current_loss = float(aux.get("loss", np.inf))
        if "plddt" in aux:
            binder_plddt = np.array(aux["plddt"], dtype=float)[binder_slice]
            pos_weights = 1.0 - np.clip(binder_plddt, 0.0, 1.0)
        else:
            pos_weights = np.ones(binder_len, dtype=float)
        if not np.isfinite(pos_weights).any():
            pos_weights = np.ones(binder_len, dtype=float)
        pos_weights = np.clip(pos_weights, 1e-3, None)
        pos_weights /= pos_weights.sum()
        proposals = []
        for _ in range(tries_per_iter):
            mutate_positions = rng.choice(binder_len, size=num_mutations, replace=False, p=pos_weights)
            mutant = binder_seq.copy()
            for pos in mutate_positions:
                probs = _softmax_np(binder_logits[pos])
                probs[int(mutant[0, pos])] = 0.0
                if bias_stage4 and guided_mask[pos]:
                    probs = (1.0 - bias_weight) * probs
                    probs[int(seed_idxs[pos])] += bias_weight
                probs = probs / np.clip(probs.sum(), 1e-8, None)
                mutant[0, pos] = rng.choice(20, p=probs)
            candidate = current_seq.copy()
            candidate[:, binder_slice] = mutant
            aux_candidate = model.predict(seq=candidate, return_aux=True, verbose=False, dropout=False, hard=True, soft=False, temp=1e-2, models=[0], num_models=1, sample_models=False)
            proposals.append((float(aux_candidate["loss"]), candidate.copy(), copy.deepcopy(model.aux)))
        best_loss, best_seq_full, best_aux = min(proposals, key=lambda x: x[0])
        if best_loss <= current_loss or not np.isfinite(current_loss):
            model.set_seq(seq=best_seq_full, bias=model._inputs.get("bias"), set_state=False)
            model.aux = best_aux
        metrics.append(collect_metrics(model, stage="stage4", iter_in_stage=local_idx, binder_slice=binder_slice, alpha=None, kl_weight=0.0, extra={"temp": 1e-2, "accepted_loss": float(model.aux.get("loss", np.nan))}))
        best_record = update_best_record(model, best_record, stage="stage4", iter_in_stage=local_idx, binder_slice=binder_slice)
        model._k = int(getattr(model, "_k", 0)) + 1
    return best_record
def run_hybrid_design(
    target_pdb_path: Path,
    binder_len: int,
    seed_fasta: str,
    template_binder_path: Optional[Path] = None,
    stage_params: Optional[Dict[str, int]] = None,
    init_params: Optional[Dict[str, float]] = None,
    kl_params: Optional[Dict[str, float]] = None,
    guided_fraction: float = 0.5,
    mut_rate: float = 0.05,
    use_kl_prior: bool = True,
    bias_stage4: bool = True,
    rng_seed: int = 0,
    run_tag: str = "hybrid",
    output_root: Path = OUTPUT_BASE,
    baseline_mode: str = "hybrid",
) -> Dict[str, Any]:
    """Execute the hybrid AF2 binder hallucination workflow."""
    target_pdb_path = Path(target_pdb_path)
    if not target_pdb_path.exists():
        raise FileNotFoundError(f"Target PDB not found: {target_pdb_path}")
    template_binder_path = Path(template_binder_path) if template_binder_path else None
    stage_defaults = {
        "STAGE1_ITERS": 50,
        "STAGE1_EXTRA": 25,
        "STAGE2_ITERS": 45,
        "STAGE3_ITERS": 5,
        "STAGE4_ITERS": 15,
    }
    if stage_params:
        stage_defaults.update(stage_params)
    stage_params = stage_defaults
    init_defaults = {"alpha_start": 0.9, "alpha_end": 0.2}
    if init_params:
        init_defaults.update({"alpha_start": init_params.get("alpha_start", 0.9), "alpha_end": init_params.get("alpha_end", 0.2)})
    init_params = init_defaults
    kl_defaults = {"w_start": 0.1, "w_end": 0.0}
    if kl_params:
        kl_defaults.update({"w_start": kl_params.get("w_start", 0.1), "w_end": kl_params.get("w_end", 0.0)})
    kl_params = kl_defaults
    rng = np.random.default_rng(rng_seed)
    seed_seq_raw = read_fasta_str(seed_fasta)
    if not seed_seq_raw:
        raise ValueError("Seed FASTA is empty after parsing.")
    if baseline_mode == "random":
        guided_fraction = 0.0
        init_params["alpha_start"] = 0.0
        init_params["alpha_end"] = 0.0
    elif baseline_mode == "seed":
        guided_fraction = 1.0
        init_params["alpha_start"] = 1.0
        init_params["alpha_end"] = 1.0
    clear_mem()
    model = mk_afdesign_model(protocol="binder", data_dir=str(AF_PARAMS_DIR))
    model._model_names = model._model_names[:1] or ["model_1_multimer_v3"]
    model.set_opt(num_models=1, sample_models=False)
    prep_kwargs = {"pdb_filename": str(target_pdb_path), "binder_len": binder_len, "target_chain": TARGET_CHAIN}
    if template_binder_path and template_binder_path.exists():
        prep_kwargs["binder_pdb"] = str(template_binder_path)
    model.prep_inputs(**prep_kwargs)
    binder_start = model._target_len
    total_seq_len = int(model._params["seq"].shape[1])
    binder_len_model = total_seq_len - binder_start
    print(f"target_len={binder_start}, total_seq_len={total_seq_len}, binder_slots={binder_len_model}")
    total_seq_len = model._params["seq"].shape[1]
    if binder_len_model != binder_len:
        print(f"Adjusted binder length from {binder_len} to {binder_len_model} to match AF2 model inputs.")
    binder_len = binder_len_model
    seed_seq = ensure_seed_length(seed_seq_raw, binder_len, rng)
    guided_mask = make_guided_mask(binder_len, fraction=guided_fraction, rng=rng)
    logits0, seed_logits = init_mixed_logits(seed_seq, binder_len, guided_mask, alpha=init_params["alpha_start"], rng=rng)
    logits0_np = np.array(np.asarray(logits0), dtype=np.float32)
    print(f"logits0_np shape: {logits0_np.shape}, seed length: {len(seed_seq)}")
    seed_logits = jnp.array(seed_logits)
    seed_idxs = np.array([aa_to_idx[a] for a in seed_seq], dtype=np.int64)
    binder_slice = slice(binder_start, binder_start + binder_len)
    seq_logits_full = np.array(model._params["seq"], dtype=float)
    seq_logits_full[:, binder_slice, :] = logits0_np
    model._params["seq"] = seq_logits_full
    soft_init = np.array(jax.nn.softmax(logits0, axis=-1))
    set_binder_softseq(model, soft_init)
    model.set_opt(soft=1.0, hard=0.0, temp=1.0, dropout=True, num_models=1, sample_models=False)
    model._callbacks["design"]["pre"] = []
    metrics: list[Dict[str, Any]] = []
    best_record: Dict[str, Any] = {"loss": None}
    if use_kl_prior:
        kl_loss_fn = jax.jit(lambda x: guided_seq_kl(x, seed_logits, guided_mask))
        kl_grad_fn = jax.jit(jax.grad(lambda x: guided_seq_kl(x, seed_logits, guided_mask)))
    else:
        kl_loss_fn = None
        kl_grad_fn = None
    stage1_total = stage_params["STAGE1_ITERS"] + stage_params["STAGE1_EXTRA"]
    print(f"Stage 1/4 (logits): {stage1_total} iterations")
    kl_total = stage1_total + max(0, stage_params["STAGE2_ITERS"] // 2)
    kl_counter = 0
    for local_idx in range(stage1_total):
        print(f"Stage 1 progress {local_idx+1}/{stage1_total}", end="\r", flush=True)
        alpha = alpha_schedule(local_idx, max(1, stage1_total - 1), init_params["alpha_start"], init_params["alpha_end"])
        kl_weight = 0.0
        if use_kl_prior and kl_loss_fn is not None:
            kl_weight = kl_weight_schedule(kl_counter, max(1, kl_total - 1), kl_params["w_start"], kl_params["w_end"])
            kl_counter += 1
        def pre_cb(mod, alpha_val=alpha):
            logits_now = jnp.array(mod._params["seq"][0, binder_slice, :])
            blended = blend_logits(logits_now, seed_logits, guided_mask, alpha_val)
            seq_all = np.array(mod._params["seq"], dtype=float)
            seq_all[0, binder_slice, :] = np.array(blended)
            mod._params["seq"] = seq_all
        def post_cb(mod, alpha_val=alpha, kl_weight_val=kl_weight):
            if use_kl_prior and kl_loss_fn is not None and kl_weight_val > 0:
                binder_logits = jnp.array(mod._params["seq"][0, binder_slice, :])
                kl_val = kl_loss_fn(binder_logits)
                grad = kl_grad_fn(binder_logits)
                mod.aux["grad"]["seq"][0, binder_slice, :] += np.array(kl_weight_val * grad)
                mod.aux.setdefault("log", {})["guided_kl"] = float(kl_val)
                mod.aux["log"]["kl_weight"] = float(kl_weight_val)
            mod.aux.setdefault("log", {})["alpha"] = float(alpha_val)
        model._callbacks["design"]["pre"] = [pre_cb]
        model.step(callback=post_cb, save_best=True, models=[0], num_models=1, sample_models=False)
        model._callbacks["design"]["pre"] = []
        metrics.append(collect_metrics(model, stage="stage1", iter_in_stage=local_idx, binder_slice=binder_slice, alpha=alpha, kl_weight=kl_weight))
        best_record = update_best_record(model, best_record, stage="stage1", iter_in_stage=local_idx, binder_slice=binder_slice)
    print(f"Stage 2/4 (softmax anneal): {stage_params['STAGE2_ITERS']} iterations")
    for local_idx in range(stage_params["STAGE2_ITERS"]):
        print(f"Stage 2 progress {local_idx+1}/{stage_params['STAGE2_ITERS']}", end="\r", flush=True)
        temp = alpha_schedule(local_idx, max(1, stage_params["STAGE2_ITERS"] - 1), 1.0, 0.05)
        kl_weight = 0.0
        apply_kl = use_kl_prior and kl_loss_fn is not None and local_idx < (stage_params["STAGE2_ITERS"] // 2)
        if apply_kl:
            kl_weight = kl_weight_schedule(kl_counter, max(1, kl_total - 1), kl_params["w_start"], kl_params["w_end"])
            kl_counter += 1
        def post_cb(mod, temp_val=temp, kl_weight_val=kl_weight):
            mod.aux.setdefault("log", {})["temp"] = float(temp_val)
            if apply_kl and kl_weight_val > 0:
                binder_logits = jnp.array(mod._params["seq"][0, binder_slice, :])
                kl_val = kl_loss_fn(binder_logits)
                grad = kl_grad_fn(binder_logits)
                mod.aux["grad"]["seq"][0, binder_slice, :] += np.array(kl_weight_val * grad)
                mod.aux["log"]["guided_kl"] = float(kl_val)
                mod.aux["log"]["kl_weight"] = float(kl_weight_val)
        model.set_opt(temp=float(temp), num_models=1, sample_models=False)
        model.step(callback=post_cb, save_best=True, models=[0], num_models=1, sample_models=False)
        metrics.append(collect_metrics(model, stage="stage2", iter_in_stage=local_idx, binder_slice=binder_slice, alpha=None, kl_weight=kl_weight, extra={"temp": float(temp)}))
        best_record = update_best_record(model, best_record, stage="stage2", iter_in_stage=local_idx, binder_slice=binder_slice)
    model.set_opt(soft=0.0, hard=1.0, temp=1e-2, dropout=False, num_models=1, sample_models=False)
    print(f"Stage 3/4 (hard/STE): {stage_params['STAGE3_ITERS']} iterations")
    for local_idx in range(stage_params["STAGE3_ITERS"]):
        print(f"Stage 3 progress {local_idx+1}/{stage_params['STAGE3_ITERS']}", end="\r", flush=True)
        def post_cb(mod):
            mod.aux.setdefault("log", {})["temp"] = 1e-2
        model.step(callback=post_cb, save_best=True, models=[0], num_models=1, sample_models=False)
        metrics.append(collect_metrics(model, stage="stage3", iter_in_stage=local_idx, binder_slice=binder_slice, alpha=None, kl_weight=0.0, extra={"temp": 1e-2}))
        best_record = update_best_record(model, best_record, stage="stage3", iter_in_stage=local_idx, binder_slice=binder_slice)
    if not getattr(model, "aux", None):
        model.predict(return_aux=True, verbose=False, dropout=False, hard=True, soft=False)
    print(f"Stage 4/4 (semi-greedy): {stage_params['STAGE4_ITERS']} iterations")
    best_record = run_stage4_semigreedy(
        model=model,
        stage_len=stage_params["STAGE4_ITERS"],
        mut_rate=mut_rate,
        guided_mask=np.array(guided_mask, dtype=bool),
        seed_idxs=seed_idxs,
        binder_slice=binder_slice,
        rng=rng,
        bias_stage4=bias_stage4,
        bias_weight=float(STAGE4_SEED_BIAS),
        metrics=metrics,
        best_record=best_record,
    )
    run_dir = output_root / run_tag
    run_dir.mkdir(parents=True, exist_ok=True)
    metrics_df = pd.DataFrame(metrics)
    metrics_path = run_dir / "metrics.csv"
    metrics_df.to_csv(metrics_path, index=False)
    best_seq_full = np.array(best_record["seq"], dtype=np.int64)
    binder_idx = best_seq_full[:, binder_slice]
    binder_seq = indices_to_seq(binder_idx[0])
    fasta_path = run_dir / "best_binder.fasta"
    with open(fasta_path, "w") as handle:
        handle.write(">designed_binder\n")
        handle.write(binder_seq + "\n")
    pdb_path = run_dir / "best_complex.pdb"
    model.set_seq(seq=best_seq_full, bias=model._inputs.get("bias"), set_state=False)
    model.save_pdb(filename=str(pdb_path), get_best=False, aux=best_record["aux"])
    binder_plddt = np.nan
    if "plddt" in best_record["aux"]:
        binder_plddt = float(np.nanmean(np.array(best_record["aux"]["plddt"], dtype=float)[binder_slice]))
    result = {
        "binder_len": binder_len,
        "model": model,
        "metrics_df": metrics_df,
        "metrics_path": metrics_path,
        "fasta_path": fasta_path,
        "best_pdb_path": pdb_path,
        "best_binder_seq": binder_seq,
        "best_loss": float(best_record["loss"]),
        "best_plddt": binder_plddt,
        "run_tag": run_tag,
        "guided_mask": np.array(guided_mask, dtype=bool),
        "seed_seq": seed_seq,
        "binder_slice": binder_slice,
        "best_record": best_record,
        "target_chain": prep_kwargs.get("target_chain", "A"),
        "binder_chain": prep_kwargs.get("binder_chain", "B"),
    }
    return result


In [None]:
seed_source = _prepare_seed(SEED_FASTA, SEED_FASTA_PATH, template_binder=TEMPLATE_BINDER_PDB_PATH, binder_chains=TEMPLATE_BINDER_CHAINS)
template_path = Path(TEMPLATE_BINDER_PDB_PATH)
if not str(template_path).strip() or not template_path.exists():
    template_path = None
binder_len = L if L is not None else len(seed_source)
stage_cfg = {
    "STAGE1_ITERS": STAGE1_ITERS,
    "STAGE1_EXTRA": STAGE1_EXTRA,
    "STAGE2_ITERS": STAGE2_ITERS,
    "STAGE3_ITERS": STAGE3_ITERS,
    "STAGE4_ITERS": STAGE4_ITERS,
}
init_cfg = {"alpha_start": INIT_ALPHA_START, "alpha_end": INIT_ALPHA_END}
kl_cfg = {"w_start": KL_W_START, "w_end": KL_W_END}
run_root = OUTPUT_BASE / RUN_NAME
run_root.mkdir(parents=True, exist_ok=True)
HYBRID_RESULT = run_hybrid_design(
    target_pdb_path=TARGET_PDB_PATH,
    binder_len=binder_len,
    seed_fasta=seed_source,
    template_binder_path=template_path,
    stage_params=stage_cfg,
    init_params=init_cfg,
    kl_params=kl_cfg,
    guided_fraction=GUIDED_FRACTION,
    mut_rate=MUT_RATE,
    use_kl_prior=USE_KL_PRIOR,
    bias_stage4=BIAS_STAGE4_PROPOSALS,
    rng_seed=GLOBAL_RNG_SEED,
    run_tag="hybrid",
    output_root=run_root,
    baseline_mode="hybrid",
)
binder_len = HYBRID_RESULT["binder_len"]
print("Hybrid design complete.")
print(f"  Best loss: {HYBRID_RESULT['best_loss']:.4f}")
print(f"  Binder pLDDT (mean): {HYBRID_RESULT['best_plddt']:.3f}")
print(f"  Sequence saved to: {HYBRID_RESULT['fasta_path']}")
print(f"  Complex PDB saved to: {HYBRID_RESULT['best_pdb_path']}")
BASELINE_RESULTS = []
if RUN_BASELINES:
    modes = []
    if RUN_RANDOM_BASELINE:
        modes.append("random")
    if RUN_SEED_BASELINE:
        modes.append("seed")
    for offset, mode in enumerate(modes, start=1):
        result = run_hybrid_design(
            target_pdb_path=TARGET_PDB_PATH,
            binder_len=binder_len,
            seed_fasta=seed_source,
            template_binder_path=template_path,
            stage_params=stage_cfg,
            init_params=init_cfg,
            kl_params=kl_cfg,
            guided_fraction=GUIDED_FRACTION,
            mut_rate=MUT_RATE,
            use_kl_prior=USE_KL_PRIOR,
            bias_stage4=BIAS_STAGE4_PROPOSALS,
            rng_seed=GLOBAL_RNG_SEED + offset,
            run_tag=mode,
            output_root=run_root,
            baseline_mode=mode,
        )
        BASELINE_RESULTS.append(result)
        print(f"Baseline '{mode}' complete -> loss {result['best_loss']:.4f}")


target_len=130, total_seq_len=229, binder_slots=99
Adjusted binder length from 229 to 99 to match AF2 model inputs.
logits0_np shape: (99, 20), seed length: 99
Stage 1/4 (logits): 75 iterations
Stage 1 progress 1/75

In [None]:
import py3Dmol

pdb_view = py3Dmol.view(width=640, height=480)
with open(str(HYBRID_RESULT['best_pdb_path'])) as handle:
    pdb_view.addModel(handle.read(), 'pdb')

binder_chain = HYBRID_RESULT.get('binder_chain', 'B')
target_chain = HYBRID_RESULT.get('target_chain', 'A')

pdb_view.setStyle({'chain': target_chain}, {'cartoon': {'color': 'white'}})
pdb_view.setStyle({'chain': binder_chain}, {'cartoon': {'color': 'rainbow'}})
pdb_view.zoomTo()
pdb_view.show()


In [None]:
summary_rows = [{
    'run': 'hybrid',
    'loss': HYBRID_RESULT['best_loss'],
    'binder_plddt': HYBRID_RESULT['best_plddt'],
    'fasta_path': str(HYBRID_RESULT['fasta_path']),
    'pdb_path': str(HYBRID_RESULT['best_pdb_path']),
}]
for res in BASELINE_RESULTS:
    summary_rows.append({
        'run': res['run_tag'],
        'loss': res['best_loss'],
        'binder_plddt': res['best_plddt'],
        'fasta_path': str(res['fasta_path']),
        'pdb_path': str(res['best_pdb_path']),
    })
summary_df = pd.DataFrame(summary_rows)
display(summary_df)
summary_df
