# CHD Variant Structural Analysis Pipeline (v5.2)

**Changes from v5.1:**
- **DDG concordance bug fix**: DDG vote now uses `max(|mono|, |multi_max|, |multi_min|)` — v5.1 only checked multi_max for positive thresholds and multi_min for negative, missing single-partner multimer values stored in multi_max that were negative (affects ZIC3 H318N, K405E, R350G, S402P)
- **DDG confidence gating uses `ddg_confidence` column**: cleaner than raw pLDDT thresholds, same result
- **New sub-score columns**: `structure_strict`, `structure_relaxed`, `external_strict`, `external_relaxed` (and T3 variants) for transparent vote decomposition
- **Column ordering**: concordance columns grouped logically with sub-scores before totals

**Concordance formula (v5.2):**
```
DDG value = max( |ddg_monomer|, |ddg_multimer_max|, |ddg_multimer_min| )

Standard:  DDG vote = 1 if value >= 2.0 AND ddg_confidence = 'high'
Relaxed:   DDG vote = 1 if value >= 1.0 AND ddg_confidence != 'low'

AM strict:  likely_pathogenic
AM relaxed: likely_pathogenic OR ambiguous

Franklin strict:  pathogenic, LP, VUS(high)
Franklin relaxed: pathogenic, LP, VUS(high), VUS(mid)
```


In [None]:
# =============================================================================
# CELL 1: CONFIGURATION
# =============================================================================

import os, sys, re, warnings, subprocess, tempfile, shutil
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Set
import pandas as pd
import numpy as np
from Bio.PDB import PDBParser, MMCIFParser, NeighborSearch, ShrakeRupley

warnings.filterwarnings('ignore')

# =============================================================================
# PATHS — adjust these to your local environment
# =============================================================================
# Set WORKING_DIR to the root of your project (where structure files live)
# Default: current directory
WORKING_DIR    = Path(os.environ.get("CHD_WORKING_DIR", ".")).resolve()
RESULTS_DIR    = WORKING_DIR / "results"

# mkdssp binary — install via: conda install -c salilab dssp  OR  brew install dssp
DSSP_PATH      = os.environ.get("DSSP_PATH", shutil.which("mkdssp") or "mkdssp")

