In [2]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import QED, Crippen, Descriptors, Lipinski, rdMolDescriptors
from rdkit.Chem import AllChem
from rdkit.Chem import rdMolDescriptors
import math
import pickle
import gzip
import os

In [None]:
from rdkit.Chem import rdMolDescriptors
import math

def calculate_sa_score(mol):
    """
    Calcula o Synthetic Accessibility score (SA)
    Valores típicos ~1 (fácil) até ~10 (difícil).
    """
    # Fragment score
    fp = rdMolDescriptors.GetMorganFingerprint(mol, 2)
    fps = fp.GetNonzeroElements()
    score1 = 0.
    for bit in fps:
        score1 += math.log(fps[bit]+1)
    score1 /= len(fps) if len(fps) > 0 else 1

    # Complexidade estrutural (rings, estereo, spiro, ponte)
    n_atoms = mol.GetNumAtoms()
    n_chiral_centers = len(Chem.FindMolChiralCenters(mol, includeUnassigned=True))
    n_bridgeheads = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
    n_spiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
    n_macrocycles = sum(1 for x in mol.GetRingInfo().AtomRings() if len(x) > 8)

    score2 = math.log(n_atoms+1) + n_chiral_centers + n_bridgeheads + n_spiro + n_macrocycles

    # Normalização simples (1–10)
    raw_score = score1 + score2
    sa_score = 1 + 9 * (raw_score / (raw_score + 10.0))
    return sa_score


import pandas as pd
from rdkit import Chem
from rdkit.Chem import QED, Crippen, Descriptors, Lipinski, rdMolDescriptors
from rdkit.Chem import AllChem

# Para SA Score (implementação de Ertl & Schuffenhauer)
from rdkit.Chem import rdMolDescriptors
import math
import pickle
import gzip
import os


# Lipinski-like Physicochemical Filter 
def passes_lipinski(mol):
    mw = Descriptors.MolWt(mol)
    logp = Crippen.MolLogP(mol)
    hbd = Lipinski.NumHDonors(mol)
    hba = Lipinski.NumHAcceptors(mol)

    if mw < 100 or mw > 500:
        return False
    if logp > 5:
        return False
    if hba > 10:
        return False
    if hbd > 5:
        return False
    return True

# Medicinal Chemistry Filter 

forbidden_smarts = {
    "Aziridine": "[N3]1CC1",  
    "Nitroso": "[NX2]=O",  
    "Acyl Chloride": "C(=O)Cl",  
    "Reactive Halogen": "[CX4][Cl,Br,I,F]",  
    "Strained Triple Bond": "C#C-C#C",  
}

forbidden_patterns = {name: Chem.MolFromSmarts(s) for name, s in forbidden_smarts.items()}

def passes_mcf(mol):
    for name, patt in forbidden_patterns.items():
        if mol.HasSubstructMatch(patt):
            return False, name
    ri = mol.GetRingInfo()
    for ring in ri.AtomRings():
        if len(ring) > 8:
            return False, "Large Ring"
    return True, None

def maximum_ring_size(mol):
    rings = mol.GetRingInfo().AtomRings()
    return 0 if not rings else max(len(r) for r in rings)

def passes_wehi_mcf_from_files(smi):
    """Use mcf.csv + wehi_pains.csv ."""
    if not _filters:
        return True
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return False
    h_mol = Chem.AddHs(mol)
    return not any(h_mol.HasSubstructMatch(patt) for patt in _filters)

def pains_filt_from_txt(mol):
    """Devolve lista de TAGs PAINS encontradas segundo pains.txt (se existir)."""
    if not _pains_txt_dict:
        return []
    hits = []
    for smarts, tag in _pains_txt_dict.items():
        patt = Chem.MolFromSmarts(smarts)
        if patt is not None and mol.HasSubstructMatch(patt):
            hits.append(tag)
    return hits

def filter_by_pattern(mol, pattern):
    patt = Chem.MolFromSmarts(pattern)
    return patt is not None and mol.HasSubstructMatch(patt)

def filter_phosphorus(mol):
    """Reprova P fora de *~P(=O)~*."""
    if filter_by_pattern(mol, "[P,p]"):
        if not filter_by_pattern(mol, "*~[P,p](=O)~*"):
            return True
    return False


