# Protein Design using protein Language Model (pLM)

In [1]:
#@title Install dependencies
%%time
import os


USE_AMBER = False
USE_TEMPLATES = False
from sys import version_info
PYTHON_VERSION = f"{version_info.major}.{version_info.minor}"

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  # hack to fix TF crash
  os.system("rm -f /usr/local/lib/python3.*/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so")
  os.system("touch COLABFOLD_READY")

if USE_AMBER or USE_TEMPLATES:
  if not os.path.isfile("CONDA_READY"):
    print("installing conda...")
    os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
    os.system("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
    os.system("mamba config --set auto_update_conda false")
    os.system("touch CONDA_READY")

if USE_TEMPLATES and not os.path.isfile("HH_READY") and USE_AMBER and not os.path.isfile("AMBER_READY"):
  print("installing hhsuite and amber...")
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=8.2.0 python='{PYTHON_VERSION}' pdbfixer")
  os.system("touch HH_READY")
  os.system("touch AMBER_READY")
else:
  if USE_TEMPLATES and not os.path.isfile("HH_READY"):
    print("installing hhsuite...")
    os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
    os.system("touch HH_READY")
  if USE_AMBER and not os.path.isfile("AMBER_READY"):
    print("installing amber...")
    os.system(f"mamba install -y -c conda-forge openmm=8.2.0 python='{PYTHON_VERSION}' pdbfixer")
    os.system("touch AMBER_READY")

installing colabfold...
CPU times: user 2.94 ms, sys: 1.12 ms, total: 4.06 ms
Wall time: 34.1 s


#### Setup ColabFold

In [2]:
from pathlib import Path
from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.batch import set_model_type

In [3]:
model_type = "auto"
is_complex = False
model_type = set_model_type(is_complex, model_type)
model_type

'alphafold2_ptm'

In [4]:
download_alphafold_params(model_type, Path("."))

Downloading alphafold2_ptm weights to .: 100%|██████████| 3.47G/3.47G [02:43<00:00, 22.8MB/s]


## Input Original Sequence

In [5]:
# The full amino acid sequence for TEM-1 beta-lactamase (UniProt: P62593)
sequence = "MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRIDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPVAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW"
len(sequence)

286

In [6]:
START = 164
L = 7

In [7]:
from colabfold.batch import run

test_name = "wt"
wt_outdir    = "af_orig_input"

queries = [(test_name, sequence, None, None)]  # <-- 4-tuple per item

_ = run(
    queries=queries,
    result_dir=wt_outdir,
    data_dir=Path("."),
    is_complex=is_complex,            # monomer
    msa_mode="single_sequence",  # fast; no server call
    use_templates=False,
    model_type="auto",
    num_models=1,
    num_recycles=1,
    rank_by="pLDDT",
    zip_results=False,
)
print("Done →", wt_outdir)

Done → af_orig_input


In [8]:
import json, os
import numpy as np
from typing import Optional, Tuple, Dict, Any
from Bio.PDB import PDBParser, Superimposer, Chain, Atom

In [9]:
scores_json_path = "/content/af_orig_input/wt_scores_rank_001_alphafold2_ptm_model_1_seed_000.json"
pdb_path = "/content/af_orig_input/wt_unrelaxed_rank_001_alphafold2_ptm_model_1_seed_000.pdb"
pae_json_path = "/content/af_orig_input/wt_predicted_aligned_error_v1.json"

In [10]:
def load_plddt_from_scores(scores_json_path: str) -> np.ndarray:
    """
    Read per-residue pLDDT from a ColabFold/AlphaFold scores_rank_*.json.
    Returns an array of shape (N_res,) with values in [0, 100].
    """
    with open(scores_json_path, "r") as f:
        d = json.load(f)
    v = np.asarray(d["plddt"], dtype=float)
    if v.ndim != 1:
        raise ValueError(f"plddt array must be 1D, got shape {v.shape}")
    return v

wt_plddt = load_plddt_from_scores(scores_json_path)
wt_plddt

array([32.5 , 24.12, 27.72, 25.33, 25.25, 26.42, 26.75, 25.64, 24.27,
       26.61, 26.52, 24.23, 26.14, 26.  , 26.59, 25.23, 31.5 , 26.69,
       25.73, 24.92, 26.36, 29.02, 26.31, 34.25, 31.81, 33.03, 33.5 ,
       31.12, 32.16, 40.22, 41.34, 42.88, 48.06, 45.41, 48.03, 50.28,
       45.28, 42.41, 46.91, 43.94, 53.  , 50.94, 53.66, 52.72, 50.12,
       56.41, 51.59, 53.25, 44.59, 45.44, 44.03, 38.59, 45.78, 43.66,
       50.94, 51.09, 47.69, 48.06, 48.5 , 40.25, 38.97, 29.59, 33.69,
       33.75, 38.03, 39.28, 39.62, 41.44, 43.09, 47.28, 50.88, 44.31,
       47.47, 53.34, 54.69, 47.75, 47.12, 49.06, 48.91, 44.34, 43.94,
       44.38, 43.25, 37.12, 34.19, 43.16, 39.25, 37.31, 35.62, 33.  ,
       30.39, 28.41, 28.66, 27.19, 28.69, 23.44, 24.12, 24.66, 23.94,
       22.28, 24.7 , 24.48, 24.14, 24.12, 26.42, 29.33, 31.17, 28.98,
       33.  , 35.84, 37.03, 38.28, 38.28, 39.66, 47.41, 50.44, 49.44,
       50.78, 51.72, 52.78, 55.59, 50.5 , 49.44, 52.38, 47.59, 43.03,
       49.16, 43.06,

In [11]:
wt_plddt_global = float(wt_plddt.mean())
wt_plddt_local  = float(wt_plddt[START:START+L].mean()) if len(wt_plddt) >= START+L else float("nan")
print(f"pLDDT shape: {wt_plddt.shape}")
print(f"pLDDT_global: {wt_plddt_global:.2f}")
print(f"pLDDT_local[{START}:{START+L}]: {wt_plddt_local:.2f}")

pLDDT shape: (286,)
pLDDT_global: 41.13
pLDDT_local[164:171]: 41.00


In [12]:
def load_pae_shape(pae_json_path: Optional[str]) -> Optional[Tuple[int, int]]:
    """
    Read PAE JSON (keys vary: 'predicted_aligned_error' or 'pae') and return its shape.
    Returns (N, N) or None if path is None.
    """
    if pae_json_path is None:
        return None
    with open(pae_json_path, "r") as f:
        d = json.load(f)
    mat = d.get("predicted_aligned_error", None)
    if mat is None:
        mat = d.get("pae", None)
    if mat is None:
        return None
    arr = np.asarray(mat)
    if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
        raise ValueError(f"PAE must be square 2D; got shape {arr.shape}")
    return tuple(arr.shape)

wt_pae_shape = load_pae_shape(pae_json_path)
print(f"PAE shape: {wt_pae_shape}")

PAE shape: (286, 286)


In [13]:
def _ca_atoms_from_pdb(pdb_path: str, chain_id: str = "A") -> list[Atom.Atom]:
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("x", pdb_path)
    model = structure[0]
    chain: Chain = model[chain_id] if chain_id in [c.id for c in model] else list(model)[0]
    return [res["CA"] for res in chain if "CA" in res]


def rmsd_outside_window(wt_pdb_path: str,
                        mut_pdb_path: str,
                        start: int,
                        L: int,
                        chain_id: str = "A") -> float:
    """
    Superimpose WT and mutant on all Cα positions OUTSIDE [start, start+L) and return RMSD (Å).
    Indexing is 0-based over the sequence order in AF/ColabFold PDB (sequential residues).
    """
    wt_atoms = _ca_atoms_from_pdb(wt_pdb_path, chain_id)
    mut_atoms = _ca_atoms_from_pdb(mut_pdb_path, chain_id)
    n = min(len(wt_atoms), len(mut_atoms))
    if n == 0:
        raise ValueError("No CA atoms found in one of the PDBs.")
    # indices to align on (exclude the edited window)
    idx = [i for i in range(n) if not (start <= i < start + L)]
    if len(idx) == 0:
        return float("nan")  # nothing to align on
    ref = [wt_atoms[i] for i in idx]
    mob = [mut_atoms[i] for i in idx]
    sup = Superimposer()
    sup.set_atoms(ref, mob)
    # sup.apply(mob)  # we only need the RMSD
    return float(sup.rms)

rmsd_self = rmsd_outside_window(pdb_path, pdb_path, START, L, chain_id="A")  # sanity ≈ 0
print(f"RMSD_outside (WT vs WT): {rmsd_self:.4f} Å")

RMSD_outside (WT vs WT): 0.0000 Å


In [14]:
# --- 3D viewer (color by pLDDT via B-factors) ---
import py3Dmol, IPython

def show_pdb(pdb_path):
  view = py3Dmol.view(width=600, height=450)
  view.addModel(open(pdb_path,'r').read(), 'pdb')

  # color by pLDDT (stored in B-factor)
  view.setStyle({'cartoon': {
      'colorscheme': {'prop':'b', 'gradient':'roygb', 'min':50, 'max':100}
  }})

  # overlay: highlight the edited window in magenta
  resi1, resi2 = START+1, START+L   # 3Dmol uses 1-based residue indices
  sel = {'chain':'A', 'resi': list(range(resi1, resi2+1))}
  view.addStyle(sel, {'cartoon': {'color':'magenta'}})
  # (optional) make it pop even more
  view.addStyle(sel, {'stick': {'color':'magenta','radius':0.25}})

  view.setBackgroundColor('white')
  view.zoomTo()
  return view

IPython.display.display(show_pdb(pdb_path))

<py3Dmol.view at 0x7c1128193350>

# Generate Candidates

In [15]:
import torch
from torch.nn.functional import log_softmax
from transformers import EsmTokenizer, EsmForMaskedLM

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [17]:
model_name = "facebook/esm2_t36_3B_UR50D"
tok = EsmTokenizer.from_pretrained(model_name)
mlm = EsmForMaskedLM.from_pretrained(model_name).to(device).eval()

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

pytorch_model.bin.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
AA20 = list("ACDEFGHIKLMNPQRSTVWY")

def _mk_masked_spaced(seq: str, start: int, length: int) -> str:
    # Turn "ACDE..." into "A C D E ..." and insert <mask> tokens
    left  = " ".join(list(seq[:start]))
    mid   = " ".join(["<mask>"] * length)
    right = " ".join(list(seq[start+length:]))
    return " ".join([p for p in (left, mid, right) if p])

In [19]:
def infill_span(seq: str, mask_start: int, mask_len: int,
                n_samples: int = 64, top_k: int = 12, temperature: float = 1.0,
                hard_constraints: dict | None = None, seed: int = 0):
    """
    Redesign a contiguous span [mask_start : mask_start+mask_len) in `seq`,
    keeping the rest fixed. Returns `n_samples` full-length sequences.
    Requires global: tok (EsmTokenizer), mlm (EsmForMaskedLM), device.
    """
    torch.manual_seed(seed)

    # 1) Build a spaced input with explicit <mask> tokens
    masked_spaced = _mk_masked_spaced(seq, mask_start, mask_len)

    # 2) Batch it
    batch = [masked_spaced] * n_samples
    enc = tok(batch, return_tensors="pt", padding=True, add_special_tokens=True).to(device)

    # Safety checks
    mask_id = tok.mask_token_id
    input_ids = enc["input_ids"].clone()
    attn = enc["attention_mask"].bool()

    # Ensure we truly have `mask_len` masks per row
    mask_counts = (input_ids == mask_id).sum(dim=1)
    if not torch.all(mask_counts == mask_len):
        raise ValueError(f"Expected {mask_len} masks per row, got {mask_counts.tolist()}")

    # 3) Precompute AA token ids
    aa_ids = torch.tensor([tok.convert_tokens_to_ids(a) for a in AA20], device=device)

    # 4) Fill left→right one residue at a time
    for j in range(mask_len):
        with torch.no_grad():
            logits = mlm(input_ids=input_ids, attention_mask=attn).logits  # [B, T, V]

        is_mask = (input_ids == mask_id)                        # [B, T]
        has_mask = is_mask.any(dim=1)                           # [B]
        if not has_mask.any():
            break

        active_rows = torch.nonzero(has_mask, as_tuple=False).squeeze(-1)  # [B_active]
        # **Key fix**: pick exactly ONE column per row — the leftmost `<mask>`
        first_pos_by_row = torch.argmax(is_mask[active_rows].to(torch.int8), dim=1)  # [B_active]

        # Slice logits at those positions → [B_active, V]
        token_logits = logits[active_rows, first_pos_by_row, :]
        token_logits = token_logits / temperature

        # Restrict to 20 AAs and sample with top-k
        aa_logits = token_logits[:, aa_ids]                     # [B_active, 20]
        k = min(top_k, aa_logits.shape[1])
        topk_logits, topk_idx = torch.topk(aa_logits, k=k, dim=-1)
        probs = torch.softmax(topk_logits, dim=-1)
        sampled_k = torch.multinomial(probs, num_samples=1).squeeze(-1)    # [B_active]
        sampled_aa_idx = topk_idx[torch.arange(active_rows.size(0), device=device), sampled_k]
        sampled_token_ids = aa_ids[sampled_aa_idx]                           # [B_active]

        # Optional per-position constraints (absolute seq position)
        abs_pos = mask_start + j
        if hard_constraints and abs_pos in hard_constraints:
            allowed = list(hard_constraints[abs_pos])
            allowed_ids = torch.tensor([tok.convert_tokens_to_ids(a) for a in allowed], device=device)
            bad = ~torch.isin(sampled_token_ids, allowed_ids)
            if bad.any():
                replace = allowed_ids[torch.randint(len(allowed_ids), (int(bad.sum()),), device=device)]
                sampled_token_ids[bad] = replace

        # Write sampled tokens back
        input_ids[active_rows, first_pos_by_row] = sampled_token_ids

    # 5) Decode and strip spaces
    dec = tok.batch_decode(input_ids, skip_special_tokens=True)
    return [d.replace(" ", "") for d in dec]

In [20]:
cands = infill_span(
    sequence,
    mask_start=START, mask_len=L,
    n_samples=64, top_k=12, temperature=0.9,
    hard_constraints=None, seed=0,
)

In [21]:
len(cands)

64

In [22]:
cands[0]

'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRIDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPVAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW'

In [23]:
def pll_window(seq: str, start: int, length: int) -> float:
    total = 0.0
    for i in range(start, start+length):
        tokens = list(seq)
        true_res = tokens[i]
        tokens[i] = "<mask>"
        masked_spaced = " ".join(tokens)
        enc = tok(masked_spaced, return_tensors="pt", add_special_tokens=True).to(device)
        with torch.no_grad():
            logits = mlm(**enc).logits[0]
        mask_pos = (enc["input_ids"][0] == tok.mask_token_id).nonzero(as_tuple=False).item()
        true_id = tok.convert_tokens_to_ids(true_res)
        total += log_softmax(logits[mask_pos], dim=-1)[true_id].item()
    return total

In [24]:
scored = [(s, pll_window(s, START, L)) for s in set(cands)]
scored.sort(key=lambda x: x[1], reverse=True)
top10 = [s for s,_ in scored[:10]]

In [25]:
top10[0]

'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRIDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPVAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW'

In [26]:
for x in top10:
  print(x[START:START+L])

PELNEAI
PELNEAT
TELNEAT
PSLNEAI
PELNEAL
PELNEAM
PELNEAV
TELNSAI
PELNEAS
PALNEAI


### Sequence-level checks (fast, model-only)

These don’t need structure prediction and will immediately separate junk from plausible designs.

What to compute

- PLL_window: pseudo-log-likelihood over the edited span.
- ΔPLL_window vs WT in the same span.
- PLL_full_per_res: pseudo-log-likelihood over the full sequence, normalized by length.
- entropy_window: average predictive entropy in the span (lower = more certain).
- BLOSUM_window: sum of BLOSUM62 substitution scores for mutated positions vs WT (how conservative the edits are).
- Simple developability heuristics on the span: net charge (pH 7), Kyte–Doolittle hydropathy, low-complexity (Shannon entropy), glycosylation motif N-X-S/T (X≠P), cysteine count.

In [27]:
import math

KD = dict(zip(list("IVLFWYPMTASGHCNQEDKR"), [4.5,3.8,3.8,2.8,2.5,2.3,2.0,1.9,1.3,1.0,0.0,-0.4,-0.8,-1.0,-3.5,-3.5,-3.5,-3.9,-4.5,-3.9]))
pKa_side = {"D":3.9,"E":4.3,"H":6.0,"C":8.3,"Y":10.1,"K":10.5,"R":12.5}
BLOSUM62 = {('A','A'):4,('A','R'):-1,('A','N'):-2,('A','D'):-2,('A','C'):0,('A','Q'):-1,('A','E'):-1,('A','G'):0,('A','H'):-2,('A','I'):-1,('A','L'):-1,('A','K'):-1,('A','M'):-1,('A','F'):-2,('A','P'):-1,('A','S'):1,('A','T'):0,('A','W'):-3,('A','Y'):-2,('A','V'):0,
('R','A'):-1,('R','R'):5,('R','N'):0,('R','D'):-2,('R','C'):-3,('R','Q'):1,('R','E'):0,('R','G'):-2,('R','H'):0,('R','I'):-3,('R','L'):-2,('R','K'):2,('R','M'):-1,('R','F'):-3,('R','P'):-2,('R','S'):-1,('R','T'):-1,('R','W'):-3,('R','Y'):-2,('R','V'):-3,
('N','A'):-2,('N','R'):0,('N','N'):6,('N','D'):1,('N','C'):-3,('N','Q'):0,('N','E'):0,('N','G'):0,('N','H'):1,('N','I'):-3,('N','L'):-3,('N','K'):0,('N','M'):-2,('N','F'):-3,('N','P'):-2,('N','S'):1,('N','T'):0,('N','W'):-4,('N','Y'):-2,('N','V'):-3,
('D','A'):-2,('D','R'):-2,('D','N'):1,('D','D'):6,('D','C'):-3,('D','Q'):0,('D','E'):2,('D','G'):-1,('D','H'):-1,('D','I'):-3,('D','L'):-4,('D','K'):-1,('D','M'):-3,('D','F'):-3,('D','P'):-1,('D','S'):0,('D','T'):-1,('D','W'):-4,('D','Y'):-3,('D','V'):-3,
('C','A'):0,('C','R'):-3,('C','N'):-3,('C','D'):-3,('C','C'):9,('C','Q'):-3,('C','E'):-4,('C','G'):-3,('C','H'):-3,('C','I'):-1,('C','L'):-1,('C','K'):-3,('C','M'):-1,('C','F'):-2,('C','P'):-3,('C','S'):-1,('C','T'):-1,('C','W'):-2,('C','Y'):-2,('C','V'):-1,
('Q','A'):-1,('Q','R'):1,('Q','N'):0,('Q','D'):0,('Q','C'):-3,('Q','Q'):5,('Q','E'):2,('Q','G'):-2,('Q','H'):0,('Q','I'):-3,('Q','L'):-2,('Q','K'):1,('Q','M'):0,('Q','F'):-3,('Q','P'):-1,('Q','S'):0,('Q','T'):-1,('Q','W'):-2,('Q','Y'):-1,('Q','V'):-2,
('E','A'):-1,('E','R'):0,('E','N'):0,('E','D'):2,('E','C'):-4,('E','Q'):2,('E','E'):5,('E','G'):-2,('E','H'):0,('E','I'):-3,('E','L'):-3,('E','K'):1,('E','M'):-2,('E','F'):-3,('E','P'):-1,('E','S'):0,('E','T'):-1,('E','W'):-3,('E','Y'):-2,('E','V'):-2,
('G','A'):0,('G','R'):-2,('G','N'):0,('G','D'):-1,('G','C'):-3,('G','Q'):-2,('G','E'):-2,('G','G'):6,('G','H'):-2,('G','I'):-4,('G','L'):-4,('G','K'):-2,('G','M'):-3,('G','F'):-3,('G','P'):-2,('G','S'):0,('G','T'):-2,('G','W'):-2,('G','Y'):-3,('G','V'):-3,
('H','A'):-2,('H','R'):0,('H','N'):1,('H','D'):-1,('H','C'):-3,('H','Q'):0,('H','E'):0,('H','G'):-2,('H','H'):8,('H','I'):-3,('H','L'):-3,('H','K'):-1,('H','M'):-2,('H','F'):-1,('H','P'):-2,('H','S'):-1,('H','T'):-2,('H','W'):-2,('H','Y'):2,('H','V'):-3,
('I','A'):-1,('I','R'):-3,('I','N'):-3,('I','D'):-3,('I','C'):-1,('I','Q'):-3,('I','E'):-3,('I','G'):-4,('I','H'):-3,('I','I'):4,('I','L'):2,('I','K'):-3,('I','M'):1,('I','F'):0,('I','P'):-3,('I','S'):-2,('I','T'):-1,('I','W'):-3,('I','Y'):-1,('I','V'):3,
('L','A'):-1,('L','R'):-2,('L','N'):-3,('L','D'):-4,('L','C'):-1,('L','Q'):-2,('L','E'):-3,('L','G'):-4,('L','H'):-3,('L','I'):2,('L','L'):4,('L','K'):-2,('L','M'):2,('L','F'):0,('L','P'):-3,('L','S'):-2,('L','T'):-1,('L','W'):-2,('L','Y'):-1,('L','V'):1,
('K','A'):-1,('K','R'):2,('K','N'):0,('K','D'):-1,('K','C'):-3,('K','Q'):1,('K','E'):1,('K','G'):-2,('K','H'):-1,('K','I'):-3,('K','L'):-2,('K','K'):5,('K','M'):-1,('K','F'):-3,('K','P'):-1,('K','S'):0,('K','T'):-1,('K','W'):-3,('K','Y'):-2,('K','V'):-2,
('M','A'):-1,('M','R'):-1,('M','N'):-2,('M','D'):-3,('M','C'):-1,('M','Q'):0,('M','E'):-2,('M','G'):-3,('M','H'):-2,('M','I'):1,('M','L'):2,('M','K'):-1,('M','M'):5,('M','F'):0,('M','P'):-2,('M','S'):-1,('M','T'):-1,('M','W'):-1,('M','Y'):-1,('M','V'):1,
('F','A'):-2,('F','R'):-3,('F','N'):-3,('F','D'):-3,('F','C'):-2,('F','Q'):-3,('F','E'):-3,('F','G'):-3,('F','H'):-1,('F','I'):0,('F','L'):0,('F','K'):-3,('F','M'):0,('F','F'):6,('F','P'):-4,('F','S'):-2,('F','T'):-2,('F','W'):1,('F','Y'):3,('F','V'):-1,
('P','A'):-1,('P','R'):-2,('P','N'):-2,('P','D'):-1,('P','C'):-3,('P','Q'):-1,('P','E'):-1,('P','G'):-2,('P','H'):-2,('P','I'):-3,('P','L'):-3,('P','K'):-1,('P','M'):-2,('P','F'):-4,('P','P'):7,('P','S'):-1,('P','T'):-1,('P','W'):-4,('P','Y'):-3,('P','V'):-2,
('S','A'):1,('S','R'):-1,('S','N'):1,('S','D'):0,('S','C'):-1,('S','Q'):0,('S','E'):0,('S','G'):0,('S','H'):-1,('S','I'):-2,('S','L'):-2,('S','K'):0,('S','M'):-1,('S','F'):-2,('S','P'):-1,('S','S'):4,('S','T'):1,('S','W'):-3,('S','Y'):-2,('S','V'):-2,
('T','A'):0,('T','R'):-1,('T','N'):0,('T','D'):-1,('T','C'):-1,('T','Q'):-1,('T','E'):-1,('T','G'):-2,('T','H'):-2,('T','I'):-1,('T','L'):-1,('T','K'):-1,('T','M'):-1,('T','F'):-2,('T','P'):-1,('T','S'):1,('T','T'):5,('T','W'):-2,('T','Y'):-2,('T','V'):0,
('W','A'):-3,('W','R'):-3,('W','N'):-4,('W','D'):-4,('W','C'):-2,('W','Q'):-2,('W','E'):-3,('W','G'):-2,('W','H'):-2,('W','I'):-3,('W','L'):-2,('W','K'):-3,('W','M'):-1,('W','F'):1,('W','P'):-4,('W','S'):-3,('W','T'):-2,('W','W'):11,('W','Y'):2,('W','V'):-3,
('Y','A'):-2,('Y','R'):-2,('Y','N'):-2,('Y','D'):-3,('Y','C'):-2,('Y','Q'):-1,('Y','E'):-2,('Y','G'):-3,('Y','H'):2,('Y','I'):-1,('Y','L'):-1,('Y','K'):-2,('Y','M'):-1,('Y','F'):3,('Y','P'):-3,('Y','S'):-2,('Y','T'):-2,('Y','W'):2,('Y','Y'):7,('Y','V'):-1,
('V','A'):0,('V','R'):-3,('V','N'):-3,('V','D'):-3,('V','C'):-1,('V','Q'):-2,('V','E'):-2,('V','G'):-3,('V','H'):-3,('V','I'):3,('V','L'):1,('V','K'):-2,('V','M'):1,('V','F'):-1,('V','P'):-2,('V','S'):-2,('V','T'):0,('V','W'):-3,('V','Y'):-1,('V','V'):4}

def pll_over_positions(seq: str, positions: list[int]) -> float:
    total = 0.0
    for i in positions:
        tokens = list(seq)
        true_res = tokens[i]
        tokens[i] = "<mask>"
        spaced = " ".join(tokens)
        enc = tok(spaced, return_tensors="pt", add_special_tokens=True).to(device)
        with torch.no_grad():
            logits = mlm(**enc).logits[0]
        mask_pos = (enc["input_ids"][0] == tok.mask_token_id).nonzero(as_tuple=False).item()
        total += log_softmax(logits[mask_pos], dim=-1)[tok.convert_tokens_to_ids(true_res)].item()
    return total

def pll_full_per_res(seq: str) -> float:
    return pll_over_positions(seq, list(range(len(seq)))) / len(seq)

def entropy_at(seq: str, i: int) -> float:
    tokens = list(seq); tokens[i] = "<mask>"
    spaced = " ".join(tokens)
    enc = tok(spaced, return_tensors="pt", add_special_tokens=True).to(device)
    with torch.no_grad():
        logits = mlm(**enc).logits[0]
    mask_pos = (enc["input_ids"][0] == tok.mask_token_id).nonzero(as_tuple=False).item()
    p = torch.softmax(logits[mask_pos], dim=-1)
    # restrict to AA20
    aa_ids = torch.tensor([tok.convert_tokens_to_ids(a) for a in AA20], device=p.device)
    q = torch.softmax(logits[mask_pos, aa_ids], dim=-1)
    ent = -(q * (q+1e-12).log()).sum().item()
    return ent

def entropy_window(seq: str, start: int, L: int) -> float:
    return sum(entropy_at(seq, i) for i in range(start, start+L)) / L

def blosum_window(mut_seq: str, wt_seq: str, start: int, L: int) -> int:
    s = 0
    for i in range(start, start+L):
        a, b = wt_seq[i], mut_seq[i]
        s += BLOSUM62.get((b,a), BLOSUM62.get((a,b), -4))  # default to a mild penalty
    return s

def net_charge_span(seq: str, start: int, L: int, pH=7.0) -> float:
    # Very rough Henderson–Hasselbalch: side chains only
    charge = 0.0
    for aa in seq[start:start+L]:
        if aa in ('D','E'):
            charge += -1.0/(1+10**(pKa_side[aa]-pH))
        elif aa in ('K','R'):
            charge += 1.0/(1+10**(pH-pKa_side[aa]))
        elif aa=='H':
            charge += 1.0/(1+10**(pH-pKa_side[aa]))
    return charge

def kd_hydropathy_span(seq: str, start: int, L: int) -> float:
    vals = [KD.get(a,0.0) for a in seq[start:start+L]]
    return sum(vals)/L if L>0 else 0.0

def has_nglyc_motif(seq: str, start: int, L: int) -> bool:
    s = seq[start:start+L]
    for i in range(len(s)-2):
        if s[i]=='N' and s[i+1] != 'P' and s[i+2] in ('S','T'):
            return True
    return False

def shannon_entropy_span(seq: str, start: int, L: int) -> float:
    from collections import Counter
    sub = seq[start:start+L]
    counts = Counter(sub)
    tot = float(L)
    return -sum((c/tot)*math.log(c/tot+1e-12) for c in counts.values())

In [28]:
# assuming: wt = original sequence, start=164, L=7, and top10 is a list[str]
def score_record(mut):
    pll_win = pll_over_positions(mut, list(range(START, START+L)))
    return {
        "sequence": mut,
        "PLL_window": pll_win,
        "ΔPLL_window_vs_WT": pll_win - pll_over_positions(wt, list(range(START, START+L))),
        "PLL_full_per_res": pll_full_per_res(mut),
        "entropy_window": entropy_window(mut, START, L),
        "BLOSUM_window": blosum_window(mut, wt, START, L),
        "charge_span_pH7": net_charge_span(mut, START, L),
        "KD_hydropathy_span": kd_hydropathy_span(mut, START, L),
        "has_NXS_T_motif": has_nglyc_motif(mut, START, L),
        "cysteines_in_span": mut[START:START+L].count('C'),
    }

wt = sequence  # your original sequence
seq_table = [score_record(s) for s in top10]
import pandas as pd
df = pd.DataFrame(seq_table).sort_values(["PLL_window","BLOSUM_window"], ascending=[False, False])
df.head(10)

Unnamed: 0,sequence,PLL_window,ΔPLL_window_vs_WT,PLL_full_per_res,entropy_window,BLOSUM_window,charge_span_pH7,KD_hydropathy_span,has_NXS_T_motif,cysteines_in_span
0,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-0.62828,0.0,-0.429313,0.273735,35,-1.996017,0.114286,False,0
1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-2.186505,-1.558225,-0.437887,0.241537,30,-1.996017,-0.342857,False,0
2,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-3.337751,-2.709471,-0.444886,0.28439,22,-1.996017,-0.442857,False,0
3,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-4.077062,-3.448782,-0.443166,0.298199,30,-0.998009,0.614286,False,0
4,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-4.325497,-3.697216,-0.443249,0.288715,33,-1.996017,0.014286,False,0
5,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-4.402236,-3.773956,-0.444997,0.278376,32,-1.996017,-0.257143,False,0
6,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-4.863198,-4.234918,-0.448053,0.268563,34,-1.996017,0.014286,False,0
7,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-5.096482,-4.468202,-0.445916,0.268847,22,-0.998009,0.514286,False,0
8,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-5.217226,-4.588946,-0.449853,0.247924,29,-1.996017,-0.528571,False,0
9,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,-5.309041,-4.680761,-0.449528,0.302671,29,-0.998009,0.757143,False,0


## Structure screen (adds geometry)

Metrics to compute
- pLDDT_global
- pLDDT_local: mean pLDDT over residues `[start:start+L]`.
- RMSD_outside: CA-RMSD of all residues outside `[start:start+L]` between WT structure (predicted once) and each mutant (after superposition on the outside region).



In [29]:
# Build ColabFold queries - expects 4-tuples
names   = [f"cand_{i}" for i in range(len(top10))]
queries = [(n, s, None, None) for n, s in zip(names, top10)]
queries[0]

('cand_0',
 'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRIDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPVAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW',
 None,
 None)

In [30]:
outdir = "af_candidates"
_ = run(
    queries=queries,
    result_dir=outdir,
    data_dir=Path("."),
    is_complex=is_complex,            # monomer
    msa_mode="single_sequence",  # fast; no server call
    use_templates=False,
    model_type="auto",
    num_models=1,
    num_recycles=1,
    rank_by="pLDDT",
    zip_results=False,
)
print("Done →", outdir)

Done → af_candidates


In [31]:
def cand_paths(job: str):
    pdb    = f"{outdir}/{job}_unrelaxed_rank_001_alphafold2_ptm_model_1_seed_000.pdb"
    scores = f"{outdir}/{job}_scores_rank_001_alphafold2_ptm_model_1_seed_000.json"
    pae    = f"{outdir}/{job}_predicted_aligned_error_v1.json"
    return pdb, scores, pae

cand_paths(names[0])

('af_candidates/cand_0_unrelaxed_rank_001_alphafold2_ptm_model_1_seed_000.pdb',
 'af_candidates/cand_0_scores_rank_001_alphafold2_ptm_model_1_seed_000.json',
 'af_candidates/cand_0_predicted_aligned_error_v1.json')

In [32]:
def compute_metrics_for_target(
    *,
    wt_pdb_path: str,
    mut_pdb_path: str,
    start: int,
    L: int,
    mut_scores_json_path: str,
    mut_pae_json_path: str,
    chain_id: str = "A",
) -> dict:
    # prefer JSON; fallback to PDB B-factors
    if mut_scores_json_path:
        plddt_arr = load_plddt_from_scores(mut_scores_json_path)
        plddt_source = "scores_json"
    else:
        plddt_arr = load_plddt_from_pdb_bfactor(mut_pdb_path, chain_id=chain_id)
        plddt_source = "pdb_bfactor"

    plddt_local  = float(plddt_arr[start:start+L].mean()) if len(plddt_arr) >= start+L else float("nan")
    plddt_global = float(plddt_arr.mean())
    pae_shape    = load_pae_shape(mut_pae_json_path)
    rmsd         = rmsd_outside_window(wt_pdb_path, mut_pdb_path, start, L, chain_id=chain_id)

    return {
        "pLDDT_global": plddt_global,
        "pLDDT_local": plddt_local,
        "plddt_shape": (len(plddt_arr),),
        "plddt_source": plddt_source,
        "pae_shape": pae_shape,
        "RMSD_outside": rmsd,
    }

In [33]:
wt_pdb_path = pdb_path
wt_plddt_local = wt_plddt_local

In [34]:
rows = []
for job, seq in zip(names, top10):
    pdb_path, scores_json, pae_json = cand_paths(job)
    metrics = compute_metrics_for_target(
        wt_pdb_path=wt_pdb_path,
        mut_pdb_path=pdb_path,
        mut_scores_json_path=scores_json,
        mut_pae_json_path=pae_json,
        start=START, L=L,
        chain_id="A",
    )
    rows.append({
        "name": job,
        "sequence": seq,
        "pdb_path": pdb_path,
        **metrics,
        "ΔpLDDT_local_vs_WT": metrics["pLDDT_local"] - wt_plddt_local,
        "N_res": metrics["plddt_shape"][0],
    })


In [35]:
df_top10 = (
    pd.DataFrame(rows)
      .sort_values(["pLDDT_local","RMSD_outside"], ascending=[False, True])
      .reset_index(drop=True)
)

# Quick glance at shapes & key metrics
print(df_top10[["name","plddt_shape","pae_shape"]])
df_top10

     name plddt_shape   pae_shape
0  cand_4      (286,)  (286, 286)
1  cand_5      (286,)  (286, 286)
2  cand_0      (286,)  (286, 286)
3  cand_6      (286,)  (286, 286)
4  cand_2      (286,)  (286, 286)
5  cand_9      (286,)  (286, 286)
6  cand_7      (286,)  (286, 286)
7  cand_3      (286,)  (286, 286)
8  cand_8      (286,)  (286, 286)
9  cand_1      (286,)  (286, 286)


Unnamed: 0,name,sequence,pdb_path,pLDDT_global,pLDDT_local,plddt_shape,plddt_source,pae_shape,RMSD_outside,ΔpLDDT_local_vs_WT,N_res
0,cand_4,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_4_unrelaxed_rank_001_alphaf...,41.692203,49.03,"(286,)",scores_json,"(286, 286)",2.339937,8.03,286
1,cand_5,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_5_unrelaxed_rank_001_alphaf...,39.799021,41.245714,"(286,)",scores_json,"(286, 286)",2.552505,0.245714,286
2,cand_0,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_0_unrelaxed_rank_001_alphaf...,41.134056,41.0,"(286,)",scores_json,"(286, 286)",9.24958e-15,0.0,286
3,cand_6,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_6_unrelaxed_rank_001_alphaf...,40.865874,39.255714,"(286,)",scores_json,"(286, 286)",1.02201,-1.744286,286
4,cand_2,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_2_unrelaxed_rank_001_alphaf...,37.738427,37.238571,"(286,)",scores_json,"(286, 286)",11.73585,-3.761429,286
5,cand_9,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_9_unrelaxed_rank_001_alphaf...,38.632098,34.685714,"(286,)",scores_json,"(286, 286)",7.074556,-6.314286,286
6,cand_7,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_7_unrelaxed_rank_001_alphaf...,39.35993,33.807143,"(286,)",scores_json,"(286, 286)",5.232699,-7.192857,286
7,cand_3,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_3_unrelaxed_rank_001_alphaf...,38.422797,30.874286,"(286,)",scores_json,"(286, 286)",2.949343,-10.125714,286
8,cand_8,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_8_unrelaxed_rank_001_alphaf...,37.716014,30.641429,"(286,)",scores_json,"(286, 286)",2.411529,-10.358571,286
9,cand_1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,af_candidates/cand_1_unrelaxed_rank_001_alphaf...,37.967063,29.764286,"(286,)",scores_json,"(286, 286)",3.148229,-11.235714,286


In [36]:
df_top10.columns

Index(['name', 'sequence', 'pdb_path', 'pLDDT_global', 'pLDDT_local',
       'plddt_shape', 'plddt_source', 'pae_shape', 'RMSD_outside',
       'ΔpLDDT_local_vs_WT', 'N_res'],
      dtype='object')

In [37]:
df_top10[['sequence', 'pLDDT_global', 'pLDDT_local', 'ΔpLDDT_local_vs_WT', 'RMSD_outside']]

Unnamed: 0,sequence,pLDDT_global,pLDDT_local,ΔpLDDT_local_vs_WT,RMSD_outside
0,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,41.692203,49.03,8.03,2.339937
1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,39.799021,41.245714,0.245714,2.552505
2,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,41.134056,41.0,0.0,9.24958e-15
3,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,40.865874,39.255714,-1.744286,1.02201
4,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,37.738427,37.238571,-3.761429,11.73585
5,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,38.632098,34.685714,-6.314286,7.074556
6,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,39.35993,33.807143,-7.192857,5.232699
7,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,38.422797,30.874286,-10.125714,2.949343
8,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,37.716014,30.641429,-10.358571,2.411529
9,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,37.967063,29.764286,-11.235714,3.148229


In [38]:
for x in df_top10['sequence']:
  print(x[START: START + L])

PELNEAL
PELNEAM
PELNEAI
PELNEAV
TELNEAT
PALNEAI
TELNSAI
PSLNEAI
PELNEAS
PELNEAT


In [39]:
sequence[START: START + L]

'PELNEAI'

## With MSA

In [40]:
test_name = "wt_msa"
wt_outdir    = "af_orig_input"

queries = [(test_name, sequence, None, None)]  # <-- 4-tuple per item

_ = run(
    queries=queries,
    result_dir=wt_outdir,
    data_dir=Path("."),
    is_complex=is_complex,            # monomer
    msa_mode="mmseqs2_uniref_env",  # fast; no server call
    use_templates=False,
    model_type="auto",
    num_models=5,
    num_recycles=3,
    rank_by="pLDDT",
    zip_results=False,
)
print("Done →", wt_outdir)

COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]


Done → af_orig_input


In [41]:
import os

def get_rank1_triples(output_dir: str, jobnames: list[str], prefer_relaxed: bool = False):
    """
    Return a dict: jobname -> {"pdb": <path or None>, "scores": <path or None>, "pae": <path or None>}
    Scans:
      - <output_dir>/
      - <output_dir>/<job>/
    Picks files for rank_001 only; PAE is the job-level predicted_aligned_error JSON.
    """
    out = {}
    for job in jobnames:
        # Collect candidate files from flat dir and (if present) job subdir
        flat_dir = output_dir
        sub_dir  = os.path.join(output_dir, job)
        files = []
        if os.path.isdir(flat_dir):
            files += [os.path.join(flat_dir, f) for f in os.listdir(flat_dir) if f.startswith(job + "_")]
        if os.path.isdir(sub_dir):
            files += [os.path.join(sub_dir, f) for f in os.listdir(sub_dir)]

        # Normalize to base names for matching
        def pick(pred):
            hits = [p for p in files if pred(os.path.basename(p))]
            hits.sort()
            return hits[0] if hits else None

        # Scores JSON (rank_001 from *any* model/seed)
        scores = pick(lambda b: "_scores_rank_001_" in b and b.endswith(".json"))

        # PAE JSON (job-level)
        pae = pick(lambda b: "predicted_aligned_error" in b and b.endswith(".json"))

        # PDB (prefer un/relaxed per flag; both restricted to rank_001)
        pdb_relaxed   = pick(lambda b: "_relaxed_rank_001_"   in b and b.endswith(".pdb"))
        pdb_unrelaxed = pick(lambda b: "_unrelaxed_rank_001_" in b and b.endswith(".pdb"))
        pdb = (pdb_relaxed if prefer_relaxed else pdb_unrelaxed) or (pdb_unrelaxed if prefer_relaxed else pdb_relaxed)

        # Fallback (rare naming variants): any rank_001 PDB
        if pdb is None:
            pdb = pick(lambda b: "_rank_001_" in b and b.endswith(".pdb"))

        out[job] = {"pdb": pdb, "scores": scores, "pae": pae}
    return out

rank1_triples = get_rank1_triples(output_dir=wt_outdir, jobnames=[test_name])
print(rank1_triples)

{'wt_msa': {'pdb': 'af_orig_input/wt_msa_unrelaxed_rank_001_alphafold2_ptm_model_3_seed_000.pdb', 'scores': 'af_orig_input/wt_msa_scores_rank_001_alphafold2_ptm_model_3_seed_000.json', 'pae': 'af_orig_input/wt_msa_predicted_aligned_error_v1.json'}}


In [42]:
# Build ColabFold queries - expects 4-tuples
msa_names   = [f"MSA_cand_{i}" for i in range(len(top10))]
msa_queries = [(n, s, None, None) for n, s in zip(msa_names, top10)]
msa_queries[0]

('MSA_cand_0',
 'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRIDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPVAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW',
 None,
 None)

In [43]:
_ = run(
    queries=msa_queries,
    result_dir=outdir,
    data_dir=Path("."),
    is_complex=is_complex,            # monomer
    msa_mode="mmseqs2_uniref_env",
    use_templates=False,
    model_type="auto",
    num_models=5,
    num_recycles=3,
    rank_by="pLDDT",
    zip_results=False,
)
print("Done →", outdir)

COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:02 remaining: 00:00]


Done → af_candidates


In [44]:
cand_triples = get_rank1_triples(output_dir=outdir, jobnames=msa_names, prefer_relaxed=False)
cand_triples['MSA_cand_0']

{'pdb': 'af_candidates/MSA_cand_0_unrelaxed_rank_001_alphafold2_ptm_model_3_seed_000.pdb',
 'scores': 'af_candidates/MSA_cand_0_scores_rank_001_alphafold2_ptm_model_3_seed_000.json',
 'pae': 'af_candidates/MSA_cand_0_predicted_aligned_error_v1.json'}

In [45]:
msa_wt_path = rank1_triples['wt_msa']['pdb']
msa_wt_path

'af_orig_input/wt_msa_unrelaxed_rank_001_alphafold2_ptm_model_3_seed_000.pdb'

In [46]:
rows = []
for job, seq in zip(msa_names, top10):
    t = cand_triples[job]   # {'pdb': ..., 'scores': ..., 'pae': ...}
    m = compute_metrics_for_target(
        wt_pdb_path=msa_wt_path,          # WT predicted with MSA settings
        mut_pdb_path=t["pdb"],
        mut_scores_json_path=t["scores"], # per-candidate JSON
        mut_pae_json_path=t["pae"],
        start=START, L=L, chain_id="A",
    )
    rows.append({"name": job, "sequence": seq, **m, "pdb_path": t["pdb"]})

df_msa = (pd.DataFrame(rows)
            .sort_values(["pLDDT_local","RMSD_outside"], ascending=[False, True])
            .reset_index(drop=True))
df_msa

Unnamed: 0,name,sequence,pLDDT_global,pLDDT_local,plddt_shape,plddt_source,pae_shape,RMSD_outside,pdb_path
0,MSA_cand_4,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.357168,98.464286,"(286,)",scores_json,"(286, 286)",5.865924,af_candidates/MSA_cand_4_unrelaxed_rank_001_al...
1,MSA_cand_7,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.411259,98.295714,"(286,)",scores_json,"(286, 286)",0.185134,af_candidates/MSA_cand_7_unrelaxed_rank_001_al...
2,MSA_cand_6,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.322098,98.222857,"(286,)",scores_json,"(286, 286)",0.29136,af_candidates/MSA_cand_6_unrelaxed_rank_001_al...
3,MSA_cand_1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.356783,98.205714,"(286,)",scores_json,"(286, 286)",0.2813332,af_candidates/MSA_cand_1_unrelaxed_rank_001_al...
4,MSA_cand_3,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.556294,98.125714,"(286,)",scores_json,"(286, 286)",1.165166,af_candidates/MSA_cand_3_unrelaxed_rank_001_al...
5,MSA_cand_9,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.336748,98.117143,"(286,)",scores_json,"(286, 286)",0.568555,af_candidates/MSA_cand_9_unrelaxed_rank_001_al...
6,MSA_cand_0,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.166294,98.108571,"(286,)",scores_json,"(286, 286)",1.584206e-14,af_candidates/MSA_cand_0_unrelaxed_rank_001_al...
7,MSA_cand_2,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.43493,98.062857,"(286,)",scores_json,"(286, 286)",0.7649345,af_candidates/MSA_cand_2_unrelaxed_rank_001_al...
8,MSA_cand_8,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.448671,98.002857,"(286,)",scores_json,"(286, 286)",0.495805,af_candidates/MSA_cand_8_unrelaxed_rank_001_al...
9,MSA_cand_5,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.357727,98.002857,"(286,)",scores_json,"(286, 286)",0.876236,af_candidates/MSA_cand_5_unrelaxed_rank_001_al...


In [47]:
df_msa[['name', 'sequence', 'pLDDT_global', 'pLDDT_local', 'RMSD_outside']]

Unnamed: 0,name,sequence,pLDDT_global,pLDDT_local,RMSD_outside
0,MSA_cand_4,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.357168,98.464286,5.865924
1,MSA_cand_7,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.411259,98.295714,0.185134
2,MSA_cand_6,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.322098,98.222857,0.29136
3,MSA_cand_1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.356783,98.205714,0.2813332
4,MSA_cand_3,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.556294,98.125714,1.165166
5,MSA_cand_9,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.336748,98.117143,0.568555
6,MSA_cand_0,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.166294,98.108571,1.584206e-14
7,MSA_cand_2,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.43493,98.062857,0.7649345
8,MSA_cand_8,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.448671,98.002857,0.495805
9,MSA_cand_5,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,94.357727,98.002857,0.876236


In [50]:
df_msa.sequence.iloc[0][START: START + L]

'PELNEAL'