for d in [RESULTS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Directories to search for structure files (in priority order)
SEARCH_DIRS = [
    WORKING_DIR,
    WORKING_DIR / "structures",
    WORKING_DIR / "structures" / "monomers",
    WORKING_DIR / "structures" / "multimers",
]

def find_file(filename):
    """Search multiple directories for a file."""
    if filename is None:
        return None
    for d in SEARCH_DIRS:
        p = d / filename
        if p.exists():
            return p
    return None

# =============================================================================
# INPUT FILES
# =============================================================================
VARIANTS_FILE    = WORKING_DIR / "variants_with_alphamissense_and_franklin_expanded.csv"
MONOMER_DDG_FILE = WORKING_DIR / "foldx_ddg_monomer_results_all.csv"
MULTIMER_DDG_FILE = WORKING_DIR / "foldx_ddg_multimer_results.csv"

# =============================================================================
# MONOMER STRUCTURE DEFINITIONS
# gene → (cif_filename_or_None, pdb_filename_or_None)
# CIF preferred for pLDDT on shroom3/cdh2/dvl2/ctnnb1/zic3/actb (PDB B-factors=0)
# PDB B-factors valid for gli3/kpna1/kpna6/mdfi/rock2/tcf7l1
# =============================================================================
MONOMER_STRUCTURES = {
    'shroom3': ('fold_shroom3_model_0.cif',  'fold_shroom3_model_0.pdb'),
    'zic3':    ('fold_zic3_model_0.cif',     'fold_zic3_model_0.pdb'),
    'cdh2':    ('fold_cdh2_model_0.cif',     'fold_cdh2_model_0.pdb'),
    'dvl2':    ('fold_dvl2_model_0.cif',     'fold_dvl2_model_0.pdb'),
    'ctnnb1':  ('fold_ctnnb1_model_0.cif',   'fold_ctnnb1_model_0.pdb'),
    'actb':    ('fold_actb_model_0.cif',     'fold_actb_model_0.pdb'),
    'rock2':   (None,                        'rock2.pdb'),
    'gli3':    (None,                        'gli3.pdb'),
    'kpna1':   (None,                        'kpna1.pdb'),
    'kpna6':   (None,                        'kpna6.pdb'),
    'mdfi':    (None,                        'mdfi.pdb'),
    'tcf7l1':  (None,                        'tcf7l1.pdb'),
    'zic2':    (None,                        'zic2.pdb'),
    'zic5':    (None,                        'zic5.pdb'),
}

# =============================================================================
# MULTIMER COMPLEX DEFINITIONS
# (gene1, partner_label, cif_file_or_None, pdb_file, chain_gene1, chain_partner, is_primary)
# NOTE: For pLDDT, PDB B-factors are used (not CIF) because CIF chain order
# varies by complex (e.g., actb CIF has chains swapped vs PDB).
# =============================================================================
MULTIMER_STRUCTURES = [
    ('shroom3','actin',         'fold_shroom3_actin_chain_model_0.cif',            'fold_shroom3_actin_chain_model_0.pdb',        'A','B', True),
    ('shroom3','actb',          'fold_shroom3_actb_model_0.cif',                   'fold_shroom3_actb_model_0.pdb',               'B','A', False),
    ('shroom3','dvl2',          'fold_shroom3_dvl2_model_0.cif',                   'fold_shroom3_dvl2_model_0.pdb',               'A','B', True),
    ('shroom3','cdh2_truncated','fold_shroom3_cdh2_truncated_model_0.cif',         'fold_shroom3_cdh2_truncated_model_0.pdb',     'A','B', False),
    ('shroom3','ctnnb1',        'fold_shroom3_ctnnb1_model_0.cif',                 'fold_shroom3_ctnnb1_model_0.pdb',             'A','B', True),
    ('shroom3','cdh2_cyto',     'fold_shroom3_cdh2_cytoplasmic_domain_model_0.cif','fold_shroom3_cdh2_cytoplasmic_domain.pdb',    'A','B', True),
    ('shroom3','actb_no_bind',  'fold_shroom3_no_actin_binding_actb_chain_model_0.cif','fold_shroom3_no_actin_binding_actb_chain.pdb','A','B', False),
    ('shroom3','rock2',         None,                                              'fold_shroom3_rock2_model_0.pdb',              'A','B', True),
    ('zic3','gli3',   'fold_zic3_gli3_model_0.cif',   'fold_zic3_gli3_model_0.pdb',  'A','B', True),
    ('zic3','kpna1',  'fold_zic3_kpna1_model_0.cif',  'fold_zic3_kpna1_model_0.pdb', 'A','B', True),
    ('zic3','kpna6',  'fold_zic3_kpna6_model_0.cif',  'fold_zic3_kpna6_model_0.pdb', 'A','B', True),
    ('zic3','mdfi',   'fold_zic3_mdfi_model_0.cif',   'fold_zic3_mdfi_model_0.pdb',  'A','B', True),
    ('zic3','tcf7l1', 'fold_zic3_tcf7l1_model_0.cif', 'fold_zic3_tcf7l1_model_0.pdb','A','B', True),
]

# =============================================================================
# AMINO ACID DATA
# =============================================================================
THREE_TO_ONE = {
    'ALA':'A','CYS':'C','ASP':'D','GLU':'E','PHE':'F','GLY':'G','HIS':'H',
    'ILE':'I','LYS':'K','LEU':'L','MET':'M','ASN':'N','PRO':'P','GLN':'Q',
    'ARG':'R','SER':'S','THR':'T','VAL':'V','TRP':'W','TYR':'Y',
}
AA_PROPERTIES = {
    'A':{'size':'small','charge':'neutral','hydrophobic':True},
    'R':{'size':'large','charge':'positive','hydrophobic':False},
    'N':{'size':'medium','charge':'neutral','hydrophobic':False},
    'D':{'size':'medium','charge':'negative','hydrophobic':False},
    'C':{'size':'small','charge':'neutral','hydrophobic':True},
    'E':{'size':'medium','charge':'negative','hydrophobic':False},
    'Q':{'size':'medium','charge':'neutral','hydrophobic':False},
    'G':{'size':'small','charge':'neutral','hydrophobic':False},
    'H':{'size':'medium','charge':'positive','hydrophobic':False},
    'I':{'size':'medium','charge':'neutral','hydrophobic':True},
    'L':{'size':'medium','charge':'neutral','hydrophobic':True},
    'K':{'size':'large','charge':'positive','hydrophobic':False},
    'M':{'size':'medium','charge':'neutral','hydrophobic':True},
    'F':{'size':'large','charge':'neutral','hydrophobic':True},
    'P':{'size':'small','charge':'neutral','hydrophobic':False},
    'S':{'size':'small','charge':'neutral','hydrophobic':False},
    'T':{'size':'small','charge':'neutral','hydrophobic':False},
    'W':{'size':'large','charge':'neutral','hydrophobic':True},
    'Y':{'size':'large','charge':'neutral','hydrophobic':False},
    'V':{'size':'small','charge':'neutral','hydrophobic':True},
}
# Max SASA (Tien et al 2013, theoretical Gly-X-Gly)
MAX_SASA = {
    'A':129,'R':274,'N':195,'D':193,'C':167,'E':223,'Q':225,'G':104,
    'H':224,'I':197,'L':201,'K':236,'M':224,'F':240,'P':159,'S':155,
    'T':172,'V':174,'W':285,'Y':263,
}

GRANTHAM = {
    ('A','R'):112,('A','N'):111,('A','D'):126,('A','C'):195,('A','Q'):91,('A','E'):107,
    ('A','G'):60,('A','H'):86,('A','I'):94,('A','L'):96,('A','K'):106,('A','M'):84,
    ('A','F'):113,('A','P'):27,('A','S'):99,('A','T'):58,('A','W'):148,('A','Y'):112,('A','V'):64,
    ('R','N'):86,('R','D'):96,('R','C'):180,('R','Q'):43,('R','E'):54,('R','G'):125,
    ('R','H'):29,('R','I'):97,('R','L'):102,('R','K'):26,('R','M'):91,('R','F'):97,
    ('R','P'):103,('R','S'):110,('R','T'):71,('R','W'):101,('R','Y'):77,('R','V'):96,
    ('N','D'):23,('N','C'):139,('N','Q'):46,('N','E'):42,('N','G'):80,('N','H'):68,
    ('N','I'):149,('N','L'):153,('N','K'):94,('N','M'):142,('N','F'):158,('N','P'):91,
    ('N','S'):46,('N','T'):65,('N','W'):174,('N','Y'):143,('N','V'):133,
    ('D','C'):154,('D','Q'):61,('D','E'):45,('D','G'):94,('D','H'):81,('D','I'):168,
    ('D','L'):172,('D','K'):101,('D','M'):160,('D','F'):177,('D','P'):108,('D','S'):65,
    ('D','T'):85,('D','W'):181,('D','Y'):160,('D','V'):152,
    ('C','Q'):154,('C','E'):170,('C','G'):159,('C','H'):174,('C','I'):198,('C','L'):198,
    ('C','K'):202,('C','M'):196,('C','F'):205,('C','P'):169,('C','S'):112,('C','T'):149,
    ('C','W'):215,('C','Y'):194,('C','V'):192,
    ('Q','E'):29,('Q','G'):87,('Q','H'):24,('Q','I'):109,('Q','L'):113,('Q','K'):53,
    ('Q','M'):101,('Q','F'):116,('Q','P'):76,('Q','S'):68,('Q','T'):42,('Q','W'):130,
    ('Q','Y'):99,('Q','V'):96,
    ('E','G'):98,('E','H'):40,('E','I'):134,('E','L'):138,('E','K'):56,('E','M'):126,
    ('E','F'):140,('E','P'):93,('E','S'):80,('E','T'):65,('E','W'):152,('E','Y'):122,('E','V'):121,
    ('G','H'):98,('G','I'):135,('G','L'):138,('G','K'):127,('G','M'):127,('G','F'):153,
    ('G','P'):42,('G','S'):56,('G','T'):59,('G','W'):184,('G','Y'):147,('G','V'):109,
    ('H','I'):94,('H','L'):99,('H','K'):32,('H','M'):87,('H','F'):100,('H','P'):77,
    ('H','S'):89,('H','T'):47,('H','W'):115,('H','Y'):83,('H','V'):84,
    ('I','L'):5,('I','K'):102,('I','M'):10,('I','F'):21,('I','P'):95,('I','S'):142,
    ('I','T'):89,('I','W'):61,('I','Y'):33,('I','V'):29,
    ('L','K'):107,('L','M'):15,('L','F'):22,('L','P'):98,('L','S'):145,('L','T'):92,
    ('L','W'):61,('L','Y'):36,('L','V'):32,
    ('K','M'):95,('K','F'):102,('K','P'):103,('K','S'):121,('K','T'):78,('K','W'):110,
    ('K','Y'):85,('K','V'):97,
    ('M','F'):28,('M','P'):87,('M','S'):135,('M','T'):81,('M','W'):67,('M','Y'):36,('M','V'):21,
    ('F','P'):114,('F','S'):155,('F','T'):103,('F','W'):40,('F','Y'):22,('F','V'):50,
    ('P','S'):74,('P','T'):38,('P','W'):147,('P','Y'):110,('P','V'):68,
    ('S','T'):58,('S','W'):177,('S','Y'):144,('S','V'):124,
    ('T','W'):128,('T','Y'):92,('T','V'):69,
    ('W','Y'):37,('W','V'):88,
    ('Y','V'):55,
}

def get_grantham(a1, a2):
    if pd.isna(a1) or pd.isna(a2): return -1
    a1, a2 = str(a1).upper(), str(a2).upper()
    if a1 == a2: return 0
    return GRANTHAM.get((a1,a2), GRANTHAM.get((a2,a1), -1))

def classify_grantham(d):
    if pd.isna(d) or d is None or d < 0: return 'unknown'
    d = int(d)
    if d <= 50: return 'conservative'
    elif d <= 100: return 'moderately_conservative'
    elif d <= 150: return 'moderately_radical'
    else: return 'radical'

def grantham_severity(d):
    if pd.isna(d) or d is None or d < 0: return 0.0
    return min(4.0, float(d) / 53.75)

def get_property_changes(r, a):
    if pd.isna(r) or pd.isna(a): return 'unknown'
    p1, p2 = AA_PROPERTIES.get(str(r).upper(),{}), AA_PROPERTIES.get(str(a).upper(),{})
    if not p1 or not p2: return 'unknown'
    ch = []
    for k in ['size','charge','hydrophobic']:
        if p1.get(k) != p2.get(k): ch.append(f"{k}:{p1[k]}->{p2[k]}")
    return ';'.join(ch) if ch else 'none'

def sf(v, d=0.0):
    if pd.isna(v) or v is None: return d
    try: return float(v)
    except: return d
def si(v, d=0):
    if pd.isna(v) or v is None: return d
    try: return int(v)
    except: return d
def ss(v): return '' if pd.isna(v) or v is None else str(v)
def sb(v, d=False):
    if pd.isna(v) or v is None: return d
    return bool(v)

# Verify files exist
found = 0
for g, (cif, pdb) in MONOMER_STRUCTURES.items():
    cf = find_file(cif)
    pf = find_file(pdb)
    if cf or pf:
        found += 1
    else:
        print(f"  ⚠ {g}: no structure found (tried {cif}, {pdb})")
print(f"✓ Configuration loaded: {found}/{len(MONOMER_STRUCTURES)} monomer structures found")
print(f"  Multimer complexes: {len(MULTIMER_STRUCTURES)}")

mfound = 0
for g1, pl, cif, pdb, *_ in MULTIMER_STRUCTURES:
    pf = find_file(pdb)
    if pf: mfound += 1
    else: print(f"  ⚠ {g1}-{pl}: {pdb} NOT FOUND")
print(f"  Multimer PDBs found: {mfound}/{len(MULTIMER_STRUCTURES)}")


In [None]:
# =============================================================================
# CELL 2: STRUCTURE LOADING AND EXTRACTION FUNCTIONS
# =============================================================================

_pdb_parser = PDBParser(QUIET=True)
_cif_parser = MMCIFParser(QUIET=True)


def load_cif(path):
    if path and path.exists():
        try: return _cif_parser.get_structure('s', str(path))
        except: pass
    return None

def load_pdb(path):
    if path and path.exists():
        try: return _pdb_parser.get_structure('s', str(path))
        except: pass
    return None


def get_plddt(structure, chain_id='A'):
    """Per-residue pLDDT from B-factors. Returns only non-zero values."""
    plddt = {}
    if structure is None: return plddt
    model = structure[0]
    # Resolve chain
    if chain_id not in model:
        for c in model: chain_id = c.id; break
    if chain_id not in model: return plddt
    for res in model[chain_id].get_residues():
        if res.id[0] == ' ':
            p = None
            if 'CA' in res: p = res['CA'].bfactor
            else:
                for atom in res: p = atom.bfactor; break
            if p is not None and p > 0:
                plddt[res.id[1]] = round(p, 2)
    return plddt


def get_monomer_plddt(gene_lower):
    """Get pLDDT for monomer: CIF first (needed for shroom3/cdh2/dvl2/ctnnb1/zic3/actb),
    then PDB fallback."""
    cif_name, pdb_name = MONOMER_STRUCTURES.get(gene_lower, (None, None))
    cif_path = find_file(cif_name)
    pdb_path = find_file(pdb_name)

    # Try CIF first
    if cif_path:
        struct = load_cif(cif_path)
        if struct:
            plddt = get_plddt(struct, 'A')
            if plddt:
                return plddt, struct, 'cif', cif_path
    # PDB fallback
    if pdb_path:
        struct = load_pdb(pdb_path)
        if struct:
            plddt = get_plddt(struct, 'A')
            if plddt:
                return plddt, struct, 'pdb', pdb_path
            # Even if pLDDT empty, return struct for contacts
            return {}, struct, 'pdb_no_plddt', pdb_path
    return {}, None, None, None


def get_multimer_plddt(pdb_path, cif_path, chain_id):
    """Get pLDDT for multimer: PDB first (consistent chain order), CIF fallback.
    AlphaFold multimer PDBs should have valid B-factors."""
    # PDB first — chain order is consistent
    struct = load_pdb(pdb_path)
    if struct:
        plddt = get_plddt(struct, chain_id)
        if plddt:
            return plddt, 'pdb'
    # CIF fallback (WARNING: chain order may differ!)
    struct = load_cif(cif_path)
    if struct:
        plddt = get_plddt(struct, chain_id)
        if plddt:
            return plddt, 'cif'
    return {}, None


def get_residue_aa(structure, chain_id='A'):
    if structure is None: return {}
    model = structure[0]
    if chain_id not in model:
        for c in model: chain_id = c.id; break
    if chain_id not in model: return {}
    return {r.id[1]: THREE_TO_ONE.get(r.resname, '?')
            for r in model[chain_id].get_residues() if r.id[0] == ' '}


def count_contacts(structure, chain_id='A', distance=5.0):
    """Unique residue-residue contacts, sequence separation >= 3.
    Matches original working pipeline (cell 24 of 85-cell notebook)."""
    if structure is None: return {}
    model = structure[0]
    if chain_id not in model:
        for c in model: chain_id = c.id; break
    if chain_id not in model: return {}

    chain = model[chain_id]
    residues = [r for r in chain.get_residues() if r.id[0] == ' ']
    contacts = {}

    for i, res in enumerate(residues):
        pos = res.id[1]
        neighbor_set = set()
        for j, other in enumerate(residues):
            if abs(i - j) < 3:  # skip self and immediate neighbors
                continue
            for atom_i in res.get_atoms():
                found = False
                for atom_j in other.get_atoms():
                    if atom_i - atom_j < distance:
                        neighbor_set.add(other.id[1])
                        found = True
                        break
                if found:
                    break
        contacts[pos] = len(neighbor_set)
    return contacts


def count_interface(structure, my_chain, partner_chain, distance=5.0):
    """Inter-chain contacts: unique partner residues within distance for each residue."""
    if structure is None: return {}, set()
    model = structure[0]
    if my_chain not in model or partner_chain not in model:
        return {}, set()

    partner_atoms = list(model[partner_chain].get_atoms())
    if not partner_atoms: return {}, set()
    ns = NeighborSearch(partner_atoms)

    inter, iface = {}, set()
    for res in model[my_chain].get_residues():
        if res.id[0] != ' ': continue
        partner_residues = set()
        for atom in res.get_atoms():
            for nb in ns.search(atom.coord, distance, 'R'):
                if nb.id[0] == ' ':
                    partner_residues.add(nb.id[1])
        count = len(partner_residues)
        if count > 0:
            inter[res.id[1]] = count
            iface.add(res.id[1])
    return inter, iface


def get_accessibility(structure, chain_id='A'):
    """Relative solvent accessibility using ShrakeRupley (no external tools).
    Compatible with Biopython 1.86."""
    acc = {}
    if structure is None: return acc
    model = structure[0]
    if chain_id not in model:
        for c in model: chain_id = c.id; break
    if chain_id not in model: return acc
    try:
        sr = ShrakeRupley()
        sr.compute(structure[0], level='R')
        for res in model[chain_id].get_residues():
            if res.id[0] != ' ': continue
            aa = THREE_TO_ONE.get(res.resname, 'X')
            max_s = MAX_SASA.get(aa, 200)
            rel = min(1.0, res.sasa / max_s) if max_s > 0 and hasattr(res, 'sasa') else None
            if rel is not None:
                acc[res.id[1]] = round(rel, 9)
    except Exception as e:
        print(f"    ShrakeRupley warning: {e}")
    return acc


def get_secondary_structure(pdb_path, chain_id='A'):
    """Secondary structure via mkdssp subprocess (handles v4 output format).
    Falls back to empty if mkdssp fails."""
    ss_map = {}
    if pdb_path is None or not pdb_path.exists(): return ss_map
    try:
        # Try mkdssp v4 with classic DSSP output format
        for cmd in [
            [DSSP_PATH, '--output-format', 'dssp', str(pdb_path)],
            [DSSP_PATH, '-i', str(pdb_path)],
            [DSSP_PATH, str(pdb_path)],
        ]:
            try:
                result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
                if result.returncode == 0 and len(result.stdout) > 100:
                    break
            except FileNotFoundError:
                return ss_map
        else:
            return ss_map

        if result.returncode != 0:
            return ss_map

        # Parse DSSP output
        in_data = False
        for line in result.stdout.split('\n'):
            if '  #  RESIDUE' in line:
                in_data = True
                continue
            if not in_data or len(line) < 17:
                continue
            # Skip chain break lines
            if line[13] == '!':
                continue
            try:
                chain = line[11]
                if chain != chain_id:
                    continue
                resnum_str = line[5:10].strip()
                if not resnum_str:
                    continue
                resnum = int(resnum_str)
                sec = line[16] if len(line) > 16 and line[16] != ' ' else '-'
                ss_map[resnum] = sec
            except (ValueError, IndexError):
                continue
    except Exception:
        pass
    return ss_map


def classify_burial(a):
    if a is None or pd.isna(a): return 'unknown'
    if float(a) < 0.05: return 'buried_core'
    elif float(a) < 0.25: return 'partially_buried'
    else: return 'surface_exposed'

def classify_plddt(v):
    if v is None or pd.isna(v): return 'unknown'
    if v >= 90: return 'very_high'
    elif v >= 70: return 'confident'
    elif v >= 50: return 'low'
    else: return 'very_low'

def classify_contacts(n):
    if n is None or pd.isna(n): return 'unknown'
    if n >= 8: return 'high_contact'
    elif n >= 1: return 'medium_contact'
    else: return 'low_contact'

print("✓ Extraction functions defined")

In [None]:
# =============================================================================
# CELL 3: LOAD VARIANTS
# =============================================================================

variants_df = pd.read_csv(VARIANTS_FILE)
variants_df.columns = [c.lower().strip() for c in variants_df.columns]
variants_df['position'] = variants_df['position'].astype(int)

ann_cols = [c for c in variants_df.columns
            if c in ['alphamissense','franklin','alphamissense_pathogenicity']]
annotation_df = variants_df[['gene','position','ref_aa','alt_aa'] + ann_cols].copy()

print(f"✓ Loaded {len(variants_df)} variants across {variants_df['gene'].nunique()} genes")
print(variants_df.groupby('gene').size().to_string())

In [None]:
# =============================================================================
# CELL 4: MONOMER METRICS
# =============================================================================

print("Extracting monomer structural metrics...")
monomer_rows = []

for gene in variants_df['gene'].unique():
    g = str(gene).lower()
    gene_vars = variants_df[variants_df['gene'].str.lower() == g]

    # Get pLDDT (CIF first for genes with zero PDB B-factors)
    plddt_map, struct_plddt, plddt_src, plddt_path = get_monomer_plddt(g)

    # Get PDB structure for contacts and accessibility
    _, pdb_name = MONOMER_STRUCTURES.get(g, (None, None))
    pdb_path = find_file(pdb_name)
    # Also try generic patterns if not found
    if pdb_path is None:
        for pat in [f'fold_{g}_model_0.pdb', f'{g}.pdb']:
            pdb_path = find_file(pat)
            if pdb_path: break

    struct_pdb = load_pdb(pdb_path)

    # Use whichever structure we have for contacts
    struct_contacts = struct_pdb or struct_plddt
    contact_map = count_contacts(struct_contacts, 'A') if struct_contacts else {}

    # Accessibility (ShrakeRupley — Biopython 1.86 compatible)
    acc_map = get_accessibility(struct_pdb or struct_plddt, 'A')

    # Secondary structure (subprocess mkdssp — handles v4 format)
    ss_map = get_secondary_structure(pdb_path, 'A')

    # AA identity
    aa_map = get_residue_aa(struct_plddt or struct_pdb, 'A')

    n_plddt = sum(1 for _, r in gene_vars.iterrows() if plddt_map.get(int(r['position'])))
    n_acc = sum(1 for _, r in gene_vars.iterrows() if acc_map.get(int(r['position'])) is not None)
    has_struct = struct_plddt is not None or struct_pdb is not None
    print(f"  {gene}: struct={'YES' if has_struct else 'NO'} src={plddt_src} pLDDT={n_plddt}/{len(gene_vars)} acc={n_acc}/{len(gene_vars)} contacts={len(contact_map)} SS={len(ss_map)}")

    for _, row in gene_vars.iterrows():
        pos = int(row['position'])
        gd = get_grantham(row['ref_aa'], row['alt_aa'])
        p = plddt_map.get(pos)
        c = contact_map.get(pos, 0) if contact_map else None
        a = acc_map.get(pos)
        sec_struct = ss_map.get(pos, '-') if ss_map else None

        monomer_rows.append({
            'gene': gene, 'position': pos,
            'ref_aa': row['ref_aa'], 'alt_aa': row['alt_aa'],
            'grantham_distance': gd, 'grantham_class': classify_grantham(gd),
            'substitution_severity': round(grantham_severity(gd), 2),
            'property_changes': get_property_changes(row['ref_aa'], row['alt_aa']),
            'monomer_plddt': p, 'monomer_plddt_category': classify_plddt(p),
            'monomer_n_contacts': float(c) if c is not None else None,
            'monomer_contact_category': classify_contacts(c),
            'monomer_aa': aa_map.get(pos),
            'monomer_accessibility': a,
            'monomer_burial': classify_burial(a),
            'monomer_secondary_structure': sec_struct,
            'monomer_contact_disruption': float(c) if c is not None else None,
        })

df = pd.DataFrame(monomer_rows)
print(f"\n✓ Monomer: {len(df)} variants")
print(f"  pLDDT populated:     {df['monomer_plddt'].notna().sum()}/{len(df)}")
print(f"  Accessibility:       {df['monomer_accessibility'].notna().sum()}/{len(df)}")
print(f"  Burial (non-unknown): {(df['monomer_burial'] != 'unknown').sum()}/{len(df)}")
print(f"  Secondary structure: {df['monomer_secondary_structure'].notna().sum()}/{len(df)}")

In [None]:
# =============================================================================
# CELL 5: MULTIMER EXTRACTION (PDB pLDDT + BIDIRECTIONAL)
# =============================================================================

print("Extracting multimer metrics...")

multi_data = {}  # (gene_lower, position) → {col: val}
all_partner_labels = set()
variant_genes = set(df['gene'].str.lower().unique())

for gene1, partner_label, cif_file, pdb_file, chain1, chain2, is_primary in MULTIMER_STRUCTURES:
    g1 = gene1.lower()
    plabel = partner_label.lower()

    pdb_path = find_file(pdb_file)
    cif_path = find_file(cif_file) if cif_file else None

    if pdb_path is None:
        print(f"  ⚠ {pdb_file} NOT FOUND — skipping {g1}-{plabel}")
        continue

    # Load PDB for everything (contacts, interface, accessibility, pLDDT)
    struct_pdb = load_pdb(pdb_path)
    if struct_pdb is None:
        print(f"  ⚠ Failed to load {pdb_file}")
        continue

    # === FORWARD: gene1 variants get multi_{partner_label}_* ===
    if g1 in variant_genes:
        all_partner_labels.add(plabel)
        gene1_positions = set(df[df['gene'].str.lower() == g1]['position'].values)

        # pLDDT from PDB first (consistent chain order), CIF fallback
        plddt_a, psrc = get_multimer_plddt(pdb_path, cif_path, chain1)
        contacts_a = count_contacts(struct_pdb, chain1)
        inter_a, iface_a = count_interface(struct_pdb, chain1, chain2)
        acc_a = get_accessibility(struct_pdb, chain1)
        ss_a = get_secondary_structure(pdb_path, chain1)

        for pos in gene1_positions:
            key = (g1, pos)
            if key not in multi_data: multi_data[key] = {}
            pfx = f"multi_{plabel}"
            p = plddt_a.get(pos)
            c = contacts_a.get(pos, 0)
            ic = inter_a.get(pos, 0)
            multi_data[key][f"{pfx}_plddt"] = p
            multi_data[key][f"{pfx}_n_contacts"] = float(c)
            multi_data[key][f"{pfx}_inter_contacts"] = float(ic)
            multi_data[key][f"{pfx}_is_interface"] = pos in iface_a
            multi_data[key][f"{pfx}_accessibility"] = acc_a.get(pos)
            multi_data[key][f"{pfx}_burial"] = classify_burial(acc_a.get(pos))
            multi_data[key][f"{pfx}_sec_struct"] = ss_a.get(pos, '-')
            multi_data[key][f"{pfx}_disruption"] = float(c)

        n_fwd = len([p for p in gene1_positions if plddt_a.get(p)])
        print(f"  FWD {g1} → multi_{plabel}: {n_fwd}/{len(gene1_positions)} pLDDT (src={psrc})")

    # === REVERSE: partner gene variants get multi_{gene1}_* ===
    partner_gene_map = {
        'dvl2':'dvl2','ctnnb1':'ctnnb1','rock2':'rock2',
        'gli3':'gli3','kpna1':'kpna1','kpna6':'kpna6',
        'mdfi':'mdfi','tcf7l1':'tcf7l1',
        'cdh2_truncated':'cdh2','cdh2_cyto':'cdh2',
        'actin':'actb','actb':'actb','actb_no_bind':'actb',
    }
    partner_gene = partner_gene_map.get(plabel, plabel)

    if partner_gene in variant_genes:
        # Column label: same as forward (multi_{partner_label})
        # e.g., rock2 variants in shroom3-rock2 complex → multi_rock2 (not multi_shroom3)
        rev_label = plabel
        all_partner_labels.add(rev_label)
        partner_positions = set(df[df['gene'].str.lower() == partner_gene]['position'].values)

        plddt_b, psrc_b = get_multimer_plddt(pdb_path, cif_path, chain2)
        contacts_b = count_contacts(struct_pdb, chain2)
        inter_b, iface_b = count_interface(struct_pdb, chain2, chain1)
        acc_b = get_accessibility(struct_pdb, chain2)
        ss_b = get_secondary_structure(pdb_path, chain2)

        for pos in partner_positions:
            key = (partner_gene, pos)
            if key not in multi_data: multi_data[key] = {}
            pfx = f"multi_{rev_label}"
            # Only write if not already populated (first complex wins)
            if f"{pfx}_plddt" not in multi_data[key] or multi_data[key][f"{pfx}_plddt"] is None:
                p = plddt_b.get(pos)
                c = contacts_b.get(pos, 0)
                ic = inter_b.get(pos, 0)
                multi_data[key][f"{pfx}_plddt"] = p
                multi_data[key][f"{pfx}_n_contacts"] = float(c)
                multi_data[key][f"{pfx}_inter_contacts"] = float(ic)
                multi_data[key][f"{pfx}_is_interface"] = pos in iface_b
                multi_data[key][f"{pfx}_accessibility"] = acc_b.get(pos)
                multi_data[key][f"{pfx}_burial"] = classify_burial(acc_b.get(pos))
                multi_data[key][f"{pfx}_sec_struct"] = ss_b.get(pos, '-')
                multi_data[key][f"{pfx}_disruption"] = float(c)

        n_rev = len([p for p in partner_positions if plddt_b.get(p)])
        print(f"  REV {partner_gene} → multi_{rev_label}: {n_rev}/{len(partner_positions)} pLDDT (src={psrc_b})")

# Merge into df
multi_df = pd.DataFrame.from_dict(multi_data, orient='index')
multi_df.index = pd.MultiIndex.from_tuples(multi_df.index, names=['gene_lower','position'])
multi_df = multi_df.reset_index()

df['gene_lower'] = df['gene'].str.lower()
df = df.merge(multi_df, on=['gene_lower','position'], how='left')
df = df.drop(columns=['gene_lower'])

# === Summary columns ===
def compute_summary(row):
    partners, plddt_v, contact_v, disrupt_v, iface_partners = [], [], [], [], []
    for pl in all_partner_labels:
        p_col = f"multi_{pl}_plddt"
        if p_col in row.index and pd.notna(row[p_col]):
            partners.append(pl)
            plddt_v.append(row[p_col])
            c = sf(row.get(f"multi_{pl}_n_contacts"), 0)
            contact_v.append(c)
            d = sf(row.get(f"multi_{pl}_disruption"), 0)
            disrupt_v.append(d)
            if sb(row.get(f"multi_{pl}_is_interface"), False):
                iface_partners.append(pl)
    return pd.Series({
        'n_multimer_complexes': len(partners),
        'multimer_partners': ';'.join(partners) if partners else None,
        'is_interface_any': len(iface_partners) > 0,
        'interface_partners': ';'.join(iface_partners) if iface_partners else None,
        'n_interface_partners': len(iface_partners),
        'multimer_plddt_max': max(plddt_v) if plddt_v else None,
        'multimer_plddt_avg': round(np.mean(plddt_v), 2) if plddt_v else None,
        'multimer_contacts_max': max(contact_v) if contact_v else None,
        'multimer_contacts_avg': round(np.mean(contact_v), 2) if contact_v else None,
        'multimer_disruption_max': max(disrupt_v) if disrupt_v else None,
        'multimer_disruption_avg': round(np.mean(disrupt_v), 2) if disrupt_v else None,
    })

summary = df.apply(compute_summary, axis=1)
df = pd.concat([df, summary], axis=1)

# best_plddt
def best_p(row):
    vals = []
    if pd.notna(row.get('monomer_plddt')): vals.append(row['monomer_plddt'])
    for pl in all_partner_labels:
        col = f"multi_{pl}_plddt"
        if col in row.index and pd.notna(row[col]): vals.append(row[col])
    return max(vals) if vals else None

df['best_plddt'] = df.apply(best_p, axis=1)
df['confidence'] = df['best_plddt'].apply(lambda x: 'high' if pd.notna(x) and x >= 70 else ('low' if pd.notna(x) else 'unknown'))

print(f"\n✓ Multimer complete")
print(f"  Interface variants: {df['is_interface_any'].sum()}/{len(df)}")
print(f"  With multimer data: {df['n_multimer_complexes'].gt(0).sum()}/{len(df)}")
print(f"  best_plddt filled:  {df['best_plddt'].notna().sum()}/{len(df)}")

In [None]:
# =============================================================================
# CELL 6: FOLDX DDG (v5.1)
# =============================================================================
# v5.1 changes:
#   1. Added ddg_multimer_min (most stabilizing complex)
#   2. Added ddg_confidence flag based on best_pLDDT
#   3. ddg_category reclassified at low pLDDT (appends _unreliable)
# =============================================================================

if MONOMER_DDG_FILE.exists():
    ddg = pd.read_csv(MONOMER_DDG_FILE)
    ddg.columns = [c.lower() for c in ddg.columns]
    dc = next((c for c in ['ddg','ddg_monomer','total_ddg'] if c in ddg.columns), None)
    if dc:
        ddg = ddg.rename(columns={dc:'ddg_monomer'})
        df = df.merge(ddg[['gene','position','ref_aa','alt_aa','ddg_monomer']],
                      on=['gene','position','ref_aa','alt_aa'], how='left')
        print(f"✓ Monomer DDG: {df['ddg_monomer'].notna().sum()}/{len(df)}")
else:
    df['ddg_monomer'] = None
    print(f"⚠ Monomer DDG file not found")

if MULTIMER_DDG_FILE.exists():
    ddgm = pd.read_csv(MULTIMER_DDG_FILE)
    ddgm.columns = [c.lower() for c in ddgm.columns]
    dc = next((c for c in ['ddg','ddg_multimer','total_ddg'] if c in ddgm.columns), None)
    if dc:
        grp = ddgm.groupby(['gene','position','ref_aa','alt_aa'])
        agg = grp.agg(
            ddg_multimer_max=(dc,'max'),
            ddg_multimer_min=(dc,'min'),       # v5.1: most stabilizing complex
            ddg_multimer_mean=(dc,'mean'),
            n_complexes_tested=(dc,'count')
        ).reset_index()
        if 'partner' in ddgm.columns:
            pt = grp['partner'].apply(lambda x: ';'.join(x.astype(str))).reset_index()
            pt.columns = ['gene','position','ref_aa','alt_aa','partners_tested']
            agg = agg.merge(pt, on=['gene','position','ref_aa','alt_aa'], how='left')
        df = df.merge(agg, on=['gene','position','ref_aa','alt_aa'], how='left')
        print(f"✓ Multimer DDG: {df['ddg_multimer_max'].notna().sum()}/{len(df)}")
        print(f"  ddg_multimer_min range: {df['ddg_multimer_min'].min():.2f} to {df['ddg_multimer_min'].max():.2f}")
else:
    for c in ['ddg_multimer_max','ddg_multimer_min','ddg_multimer_mean','n_complexes_tested','partners_tested']:
        df[c] = None

for c in ['ddg_monomer','ddg_multimer_max','ddg_multimer_min','ddg_multimer_mean','n_complexes_tested','partners_tested']:
    if c not in df.columns: df[c] = None

# -------------------------------------------------------------------------
# DDG CONFIDENCE (v5.1): based on best_pLDDT
#   high:     pLDDT >= 70  → DDG values are reliable
#   moderate: pLDDT 50-69  → DDG values are usable with caution
#   low:      pLDDT < 50   → DDG values are unreliable (FoldX on bad structure)
# -------------------------------------------------------------------------
def assign_ddg_confidence(plddt):
    if pd.isna(plddt): return 'unknown'
    p = float(plddt)
    if p >= 70: return 'high'
    elif p >= 50: return 'moderate'
    else: return 'low'

df['ddg_confidence'] = df['best_plddt'].apply(assign_ddg_confidence)

# -------------------------------------------------------------------------
# DDG CATEGORY (v5.1): classify monomer DDG, flag unreliable at low pLDDT
# -------------------------------------------------------------------------
def classify_ddg(v):
    if pd.isna(v): return None
    v = float(v)
    if v > 2.0: return 'highly_destabilizing'
    elif v > 1.0: return 'destabilizing'
    elif v > 0.5: return 'mildly_destabilizing'
    elif v > -0.5: return 'neutral'
    elif v > -1.0: return 'mildly_stabilizing'
    elif v > -2.0: return 'stabilizing'
    else: return 'highly_stabilizing'

def classify_ddg_with_confidence(row):
    raw_cat = classify_ddg(row.get('ddg_monomer'))
    if raw_cat is None: return None
    confidence = row.get('ddg_confidence', 'unknown')
    if confidence == 'low':
        return raw_cat + '_unreliable'
    return raw_cat

df['ddg_category_raw'] = df['ddg_monomer'].apply(classify_ddg)
df['ddg_category'] = df.apply(classify_ddg_with_confidence, axis=1)

print(f"✓ DDG complete (v5.1)")
print(f"\nDDG confidence distribution:")
print(df['ddg_confidence'].value_counts().to_string())
print(f"\nDDG category distribution:")
print(df['ddg_category'].value_counts().to_string())

# Show impact of confidence gating
if df['ddg_category'].notna().any():
    unreliable = df['ddg_category'].str.contains('_unreliable', na=False).sum()
    print(f"\n  Variants reclassified as unreliable: {unreliable}")
    hi_ddg_low_conf = df[(df['ddg_category_raw'].isin(['highly_destabilizing','destabilizing'])) & 
                          (df['ddg_confidence'] == 'low')]
    if len(hi_ddg_low_conf) > 0:
        print(f"  High/destabilizing DDG at low pLDDT (now flagged):")
        for _, r in hi_ddg_low_conf.iterrows():
            print(f"    {r['gene']} {r['ref_aa']}{int(r['position'])}{r['alt_aa']}: "
                  f"DDG={r['ddg_monomer']:.2f}, pLDDT={r['best_plddt']:.0f}")


In [None]:
# =============================================================================
# CELL 6a: FOLDX MONOMER DDG FOR MISSING VARIANTS
# =============================================================================

import shutil

# FoldX binary — set FOLDX_PATH env var or ensure 'foldx' is on your PATH
FOLDX_BINARY = Path(os.environ.get("FOLDX_PATH", shutil.which("foldx") or "foldx"))
FOLDX_DIR_BASE = WORKING_DIR / "foldx_expanded"
FOLDX_DIR_BASE.mkdir(parents=True, exist_ok=True)

# rotabase.txt should be in the FoldX binary directory
ROTABASE = FOLDX_BINARY.parent / "rotabase.txt"

ONE_TO_THREE = {v:k for k,v in THREE_TO_ONE.items()}

def run_foldx_buildmodel(structure_path, chain_id, ref_aa, position, alt_aa, work_dir, n_runs=3):
    """Run FoldX BuildModel and return DDG. Returns None on failure."""
    work_dir = Path(work_dir)
    work_dir.mkdir(parents=True, exist_ok=True)

    # Copy structure to work dir
    struct_name = structure_path.name
    shutil.copy2(structure_path, work_dir / struct_name)

    # Copy rotabase if needed
    if ROTABASE.exists() and not (work_dir / "rotabase.txt").exists():
        shutil.copy2(ROTABASE, work_dir / "rotabase.txt")

    # FoldX mutation format: {wt_aa_1letter}{chain}{position}{mut_aa_1letter};
    mut_str = f"{ref_aa}{chain_id}{position}{alt_aa};"

    # Write individual_list.txt
    mut_file = work_dir / "individual_list.txt"
    mut_file.write_text(mut_str + "\n")

    # Run FoldX
    cmd = [
        str(FOLDX_BINARY),
        "--command=BuildModel",
        f"--pdb={struct_name}",
        "--mutant-file=individual_list.txt",
        f"--numberOfRuns={n_runs}",
        f"--output-dir={work_dir}",
    ]

    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, cwd=str(work_dir))
        if result.returncode != 0:
            print(f"    FoldX error: {result.stderr[:200]}")
            return None

        # Parse output: Dif_{struct_name}.fxout
        dif_file = work_dir / f"Dif_{struct_name.replace('.pdb','')}.fxout"
        if not dif_file.exists():
            # Try alternative naming
            for f in work_dir.glob("Dif_*.fxout"):
                dif_file = f
                break

        if not dif_file.exists():
            print(f"    No Dif output found in {work_dir}")
            return None

        # Read DDG: skip header lines, take average of runs
        ddg_values = []
        with open(dif_file) as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('Pdb') or line.startswith('#'):
                    continue
                parts = line.split('\t')
                if len(parts) >= 2:
                    try:
                        ddg_values.append(float(parts[1]))
                    except ValueError:
                        continue

        if ddg_values:
            return round(sum(ddg_values) / len(ddg_values), 4)
        return None

    except subprocess.TimeoutExpired:
        print(f"    FoldX timeout for {ref_aa}{position}{alt_aa}")
        return None
    except Exception as e:
        print(f"    FoldX exception: {e}")
        return None