# 3)  fragmentos proibidos (adiciona ao teu MCF simples)
_expanded_forbidden_smarts = [
    "*1=**=*1",
    "*1*=*=*1",
    "*1~*=*1",
    "[F,Cl,Br]C=[O,S,N]",
    "[Br]-C-C=[O,S,N]",
    "[N,n,S,s,O,o]C[F,Cl,Br]",
    "[I]",
    "[S&X3]",
    "[S&X5]",
    "[S&X6]",
    "[B,N,n,O,S]~[F,Cl,Br,I]",
    "*=*=*=*",
    "*=[NH]",
    "[P,p]~[F,Cl,Br]",
    "SS",
    "C#C",
    "C=C=C",
    "C=C=N",
    "NNN",
    "[*;R1]1~[*]~[*]~[*]1",
    "OOO",
    "[#8]1-[#6]2[#8][#6][#8][#6]12",   # epóxido
    "N=C=O",                            # isocianato
    "C1CN1",                            # aziridina
    "[#6](=[#8])[F,Cl,Br,I]",           # acil-haletos
    "[#6](=[#8])=[#6](-[#8])-[#6](=[#8])~[#8]",  # quinona
    "N(-[#6])=[#7]-[#8]"                # nitrosamina
]
_expanded_patterns = [Chem.MolFromSmarts(s) for s in _expanded_forbidden_smarts if Chem.MolFromSmarts(s) is not None]

def substructure_violations(mol):
    for sm in _expanded_forbidden_smarts:
        patt = Chem.MolFromSmarts(sm)
        if patt is not None and mol.HasSubstructMatch(patt):
            return True, f"Forbidden fragment: {sm}"
    return False, None


def passes_full_medicinal_filters(smi):
    """
    Integra:
      - teu passes_mcf(mol) (lista básica)
      - lista expandida de subestruturas
      - MCF/WEHI de ficheiros
      - PAINS de pains.txt
      - carga formal, radicais, bridgeheads, anel >8
      - fósforo inadequado
      - TPSA <= 140, rotors <= 10
      - exclusão Si/Sn
    """
    if 'Si' in smi or 'Sn' in smi:
        return False, "Si/Sn not allowed"

    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return False, "Invalid SMILES"

    # Carga formal / radicais
    if Chem.rdmolops.GetFormalCharge(mol) != 0:
        return False, "Non-zero formal charge"
    if Descriptors.NumRadicalElectrons(mol) != 0:
        return False, "Radicals present"

    # Bridgeheads e tamanho de anel
    if rdMolDescriptors.CalcNumBridgeheadAtoms(mol) > 2:
        return False, "Too many bridgehead atoms"
    if maximum_ring_size(mol) > 8:
        return False, "Ring size > 8"

    # Fósforo
    if filter_phosphorus(mol):
        return False, "Improper phosphorus context"

    # Tua lista simples (já existente)
    ok_basic, why_basic = passes_mcf(mol)
    if not ok_basic:
        return False, f"Basic MCF: {why_basic}"

    # Lista expandida
    viol, why = substructure_violations(mol)
    if viol:
        return False, why

    # WEHI/MCF de ficheiros
    if not passes_wehi_mcf_from_files(smi):
        return False, "WEHI/MCF forbidden pattern"

    # PAINS de pains.txt
    pains_hits = pains_filt_from_txt(mol)
    if pains_hits:
        return False, f"PAINS: {', '.join(pains_hits)}"

    # TPSA e rotors
    if Descriptors.TPSA(mol) > 140:
        return False, "TPSA > 140"
    if Descriptors.NumRotatableBonds(mol) > 10:
        return False, "Too many rotatable bonds (>10)"

    return True, ""


# 5) SA/QED check 
def sa_qed_pass(mol, sa_thr=4.5, qed_thr=0.3):
    from rdkit.Chem import QED
    sa = calculate_sa_score(mol)  # tua função
    qed = QED.qed(mol)
    return (sa < sa_thr and qed > qed_thr, sa, qed)