# Identify variants missing monomer DDG
missing_mono = df[df['ddg_monomer'].isna()].copy()
print(f"Variants missing monomer DDG: {len(missing_mono)}")
if len(missing_mono) > 0:
    print(f"  Genes: {missing_mono['gene'].value_counts().to_string()}")

# Verify FoldX binary exists
if not FOLDX_BINARY.exists():
    print(f"⚠ FoldX binary not found at {FOLDX_BINARY}")
    print("  Skipping FoldX computation. Set FOLDX_BINARY path and re-run this cell.")
else:
    new_ddg_mono = {}
    for idx, row in missing_mono.iterrows():
        gene = str(row['gene']).lower()
        pos = int(row['position'])
        ref = str(row['ref_aa'])
        alt = str(row['alt_aa'])

        # Find PDB structure
        _, pdb_name = MONOMER_STRUCTURES.get(gene, (None, None))
        pdb_path = find_file(pdb_name)
        if pdb_path is None:
            print(f"  ⚠ No monomer PDB for {gene} — skipping {ref}{pos}{alt}")
            continue

        work_dir = FOLDX_DIR_BASE / "monomer" / f"{gene}_{ref}{pos}{alt}"
        print(f"  Running FoldX: {gene} {ref}{pos}{alt}...", end=" ")
        ddg = run_foldx_buildmodel(pdb_path, 'A', ref, pos, alt, work_dir)
        if ddg is not None:
            new_ddg_mono[idx] = ddg
            print(f"DDG = {ddg:.2f}")
        else:
            print("FAILED")

    # Merge new DDG values
    for idx, ddg in new_ddg_mono.items():
        df.loc[idx, 'ddg_monomer'] = ddg
        df.loc[idx, 'ddg_category'] = classify_ddg(ddg)

    print(f"\n✓ Computed {len(new_ddg_mono)} new monomer DDG values")
    print(f"  Total monomer DDG coverage: {df['ddg_monomer'].notna().sum()}/{len(df)}")


In [None]:
# =============================================================================
# CELL 6b: FOLDX MULTIMER DDG FOR MISSING VARIANTS
# =============================================================================

# For multimer DDG, run BuildModel on each complex PDB where the variant
# is at the interface, then compute interaction energy difference via AnalyseComplex.

def run_foldx_analysecomplex(structure_path, chains, work_dir):
    """Run FoldX AnalyseComplex and return interaction energy."""
    work_dir = Path(work_dir)
    struct_name = structure_path.name
    if not (work_dir / struct_name).exists():
        shutil.copy2(structure_path, work_dir / struct_name)
    if ROTABASE.exists() and not (work_dir / "rotabase.txt").exists():
        shutil.copy2(ROTABASE, work_dir / "rotabase.txt")

    cmd = [
        str(FOLDX_BINARY),
        "--command=AnalyseComplex",
        f"--pdb={struct_name}",
        f"--analyseComplexChains={chains}",
        f"--output-dir={work_dir}",
    ]
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, cwd=str(work_dir))
        if result.returncode != 0:
            return None

        # Parse Interaction_{name}_AC.fxout
        for f in work_dir.glob("Interaction_*_AC.fxout"):
            with open(f) as fh:
                for line in fh:
                    if line.startswith('Pdb') or line.startswith('#') or not line.strip():
                        continue
                    parts = line.strip().split('\t')
                    if len(parts) >= 6:
                        try:
                            return float(parts[5])  # Interaction Energy
                        except ValueError:
                            continue
        return None
    except Exception:
        return None