-
def evaluate_smiles(smi):
    """Aplica: SA/QED + Lipinski (já tens) + filtros medicinais ampliados."""
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return {"status": "FAIL", "reason": "Invalid SMILES"}

    # Lipinski 
    if not passes_lipinski(mol):
        return {"status": "FAIL", "reason": "Lipinski (Ro5) violation"}

    # Filtros medicinais integrados
    ok, why = passes_full_medicinal_filters(smi)
    if not ok:
        return {"status": "FAIL", "reason": why}

    # SA/QED
    ok_sq, sa, qed = sa_qed_pass(mol)
    if not ok_sq:
        msg = []
        if sa >= 4.5: msg.append(f"SA={sa:.2f} >= 4.5")
        if qed <= 0.3: msg.append(f"QED={qed:.2f} <= 0.3")
        return {"status": "FAIL", "reason": "; ".join(msg)}

    return {
        "status": "PASS",
        "reason": "",
        "SA": sa,
        "QED": qed,
        "MolWt": Descriptors.MolWt(mol),
        "LogP": Crippen.MolLogP(mol),
        "HBD": Lipinski.NumHDonors(mol),
        "HBA": Lipinski.NumHAcceptors(mol),
        "TPSA": Descriptors.TPSA(mol),
        "RotB": Descriptors.NumRotatableBonds(mol)
    }

In [None]:

import os, math
import pandas as pd
from collections import Counter
from rdkit import Chem
from rdkit.Chem import QED, Descriptors




# SA/QED (thresholds do TARTARUS)

def sa_qed_check(smi, sa_thr=4.5, qed_thr=0.3):
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return False, "Invalid SMILES (SA/QED)"
    sa = calculate_sa_score(mol)
    qed = QED.qed(mol)
    if sa >= sa_thr:
        return False, f"SA={sa:.2f} >= {sa_thr}"
    if qed <= qed_thr:
        return False, f"QED={qed:.2f} <= {qed_thr}"
    return True, ""


# 1) Definição das ETAPAS (uma função por etapa)

def stg_parse(smi):
    mol = Chem.MolFromSmiles(smi)
    return (mol is not None, "Invalid SMILES" if mol is None else "")

def stg_elements(smi):
    
    if 'Si' in smi or 'Sn' in smi:
        return False, "Si/Sn not allowed"
    return True, ""

def stg_charge_radicals(smi):
    mol = Chem.MolFromSmiles(smi)
    if Chem.rdmolops.GetFormalCharge(mol) != 0:
        return False, "Non-zero formal charge"
    if Descriptors.NumRadicalElectrons(mol) != 0:
        return False, "Radicals present"
    return True, ""

def stg_ring_bridgehead(smi):
    mol = Chem.MolFromSmiles(smi)
    from rdkit.Chem import rdMolDescriptors
   
    rings = mol.GetRingInfo().AtomRings()
    if rings and max(len(r) for r in rings) > 8:
        return False, "Ring size > 8"
    if rdMolDescriptors.CalcNumBridgeheadAtoms(mol) > 2:
        return False, "Too many bridgehead atoms"
    return True, ""

def stg_phosphorus(smi):
    mol = Chem.MolFromSmiles(smi)
    if filter_phosphorus(mol):
        return False, "Improper phosphorus context"
    return True, ""

def stg_basic_mcf(smi):
    mol = Chem.MolFromSmiles(smi)
    ok, why = passes_mcf(mol)
    return (ok, f"Basic MCF: {why}" if not ok else "")

def stg_expanded_mcf(smi):
    mol = Chem.MolFromSmiles(smi)
    viol, why = substructure_violations(mol)
    return (not viol, why if viol else "")

def stg_physchem(smi):
    
    mol = Chem.MolFromSmiles(smi)
    if Descriptors.TPSA(mol) > 140:
        return False, "TPSA > 140"
    if Descriptors.NumRotatableBonds(mol) > 10:
        return False, "Too many rotatable bonds (>10)"
    return True, ""

def stg_lipinski(smi):
    mol = Chem.MolFromSmiles(smi)
    if not passes_lipinski(mol):
        return False, "Lipinski violation"
    return True, ""

def stg_sa_qed(smi):
    return sa_qed_check(smi, sa_thr=4.5, qed_thr=0.3)

STAGES = [
    ("Parse",              stg_parse),
    ("Elements",           stg_elements),
    ("Charge/Radicals",    stg_charge_radicals),
    ("Rings/Bridgeheads",  stg_ring_bridgehead),
    ("Phosphorus",         stg_phosphorus),
    ("Basic MCF",          stg_basic_mcf),
    ("Expanded MCF",       stg_expanded_mcf),
    ("PhysChem (TPSA/Rot)",stg_physchem),
    ("Lipinski",           stg_lipinski),
    ("SA/QED",             stg_sa_qed),
]


# 2) Runner:

def run_stagewise(
    input_csv,
    smiles_col="smiles",
    outdir="stagewise_outputs",
    enrich_final=True  
):
    """
    Stage-by-stage filtering with English outputs and final survivors saved.
    Creates:
      - stagewise_outputs/survivors_after_XX_<Stage>.csv
      - stagewise_outputs/full_audit.csv (all molecules, pass/fail, reasons)
      - stagewise_outputs/stage_summary.csv (per-stage counts)
      - stagewise_outputs/final_passed_smiles.csv (SMILES only)
      - stagewise_outputs/final_passed_with_props.csv (if enrich_final=True)
    """
    import os
    import pandas as pd
    from collections import Counter
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Crippen, Lipinski, QED

    os.makedirs(outdir, exist_ok=True)
    df = pd.read_csv(input_csv)

    # detect SMILES column 
    if smiles_col not in df.columns:
        cand = [c for c in df.columns if c.lower() == "smiles"]
        if not cand:
            raise ValueError(f"Column '{smiles_col}' not found, and no 'smiles' found case-insensitively.")
        smiles_col = cand[0]

    # base audit dataframe
    audit = pd.DataFrame({"smiles": df[smiles_col].astype(str).str.strip()})
    for c in df.columns:
        if c != smiles_col:
            audit[c] = df[c]

    # initialize audit columns
    for col in ["status", "failed_stage", "reason"]:
        if col not in audit.columns:
            audit[col] = pd.NA

    survivors_mask = pd.Series([True] * len(audit), index=audit.index)
    stage_stats = []


    for i, (name, fn) in enumerate(STAGES, start=1):
        current_idx = audit.index[survivors_mask]
        fails = []
        reasons = []

        for idx in current_idx:
            smi = audit.at[idx, "smiles"]
            ok, why = fn(smi)
            if not ok:
                fails.append(idx)
                reasons.append(why or "Failed")

        # mark failures of this stage
        if fails:
            reason_map = {idx: reasons[k] for k, idx in enumerate(fails)}
            audit.loc[fails, "status"] = "FAIL"
            audit.loc[fails, "failed_stage"] = name
            empty_mask = audit.loc[fails, "reason"].isna() | (audit.loc[fails, "reason"] == "")
            set_idx = audit.loc[fails].index[empty_mask]
            audit.loc[set_idx, "reason"] = [reason_map[idx] for idx in set_idx]

        # update survivors and stats
        n_before = int(survivors_mask.sum())
        if fails:
            survivors_mask.loc[fails] = False
        n_after = int(survivors_mask.sum())
        n_failed = n_before - n_after

        top_reasons = Counter(reasons).most_common(5)
        stage_stats.append({
            "stage": name,
            "failed": n_failed,
            "survivors": n_after,
            "top_reasons": "; ".join([f"{r} (n={cnt})" for r, cnt in top_reasons]) if top_reasons else ""
        })

        # snapshot of survivors after this stage
        snap_path = os.path.join(outdir, f"survivors_after_{i:02d}_{name.replace('/','-')}.csv")
        audit.loc[survivors_mask, ["smiles"]].to_csv(snap_path, index=False)

        # console output in English
        print(f"[{i:02d}] {name}: failed {n_failed} | survivors {n_after}")
        if top_reasons:
            print("   Top failure reasons:")
            for r, cnt in top_reasons:
                print(f"     - {r}: {cnt}")

    # mark PASS for final survivors
    audit.loc[survivors_mask, "status"] = "PASS"
    audit["failed_stage"] = audit["failed_stage"].fillna("")
    audit["reason"] = audit["reason"].fillna("")

    # write audit and summary
    summary = pd.DataFrame(stage_stats)
    summary_path = os.path.join(outdir, "stage_summary.csv")
    summary.to_csv(summary_path, index=False)

    audit_path = os.path.join(outdir, "full_audit.csv")
    audit.to_csv(audit_path, index=False)

    # final survivors (SMILES only)
    final_smiles = audit.loc[audit["status"] == "PASS", ["smiles"]].copy()
    final_smiles_path = os.path.join(outdir, "final_passed_smiles.csv")
    final_smiles.to_csv(final_smiles_path, index=False)

    # enrich final survivors with properties
    if enrich_final and not final_smiles.empty:
        rows = []
        for smi in final_smiles["smiles"]:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            rows.append({
                "smiles": smi,
                "SA": calculate_sa_score(mol),
                "QED": QED.qed(mol),
                "MolWt": Descriptors.MolWt(mol),
                "LogP": Crippen.MolLogP(mol),
                "HBD": Lipinski.NumHDonors(mol),
                "HBA": Lipinski.NumHAcceptors(mol),
                "TPSA": Descriptors.TPSA(mol),
                "RotB": Descriptors.NumRotatableBonds(mol),
            })
        final_props = pd.DataFrame(rows)
        final_props_path = os.path.join(outdir, "final_passed_with_props.csv")
        final_props.to_csv(final_props_path, index=False)
    else:
        final_props_path = None

    # summary print
    print("\n=== STAGE SUMMARY ===")
    if not summary.empty:
        print(summary[["stage", "failed", "survivors"]])
        worst = summary.sort_values("failed", ascending=False).iloc[0]
        print(f" Most filtering stage: '{worst['stage']}' (filtered {worst['failed']} molecules).")
    else:
        print("No stages applied.")

    

    return summary, audit