def run_foldx_multimer_ddg(pdb_path, chain_gene, chain_partner, ref_aa, position, alt_aa, work_dir, n_runs=3):
    """Compute multimer DDG: BuildModel then AnalyseComplex on WT and mutant."""
    work_dir = Path(work_dir)
    work_dir.mkdir(parents=True, exist_ok=True)
    chains_str = f"{chain_gene},{chain_partner}"

    # Step 1: AnalyseComplex on wild-type
    wt_dir = work_dir / "wt"
    wt_dir.mkdir(exist_ok=True)
    ie_wt = run_foldx_analysecomplex(pdb_path, chains_str, wt_dir)

    # Step 2: BuildModel to get mutant structure
    ddg_fold = run_foldx_buildmodel(pdb_path, chain_gene, ref_aa, position, alt_aa, work_dir, n_runs)

    # Step 3: Find mutant PDB and run AnalyseComplex
    struct_base = pdb_path.stem
    mutant_pdb = None
    for f in work_dir.glob(f"{struct_base}_1_*.pdb"):
        mutant_pdb = f; break
    if mutant_pdb is None:
        for f in work_dir.glob("*_1.pdb"):
            mutant_pdb = f; break

    ie_mut = None
    if mutant_pdb is not None:
        mut_dir = work_dir / "mut"
        mut_dir.mkdir(exist_ok=True)
        ie_mut = run_foldx_analysecomplex(mutant_pdb, chains_str, mut_dir)

    # DDG_binding = IE(mutant) - IE(wt)
    ddg_binding = None
    if ie_wt is not None and ie_mut is not None:
        ddg_binding = round(ie_mut - ie_wt, 4)

    return ddg_fold, ddg_binding


# Identify interface variants missing multimer DDG
if not FOLDX_BINARY.exists():
    print("⚠ FoldX binary not found — skipping multimer DDG")
else:
    # Variants at interfaces that don't have multimer DDG
    missing_multi = df[df['is_interface_any'] & df['ddg_multimer_max'].isna()].copy()
    print(f"Interface variants missing multimer DDG: {len(missing_multi)}")

    new_multi_results = []

    for idx, row in missing_multi.iterrows():
        gene = str(row['gene']).lower()
        pos = int(row['position'])
        ref = str(row['ref_aa'])
        alt = str(row['alt_aa'])
        partners = str(row.get('interface_partners', ''))
        if not partners or partners == 'nan':
            continue

        for partner in partners.split(';'):
            partner = partner.strip()
            if not partner:
                continue

            # Find the matching multimer definition
            matched = None
            for g1, pl, cif, pdb, c1, c2, primary in MULTIMER_STRUCTURES:
                if g1 == gene and pl == partner:
                    matched = (pdb, c1, c2)
                    break
                # Check reverse: partner gene with gene as partner
                partner_gene_map = {
                    'dvl2':'dvl2','ctnnb1':'ctnnb1','rock2':'rock2',
                    'gli3':'gli3','kpna1':'kpna1','kpna6':'kpna6',
                    'mdfi':'mdfi','tcf7l1':'tcf7l1',
                    'cdh2_truncated':'cdh2','cdh2_cyto':'cdh2',
                    'actin':'actb','actb':'actb','actb_no_bind':'actb',
                }
                pg = partner_gene_map.get(pl, pl)
                if g1 != gene and pg == gene and pl == partner:
                    matched = (pdb, c2, c1)  # Swap chains
                    break

            if matched is None:
                continue

            pdb_file, my_chain, partner_chain = matched
            pdb_path = find_file(pdb_file)
            if pdb_path is None:
                continue

            work_dir = FOLDX_DIR_BASE / "multimer" / f"{gene}_{ref}{pos}{alt}_{partner}"
            print(f"  Running FoldX multimer: {gene} {ref}{pos}{alt} × {partner}...", end=" ")

            ddg_fold, ddg_binding = run_foldx_multimer_ddg(
                pdb_path, my_chain, partner_chain, ref, pos, alt, work_dir
            )

            ddg_val = ddg_binding if ddg_binding is not None else ddg_fold
            if ddg_val is not None:
                new_multi_results.append({
                    'idx': idx, 'gene': gene, 'position': pos,
                    'ref_aa': ref, 'alt_aa': alt, 'partner': partner,
                    'ddg_multimer': ddg_val, 'ddg_fold': ddg_fold, 'ddg_binding': ddg_binding
                })
                print(f"DDG_binding={ddg_binding}, DDG_fold={ddg_fold}")
            else:
                print("FAILED")

    # Aggregate and merge
    if new_multi_results:
        multi_new = pd.DataFrame(new_multi_results)
        grp = multi_new.groupby('idx').agg(
            ddg_max=('ddg_multimer', 'max'),
            ddg_mean=('ddg_multimer', 'mean'),
            n_tested=('ddg_multimer', 'count'),
            partners=('partner', lambda x: ';'.join(x))
        )
        for idx, row in grp.iterrows():
            existing_max = df.loc[idx, 'ddg_multimer_max']
            if pd.isna(existing_max) or row['ddg_max'] > existing_max:
                df.loc[idx, 'ddg_multimer_max'] = row['ddg_max']
            existing_mean = df.loc[idx, 'ddg_multimer_mean']
            if pd.isna(existing_mean):
                df.loc[idx, 'ddg_multimer_mean'] = row['ddg_mean']
            existing_n = df.loc[idx, 'n_complexes_tested']
            if pd.isna(existing_n):
                df.loc[idx, 'n_complexes_tested'] = row['n_tested']
            else:
                df.loc[idx, 'n_complexes_tested'] = int(existing_n) + row['n_tested']
            existing_p = str(df.loc[idx, 'partners_tested'])
            if existing_p == 'nan' or not existing_p:
                df.loc[idx, 'partners_tested'] = row['partners']
            else:
                df.loc[idx, 'partners_tested'] = existing_p + ';' + row['partners']

        print(f"\n✓ Computed {len(new_multi_results)} new multimer DDG values across {len(grp)} variants")
    else:
        print("  No new multimer DDG computed")

    # Save expanded DDG results for future re-use
    foldx_out = RESULTS_DIR / "foldx_ddg_expanded_results.csv"
    ddg_cols_save = ['gene','position','ref_aa','alt_aa','ddg_monomer','ddg_category',
                     'ddg_multimer_max','ddg_multimer_mean','n_complexes_tested','partners_tested']
    df[[c for c in ddg_cols_save if c in df.columns]].to_csv(foldx_out, index=False)
    print(f"  Saved expanded DDG to {foldx_out}")