In [None]:

summary, audit = run_stagewise("/home/resperanca/Tuberculosis_Tese/Data/Model_results/resultado_final_10_99.csv", smiles_col="smiles", outdir="stagewise_outputs")



# Results

## 10_95%

In [2]:
df

Unnamed: 0,stage,failed,survivors,top_reasons
0,Parse,0,5786,
1,Elements,0,5786,
2,Charge/Radicals,60,5726,Non-zero formal charge (n=60)
3,Rings/Bridgeheads,54,5672,Ring size > 8 (n=34); Too many bridgehead atom...
4,Phosphorus,0,5672,
5,Basic MCF,1559,4113,Basic MCF: Reactive Halogen (n=1558); Basic MC...
6,Expanded MCF,377,3736,Forbidden fragment: [*;R1]1~[*]~[*]~[*]1 (n=18...
7,PhysChem (TPSA/Rot),1366,2370,Too many rotatable bonds (>10) (n=703); TPSA >...
8,Lipinski,1206,1164,Lipinski violation (n=1206)
9,SA/QED,195,969,SA=4.52 >= 4.5 (n=29); SA=4.50 >= 4.5 (n=24); ...


## 10_99%

In [4]:
df

Unnamed: 0,stage,failed,survivors,top_reasons
0,Parse,0,317,
1,Elements,0,317,
2,Charge/Radicals,2,315,Non-zero formal charge (n=2)
3,Rings/Bridgeheads,1,314,Ring size > 8 (n=1)
4,Phosphorus,0,314,
5,Basic MCF,157,157,Basic MCF: Reactive Halogen (n=157)
6,Expanded MCF,14,143,Forbidden fragment: [*;R1]1~[*]~[*]~[*]1 (n=10...
7,PhysChem (TPSA/Rot),87,56,Too many rotatable bonds (>10) (n=54); TPSA > ...
8,Lipinski,28,28,Lipinski violation (n=28)
9,SA/QED,14,14,SA=4.54 >= 4.5 (n=7); QED=0.25 <= 0.3 (n=4); S...


## Iso_95%

In [6]:
df

Unnamed: 0,stage,failed,survivors,top_reasons
0,Parse,0,974,
1,Elements,0,974,
2,Charge/Radicals,4,970,Non-zero formal charge (n=4)
3,Rings/Bridgeheads,14,956,Too many bridgehead atoms (n=11); Ring size > ...
4,Phosphorus,0,956,
5,Basic MCF,225,731,Basic MCF: Reactive Halogen (n=225)
6,Expanded MCF,117,614,Forbidden fragment: [*;R1]1~[*]~[*]~[*]1 (n=62...
7,PhysChem (TPSA/Rot),266,348,TPSA > 140 (n=145); Too many rotatable bonds (...
8,Lipinski,183,165,Lipinski violation (n=183)
9,SA/QED,32,133,SA=4.52 >= 4.5 (n=4); SA=4.54 >= 4.5 (n=3); SA...


## Iso_99%

In [8]:
df

Unnamed: 0,stage,failed,survivors,top_reasons
0,Parse,0,31,
1,Elements,0,31,
2,Charge/Radicals,0,31,
3,Rings/Bridgeheads,0,31,
4,Phosphorus,0,31,
5,Basic MCF,10,21,Basic MCF: Reactive Halogen (n=10)
6,Expanded MCF,1,20,Forbidden fragment: [*;R1]1~[*]~[*]~[*]1 (n=1)
7,PhysChem (TPSA/Rot),13,7,Too many rotatable bonds (>10) (n=9); TPSA > 1...
8,Lipinski,5,2,Lipinski violation (n=5)
9,SA/QED,1,1,QED=0.25 <= 0.3 (n=1)