print(f"\nFinal DDG coverage:")
print(f"  Monomer DDG: {df['ddg_monomer'].notna().sum()}/{len(df)}")
print(f"  Multimer DDG: {df['ddg_multimer_max'].notna().sum()}/{len(df)}")


In [None]:
# =============================================================================
# CELL 7: SCORING, TIERS, MECHANISM (v5.1)
# =============================================================================
# v5 changes:
#   1. Disruption = substitution_severity × (monomer_contacts + Σ inter_contacts)
#   2. Disruption thresholds: ≥20→+4, ≥10→+3, ≥4→+2, ≥1→+1
#   3. Burial: best across monomer + confident multimer (pLDDT ≥ 50)
#   4. Grantham bonus removed (integrated into disruption)
#   5. pLDDT multiplier graduated: ≥70→×1.0, 50-69→×0.7, <50→×0.4
# v5.1 changes:
#   6. Mechanism: DDG-based classifications gated by pLDDT (ddg_confidence)
#   7. Low-pLDDT variants with high DDG → 'DDG unreliable (low confidence)'
#   8. Stabilizing DDG recognized in mechanism (highly_stabilizing category)
# =============================================================================

BURIAL_RANK = {'unknown': 0, 'surface_exposed': 1, 'partially_buried': 2, 'buried_core': 3}
RANK_TO_BURIAL = {v: k for k, v in BURIAL_RANK.items()}


def calculate_score(row):
    """
    Structural disruption score (v5).
    
    final_score = (disruption_pts + interface_pts + burial_pts) × pLDDT_multiplier
    
    Max possible: (4 + 2 + 2) × 1.0 = 8.0
    """
    score, ev = 0.0, []

    # =================================================================
    # DISRUPTION: substitution_severity × (mono_contacts + Σ inter_contacts)
    # =================================================================
    sev = sf(row.get('substitution_severity'), 0)
    mono_c = sf(row.get('monomer_n_contacts'), 0)

    # Sum inter-chain contacts across ALL multimer complexes
    inter_sum = 0.0
    inter_details = []
    for pl in all_partner_labels:
        ic_col = f"multi_{pl}_inter_contacts"
        if ic_col in row.index and pd.notna(row[ic_col]):
            ic_val = float(row[ic_col])
            if ic_val > 0:
                inter_sum += ic_val
                inter_details.append(f"{pl}:{int(ic_val)}")

    total_contacts = mono_c + inter_sum
    disruption = round(sev * total_contacts, 2)

    # Disruption → points (thresholds calibrated on high-confidence distribution)
    if   disruption >= 20: score += 4.0; ev.append(f'very_high_disruption({disruption:.1f})')
    elif disruption >= 10: score += 3.0; ev.append(f'high_disruption({disruption:.1f})')
    elif disruption >= 4:  score += 2.0; ev.append(f'moderate_disruption({disruption:.1f})')
    elif disruption >= 1:  score += 1.0; ev.append(f'low_disruption({disruption:.1f})')
    else:
        ev.append(f'no_disruption({disruption:.1f})')

    # =================================================================
    # INTERFACE POINTS (unchanged from v4)
    # =================================================================
    ni = si(row.get('n_interface_partners'), 0)
    ip = ss(row.get('interface_partners'))
    if   ni >= 2: score += 2.0; ev.append(f'multi_interface({ni})')
    elif ni == 1: score += 1.5; ev.append(f'interface({ip})')

    # =================================================================
    # BURIAL: best across monomer + confident multimer (pLDDT ≥ 50)
    # =================================================================
    mono_burial = ss(row.get('monomer_burial'))
    best_rank = BURIAL_RANK.get(mono_burial, 0)
    best_source = 'monomer'

    for pl in all_partner_labels:
        burial_col = f"multi_{pl}_burial"
        plddt_col  = f"multi_{pl}_plddt"
        if burial_col in row.index and pd.notna(row.get(burial_col)):
            pl_plddt = sf(row.get(plddt_col), 0)
            if pl_plddt >= 50:  # only trust multimer burial if confident
                pl_burial = ss(row[burial_col])
                pl_rank = BURIAL_RANK.get(pl_burial, 0)
                if pl_rank > best_rank:
                    best_rank = pl_rank
                    best_source = pl

    best_burial = RANK_TO_BURIAL.get(best_rank, 'unknown')
    if   best_burial == 'buried_core':      score += 2.0; ev.append(f'buried_core({best_source})')
    elif best_burial == 'partially_buried': score += 1.0; ev.append(f'partially_buried({best_source})')

    # =================================================================
    # pLDDT MULTIPLIER (graduated)
    # =================================================================
    bp = row.get('best_plddt')
    if bp is not None and not pd.isna(bp):
        bp_val = float(bp)
        if bp_val < 50:
            score *= 0.4
            ev.append(f'very_low_plddt_discount({int(bp_val)})')
        elif bp_val < 70:
            score *= 0.7
            ev.append(f'low_plddt_discount({int(bp_val)})')
        # else: ×1.0, no discount

    final = round(score, 2)
    return (final, ';'.join(ev), disruption, total_contacts, inter_sum, best_burial, best_source)


def assign_tier(s):
    if pd.isna(s): return 'Tier 4 - Likely benign'
    s = float(s)
    if   s >= 5.0: return 'Tier 1 - High confidence pathogenic'
    elif s >= 3.0: return 'Tier 2 - Likely pathogenic'
    elif s >= 1.5: return 'Tier 3 - VUS with evidence'
    else:          return 'Tier 4 - Likely benign'


def classify_mechanism(row):
    """
    FoldX DDG mechanism classification (v5.1).
    
    Changes from v5:
      - DDG-based classifications gated by ddg_confidence
      - Low-pLDDT variants → 'DDG unreliable (low confidence)' instead of mechanism credit
      - Stabilizing DDG (< -2) recognized as potential gain-of-function mechanism
    """
    tier     = ss(row.get('tier'))
    ddg_m    = row.get('ddg_monomer')
    ddg_x    = row.get('ddg_multimer_max')
    ddg_min  = row.get('ddg_multimer_min')
    is_if    = sb(row.get('is_interface_any'), False)
    ht       = 'Tier 1' in tier or 'Tier 2' in tier
    ddg_conf = ss(row.get('ddg_confidence'))

    has_dm = pd.notna(ddg_m)
    has_dx = pd.notna(ddg_x)

    # ---- pLDDT gate: if low confidence, DDG is unreliable ----
    if ddg_conf == 'low' and (has_dm or has_dx):
        any_extreme = (has_dm and abs(float(ddg_m)) > 2.0) or (has_dx and abs(float(ddg_x)) > 2.0)
        if any_extreme:
            return 'DDG unreliable (low confidence)'
        # Mild DDG at low pLDDT — still unreliable but not worth flagging
        if ht:
            if is_if:
                return 'Interface disruption (DDG-neutral)'
            else:
                return 'Structural tier (DDG unreliable)'
        else:
            if not has_dm and not has_dx:
                return 'No DDG data'
            return 'Likely benign'

    # ---- Reliable DDG (pLDDT >= 50) ----
    # Monomer DDG thresholds
    hi_dm   = has_dm and float(ddg_m) > 2.0
    neut_dm = has_dm and abs(float(ddg_m)) < 1.0
    mod_dm  = has_dm and 0.5 <= float(ddg_m) <= 1.5
    stab_dm = has_dm and float(ddg_m) < -2.0  # v5.1: highly stabilizing

    # Multimer DDG thresholds
    hi_dx   = has_dx and float(ddg_x) > 2.0
    mod_dx  = has_dx and 0.5 <= float(ddg_x) <= 1.5
    stab_dx = has_dx and pd.notna(ddg_min) and float(ddg_min) < -2.0  # v5.1: use min for stabilizing

    if ht:
        # High-tier mechanisms (Tier 1 or 2)
        if hi_dm and is_if and hi_dx:
            return 'Dual mechanism (fold + PPI)'
        elif hi_dm and is_if:
            return 'Dual mechanism (fold + interface)'
        elif neut_dm and is_if and hi_dx:
            return 'Complex destabilization (PPI-specific)'
        elif is_if and not hi_dm:
            return 'Interface disruption (DDG-neutral)'
        elif hi_dm:
            return 'Fold destabilization'
        elif stab_dm or stab_dx:
            return 'Potential gain-of-function (over-stabilization)'
        elif mod_dm or mod_dx:
            return 'Structural concern (moderate DDG)'
        elif has_dm or has_dx:
            return 'Structural tier (low DDG)'
        else:
            return 'No DDG data'
    else:
        # Low-tier mechanisms (Tier 3 or 4)
        if hi_dm or hi_dx:
            return 'High DDG only (Tier 3/4)'
        elif stab_dm or stab_dx:
            return 'High stabilizing DDG only (Tier 3/4)'
        elif not has_dm and not has_dx:
            return 'No DDG data'
        else:
            return 'Likely benign'


# === Apply scoring ===
results = df.apply(lambda r: calculate_score(r), axis=1)
df['final_score']       = [r[0] for r in results]
df['score_evidence']    = [r[1] for r in results]
df['contact_disruption']= [r[2] for r in results]
df['total_contacts']    = [r[3] for r in results]
df['inter_contacts_sum']= [r[4] for r in results]
df['best_burial']       = [r[5] for r in results]
df['best_burial_source']= [r[6] for r in results]

df['tier'] = df['final_score'].apply(assign_tier)
df['final_mechanism'] = df.apply(classify_mechanism, axis=1)
df['pathogenic_mechanism'] = df['final_mechanism']

# === Summary statistics ===
print("✓ Scoring complete (v5)")
print(f"\nTier distribution:")
print(df['tier'].value_counts().to_string())
print(f"\nScore statistics:")
print(f"  Range: {df['final_score'].min():.2f} – {df['final_score'].max():.2f}")
print(f"  Mean:  {df['final_score'].mean():.2f}")
print(f"\nDisruption statistics:")
print(f"  Range: {df['contact_disruption'].min():.1f} – {df['contact_disruption'].max():.1f}")
print(f"  Non-zero: {(df['contact_disruption'] > 0).sum()}/{len(df)}")
print(f"\nBurial upgrades (multimer > monomer):")
upgraded = df[df['best_burial_source'] != 'monomer']
print(f"  {len(upgraded)}/{len(df)} variants use multimer burial")
print(f"\npLDDT multiplier distribution:")
for label, lo, hi in [('×1.0 (≥70)', 70, 999), ('×0.7 (50-69)', 50, 70), ('×0.4 (<50)', 0, 50)]:
    n = ((df['best_plddt'] >= lo) & (df['best_plddt'] < hi)).sum()
    print(f"  {label}: {n}")
print(f"\nMechanism classification:")
print(df['final_mechanism'].value_counts().to_string())



In [None]:
# =============================================================================
# CELL 8: ANNOTATIONS + CONCORDANCE (v5.2)
# =============================================================================
# v5.2 changes:
#   1. DDG vote uses max(|mono|, |multi_max|, |multi_min|) — fixes multi_max bug
#   2. DDG confidence gating via ddg_confidence column (not raw pLDDT)
#   3. New sub-score columns: structure_strict/relaxed, external_strict/relaxed
#   4. T3-inclusive variants of all sub-scores
# =============================================================================

df['gene_lower'] = df['gene'].astype(str).str.lower()
if annotation_df is not None:
    ann = annotation_df.copy()
    ann.columns = [c.lower() for c in ann.columns]
    ann['gene_lower'] = ann['gene'].astype(str).str.lower()
    for col in ['alphamissense','alphamissense_pathogenicity','franklin']:
        if col in ann.columns:
            m = ann[['gene_lower','position','ref_aa','alt_aa',col]].drop_duplicates()
            df = df.merge(m, on=['gene_lower','position','ref_aa','alt_aa'], how='left')
            rn = {'alphamissense':'AlphaMissense','alphamissense_pathogenicity':'AlphaMissense_pathogenicity'}
            if col in rn: df = df.rename(columns={col: rn[col]})
            print(f"✓ {col}: {df[rn.get(col,col)].notna().sum()}/{len(df)}")

for col in ['AlphaMissense','AlphaMissense_pathogenicity','franklin']:
    if col not in df.columns: df[col] = None

df = df.drop(columns=['gene_lower'], errors='ignore')

def build_evidence(row):
    parts = []
    ddg_m = row.get('ddg_monomer')
    if pd.notna(ddg_m):
        v = float(ddg_m)
        if v > 2.0: parts.append(f'DDG_high({v:.1f})')
        elif v > 1.0: parts.append(f'DDG_mod({v:.1f})')
    ip = ss(row.get('interface_partners'))
    if ip: parts.append(f'interface({ip})')
    if ss(row.get('monomer_burial')) == 'buried_core': parts.append('buried')
    am = ss(row.get('AlphaMissense')).lower()
    if am == 'likely_pathogenic': parts.append('AM_path')
    elif am == 'ambiguous': parts.append('AM_amb')
    return '; '.join(parts) if parts else 'Limited evidence'

df['evidence_summary'] = df.apply(build_evidence, axis=1)

def classify_integrated(row):
    tier, am, conf = ss(row.get('tier')), ss(row.get('AlphaMissense')).lower(), ss(row.get('confidence'))
    is_if = sb(row.get('is_interface_any'), False)
    ht = 'Tier 1' in tier or 'Tier 2' in tier
    am_p, am_a = am == 'likely_pathogenic', am == 'ambiguous'
    if ht and am_p: return 'Class A - Concordant pathogenic'
    elif ht and am_a and is_if: return 'Class B - Likely pathogenic (structural)'
    elif ht and is_if and conf == 'high': return 'Class C - Interface disruptor (structural only)'
    elif ht and conf == 'high': return 'Class D - Structural evidence only (high confidence)'
    elif ht and conf == 'low': return 'Class E - Structural evidence only (low confidence)'
    elif am_p and 'Tier 3' in tier: return 'Class F - AlphaMissense pathogenic (weak structural)'
    elif am_a and ht: return 'Class G - VUS (mixed evidence)'
    elif ht: return 'Class H - VUS (weak evidence)'
    elif am_p: return 'Class I - AlphaMissense only'
    else: return 'Class J - Likely benign'

df['integrated_class'] = df.apply(classify_integrated, axis=1)

def std_franklin(v):
    if pd.isna(v): return 'No data'
    v = str(v).strip(); vl = v.lower()
    if 'pathogenic' in vl and 'likely' in vl: return 'Likely pathogenic'
    elif 'pathogenic' in vl: return 'Pathogenic'
    elif 'benign' in vl and 'likely' in vl: return 'Likely benign'
    elif 'benign' in vl: return 'Benign'
    elif 'vus' in vl and 'high' in vl: return 'VUS (high)'
    elif 'vus' in vl and 'mid' in vl: return 'VUS (mid)'
    elif 'vus' in vl and 'low' in vl: return 'VUS (low)'
    elif 'vus' in vl: return 'VUS'
    return v

# --- Normalize AlphaMissense: classify raw scores ---
def classify_am(val):
    if pd.isna(val): return val
    s = str(val).strip()
    if s in ('likely_pathogenic','likely_benign','ambiguous'): return s
    try:
        score = float(s)
        if score >= 0.564: return 'likely_pathogenic'
        elif score < 0.340: return 'likely_benign'
        else: return 'ambiguous'
    except ValueError:
        return s

if 'AlphaMissense' in df.columns:
    df['AlphaMissense_raw'] = df['AlphaMissense']
    df['AlphaMissense'] = df['AlphaMissense'].apply(classify_am)
    print("AlphaMissense normalized:", df['AlphaMissense'].value_counts().to_string())

# --- Normalize Franklin: fix casing and typos ---
if 'franklin' in df.columns:
    df['franklin_raw'] = df['franklin']
    def norm_franklin(v):
        if pd.isna(v): return v
        s = str(v).strip().lower()
        m = {'benign':'benign','likely benign':'likely benign',
             'vus (low)':'VUS (low)','vus (mid)':'VUS (mid)',
             'vus mid)':'VUS (mid)','vus(mid)':'VUS (mid)',
             'vus (high)':'VUS (high)',
             'pathogenic':'pathogenic','likely pathogenic':'likely pathogenic'}
        return m.get(s, v)
    df['franklin'] = df['franklin'].apply(norm_franklin)
    print("Franklin normalized:", df['franklin'].value_counts().to_string())


# =============================================================================
# CONCORDANCE (v5.2)
# =============================================================================
# v5.2 fixes:
#   - DDG vote: max(|mono|, |multi_max|, |multi_min|) with confidence gating
#     (v5.1 bug: only checked multi_max positive, multi_min negative)
#   - New sub-score columns for transparent decomposition
#
# Four evidence lines:
#   1. Structural tier:  1 if Tier 1/2 (or Tier 3 for T3-inclusive)
#   2. DDG:              1 if max_abs_ddg >= threshold AND confidence passes
#   3. AlphaMissense:    1 if pathogenic (strict) or ambiguous+ (relaxed)
#   4. Franklin:         1 if VUS(high)+ (strict) or VUS(mid)+ (relaxed)
#
# DDG thresholds:
#   Standard: max_abs_ddg >= 2.0, ddg_confidence = 'high'
#   Relaxed:  max_abs_ddg >= 1.0, ddg_confidence != 'low'
# =============================================================================

def concordance_v52(row):
    tier = ss(row.get('tier'))
    am = ss(row.get('AlphaMissense')).lower()
    fr = std_franklin(row.get('franklin'))
    fr_lower = fr.lower() if isinstance(fr, str) else ''
    ddg_conf = ss(row.get('ddg_confidence')).lower()

    # --- Max absolute DDG across all sources (v5.2 FIX) ---
    ddg_vals = []
    for col in ['ddg_monomer', 'ddg_multimer_max', 'ddg_multimer_min']:
        v = row.get(col)
        if pd.notna(v):
            ddg_vals.append(abs(float(v)))
    max_abs_ddg = max(ddg_vals) if ddg_vals else 0.0

    # --- Tier votes ---
    tier_t12 = 'Tier 1' in tier or 'Tier 2' in tier
    tier_t123 = tier_t12 or 'Tier 3' in tier

    # --- DDG votes (v5.2: uses max_abs_ddg + ddg_confidence column) ---
    ddg_vote_strict  = 1 if ddg_conf == 'high' and max_abs_ddg >= 2.0 else 0
    ddg_vote_relaxed = 1 if ddg_conf != 'low'  and max_abs_ddg >= 1.0 else 0

    # --- AlphaMissense votes ---
    am_strict  = 1 if am == 'likely_pathogenic' else 0
    am_relaxed = 1 if am in ('likely_pathogenic', 'ambiguous') else 0

    # --- Franklin votes ---
    fr_strict  = 1 if fr_lower in ('vus (high)', 'pathogenic', 'likely pathogenic') else 0
    fr_relaxed = 1 if fr_lower in ('vus (high)', 'vus (mid)', 'pathogenic', 'likely pathogenic') else 0

    # --- Sub-scores (new in v5.2) ---
    tv = 1 if tier_t12 else 0
    tv3 = 1 if tier_t123 else 0

    struct_s = tv + ddg_vote_strict
    struct_r = tv + ddg_vote_relaxed
    struct_s_t3 = tv3 + ddg_vote_strict
    struct_r_t3 = tv3 + ddg_vote_relaxed
    ext_s = am_strict + fr_strict
    ext_r = am_relaxed + fr_relaxed

    # --- Three-way labels (for backward compat) ---
    n3 = tv + am_strict + fr_strict
    labels3 = ['All 3 benign/VUS', '1 of 3 pathogenic',
               '2 of 3 pathogenic', 'All 3 agree pathogenic']
    n3_t3 = tv3 + am_strict + fr_strict

    return pd.Series({
        # Sub-scores (new in v5.2)
        'structure_strict':      struct_s,
        'structure_relaxed':     struct_r,
        'structure_strict_t3':   struct_s_t3,
        'structure_relaxed_t3':  struct_r_t3,
        'external_strict':       ext_s,
        'external_relaxed':      ext_r,

        # Three-way concordance
        'three_way':             labels3[n3],
        'three_way_ambiguous':   tv + am_relaxed + fr_relaxed,
        'three_way_t3':          labels3[min(n3_t3, 3)],

        # Four-way concordance (sub-score sums)
        'four_way':              struct_s + ext_s,
        'four_way_ambiguous_as_pathogenic_ddg_1_threshold': struct_r + ext_r,
        'four_way_t3':           struct_s_t3 + ext_s,
        'four_way_t3_ambiguous': struct_r_t3 + ext_r,
    })


conc = df.apply(concordance_v52, axis=1)
for c in conc.columns:
    df[c] = conc[c]

print(f"\n✓ Concordance computed (v5.2 — fixed DDG + sub-scores)")

# --- Summary stats ---
print("\nSub-score distributions:")
for col in ['structure_strict','structure_relaxed','external_strict','external_relaxed']:
    print(f"  {col}: {df[col].value_counts().sort_index().to_dict()}")

print(f"\nFour-way strict:")
print(df['four_way'].value_counts().sort_index().to_string())
print(f"\nFour-way relaxed:")
print(df['four_way_ambiguous_as_pathogenic_ddg_1_threshold'].value_counts().sort_index().to_string())

# --- DDG gating impact ---
_has_ddg = df['ddg_monomer'].notna() | df['ddg_multimer_max'].notna() | df['ddg_multimer_min'].notna()
_low = _has_ddg & (df['ddg_confidence'] == 'low')
_mod = _has_ddg & (df['ddg_confidence'] == 'moderate')
_hi  = _has_ddg & (df['ddg_confidence'] == 'high')
print(f"\nDDG gating impact:")
print(f"  Variants with DDG data: {_has_ddg.sum()}")
print(f"  confidence=low  (excluded from all concordance): {_low.sum()}")
print(f"  confidence=moderate (relaxed only):              {_mod.sum()}")
print(f"  confidence=high (strict + relaxed):              {_hi.sum()}")


In [None]:
# =============================================================================
# CELL 9: SAVE RESULTS (v5.2)
# =============================================================================

id_cols = ['gene','position','ref_aa','alt_aa']
grantham_cols = ['grantham_distance','grantham_class','substitution_severity','property_changes']
mono_cols = ['monomer_plddt','monomer_plddt_category','monomer_n_contacts',
             'monomer_contact_category','monomer_aa','monomer_accessibility',
             'monomer_burial','monomer_secondary_structure','monomer_contact_disruption']
summary_cols = ['n_multimer_complexes','multimer_partners']

partner_order = ['actin','actb','dvl2','cdh2_truncated','ctnnb1','cdh2_cyto','actb_no_bind','rock2',
                 'shroom3','gli3','kpna1','kpna6','mdfi','tcf7l1','zic3']
per_complex_cols = []
for pl in partner_order:
    for suffix in ['_plddt','_n_contacts','_inter_contacts','_is_interface',
                   '_accessibility','_burial','_sec_struct','_disruption']:
        col = f"multi_{pl}{suffix}"
        if col in df.columns: per_complex_cols.append(col)

summary2 = ['is_interface_any','interface_partners','n_interface_partners',
            'multimer_plddt_max','multimer_plddt_avg','multimer_contacts_max',
            'multimer_contacts_avg','multimer_disruption_max','multimer_disruption_avg']

# v5 scoring columns
score_cols_v5 = ['contact_disruption','total_contacts','inter_contacts_sum',
                 'best_burial','best_burial_source',
                 'final_score','confidence','best_plddt','score_evidence','tier']

am_cols = ['AlphaMissense','AlphaMissense_pathogenicity','franklin','integrated_class']
ddg_cols = ['ddg_monomer','ddg_category','ddg_category_raw','ddg_confidence',
            'pathogenic_mechanism','evidence_summary',
            'ddg_multimer_max','ddg_multimer_min','ddg_multimer_mean',
            'n_complexes_tested','partners_tested']
mech_cols = ['final_mechanism']

# v5.2: sub-scores before concordance totals
subscore_cols = ['structure_strict','structure_relaxed',
                 'structure_strict_t3','structure_relaxed_t3',
                 'external_strict','external_relaxed']
conc_cols = ['three_way','three_way_ambiguous','three_way_t3',
             'four_way','four_way_ambiguous_as_pathogenic_ddg_1_threshold',
             'four_way_t3','four_way_t3_ambiguous']

ordered = (id_cols + grantham_cols + mono_cols + summary_cols + per_complex_cols +
           summary2 + score_cols_v5 + am_cols + ddg_cols + mech_cols +
           subscore_cols + conc_cols)
remaining = [c for c in df.columns if c not in ordered]
final_cols = [c for c in ordered if c in df.columns] + remaining
df_out = df[final_cols]

out = RESULTS_DIR / "variant_comprehensive_v5_2.csv"
df_out.to_csv(out, index=False)

hp = df_out[df_out['tier'].str.contains('Tier 1|Tier 2', regex=True)]
hp.to_csv(RESULTS_DIR / "high_priority_variants_v5_2.csv", index=False)

scols = ['gene','position','ref_aa','alt_aa','contact_disruption','total_contacts',
         'inter_contacts_sum','best_burial','best_burial_source',
         'interface_partners','best_plddt','score_evidence','tier',
         'ddg_monomer','ddg_category','ddg_confidence',
         'ddg_multimer_max','ddg_multimer_min',
         'evidence_summary','final_mechanism','AlphaMissense','integrated_class','franklin',
         'structure_strict','structure_relaxed','external_strict','external_relaxed',
         'three_way','three_way_ambiguous',
         'four_way','four_way_ambiguous_as_pathogenic_ddg_1_threshold',
         'three_way_t3','four_way_t3','four_way_t3_ambiguous']
df_out[[c for c in scols if c in df_out.columns]].to_csv(
    RESULTS_DIR / "variant_pipeline_results_summary_v5_2.csv", index=False)

print("=" * 60)
print("PIPELINE v5.2 COMPLETE")
print("=" * 60)
print(f"Total variants: {len(df_out)}")
print(f"Total columns:  {len(final_cols)}")

# =============================================================================
# VALIDATION: Concordance spot-checks (v5.2)
# =============================================================================
print("\n=== CONCORDANCE VALIDATION (v5.2) ===")

def validate_concordance(df_out, gene, pos, expected):
    """Validate concordance values for a specific variant."""
    row = df_out[(df_out['gene']==gene)&(df_out['position']==pos)]
    if len(row) == 0:
        print(f"\n  ⚠ {gene} pos {pos} not found")
        return
    r = row.iloc[0]
    label = f"{r['ref_aa']}{pos}{r['alt_aa']}"
    print(f"\n  {gene} {label}:")
    all_pass = True
    for col, exp_val in expected.items():
        got = r.get(col)
        try:
            match = int(got) == int(exp_val) if isinstance(exp_val, int) else str(got) == str(exp_val)
        except:
            match = str(got) == str(exp_val)
        sym = '✓' if match else '✗'
        if not match: all_pass = False
        print(f"    {sym} {col:45s} got={str(got):>5s}  expected={str(exp_val):>5s}")
    return all_pass

# ZIC3 C297F: T1, mono=8.54 (high conf) → struct=2, AM=path, FR=VUS(high) → ext=2, 4way=4
validate_concordance(df_out, 'zic3', 297, {
    'structure_strict': 2, 'structure_relaxed': 2,
    'external_strict': 2, 'external_relaxed': 2,
    'four_way': 4, 'four_way_ambiguous_as_pathogenic_ddg_1_threshold': 4,
})

# ZIC3 K405E: T2, mono=-0.15 but multi_max=-4.32 (high conf) → strict DDG=1 (|4.32|>=2)
validate_concordance(df_out, 'zic3', 405, {
    'structure_strict': 2, 'structure_relaxed': 2,
    'external_strict': 2, 'external_relaxed': 2,
    'four_way': 4, 'four_way_ambiguous_as_pathogenic_ddg_1_threshold': 4,
})

# ZIC3 H318N: T1, mono=0.20, multi_max=-1.71 (high conf) → strict DDG=0 (|1.71|<2), relax DDG=1
validate_concordance(df_out, 'zic3', 318, {
    'structure_strict': 1, 'structure_relaxed': 2,
    'external_strict': 2, 'external_relaxed': 2,
    'four_way': 3, 'four_way_ambiguous_as_pathogenic_ddg_1_threshold': 4,
})

# SHROOM3 H161Q: T2, mono=-0.15, multi_min=-4.67 (moderate conf) → strict DDG=0, relax DDG=1
validate_concordance(df_out, 'shroom3', 161, {
    'structure_strict': 1, 'structure_relaxed': 2,
    'external_strict': 0, 'external_relaxed': 2,
    'four_way': 1, 'four_way_ambiguous_as_pathogenic_ddg_1_threshold': 4,
})

# GLI3 P103S: T4, mono=-1.47 (LOW conf) → DDG=0 regardless, FR=VUS(mid) → ext_r=1
validate_concordance(df_out, 'gli3', 103, {
    'structure_strict': 0, 'structure_relaxed': 0,
    'external_strict': 0, 'external_relaxed': 1,
    'four_way': 0, 'four_way_ambiguous_as_pathogenic_ddg_1_threshold': 1,
})

# =============================================================================
# VALIDATION: Manual score checks against v5 formula (unchanged from v5.1)
# =============================================================================
print("\n=== SCORING VALIDATION (v5 formula) ===")

def validate_variant(df_out, gene, pos, expected_checks):
    """Validate a variant's scoring against expected values."""
    row = df_out[(df_out['gene']==gene)&(df_out['position']==pos)]
    if len(row) == 0:
        print(f"\n  ⚠ {gene} pos {pos} not found")
        return
    r = row.iloc[0]
    label = f"{r['ref_aa']}{pos}{r['alt_aa']}"
    print(f"\n  {gene} {label}:")
    for name, col, expected in expected_checks:
        got = r.get(col)
        got_s = str(got)[:35] if got is not None else 'None'
        if expected is None:
            match = '·'
        elif isinstance(expected, str):
            match = '✓' if expected in got_s else '✗'
        elif isinstance(expected, (int, float)):
            try:
                match = '✓' if abs(float(got) - expected) < 0.15 else '✗'
            except:
                match = '✗'
        else:
            match = '✓' if str(expected) in got_s else '✗'
        print(f"    {match} {name:30s}  got={got_s:>20s}  expected={str(expected):>15s}")

validate_variant(df_out, 'shroom3', 35, [
    ('final_score',  'final_score',  None),
    ('tier',         'tier',         'Tier 1'),
    ('four_way',     'four_way',     None),
])

validate_variant(df_out, 'zic3', 350, [
    ('final_score',  'final_score',  None),
    ('tier',         'tier',         None),
    ('four_way',     'four_way',     None),
])

zero_plddt = df_out['monomer_plddt'].isna().sum()
print(f"\nMonomer pLDDT missing: {zero_plddt}/{len(df_out)}")

print(f"\nSaved to {RESULTS_DIR}/")
print("  variant_comprehensive_v5_2.csv")
print("  high_priority_variants_v5_2.csv")
print("  variant_pipeline_results_summary_v5_2.csv")
