<a href="https://colab.research.google.com/github/jwasswa2023/ChloroFinder/blob/main/USING_CHLOROFINDER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyteomics
!pip install joblib


import os, ast
import numpy as np
import pandas as pd
from itertools import combinations
from scipy.sparse import csr_matrix, hstack
import joblib
from pyteomics import mzml  # pip install pyteomics

# =========================
# CONFIG
# =========================
INPUT_CSV  = "/Your_MS2_fragments.csv"
MZML_PATH  = "/Your_mzml_file.mzML"
MODEL_PATH = "/ChloroFinder.pkl"

OUT_ALL    = "/content/mgf_predictions_all_with_isotopes.csv"
OUT_POS    = "/content/mgf_predictions_chlorinated_only_with_isotopes.csv"

# Model feature building (match training)
ROUND_DP   = 3
MIN_DELTA  = 0.001
MAX_DELTA  = None

# Isotope inspection (MS1)
ISOTOPE_BIN_DA = 0.01
CL_M2_SHIFT    = 1.99705
CL_M4_SHIFT    = 3.99410
ISOTOPE_TAU    = 0.12
CL_EXPECTED_RATIOS = {1: 0.32, 2: 0.60, 3: 0.90}
EPS = 1e-12

# MS2 fragment isotope pairing
MS2_PAIR_TOL_DA = 0.01

# CSV column names
MZ_COL        = "mz_values"
INT_COL       = "intensity_values"
RT_COL        = "RTINSECONDS"
PRECURSOR_COL = "PEPMASS"

# ======= FUSION: weights & thresholds =======
W_MODEL   = 0.6    # weight for model probability
W_MS1ISO  = 0.35   # weight for MS1 isotope score (0..1)
W_MS2PAIR = 0.05   # small bonus if any fragment has +1.997 partner
FINAL_THRESHOLD = 0.60   # combined_score >= this => chlorinated

# (Optional) guardrails — if you want hard minimums in addition to the fused score
MIN_MODEL_PROB = 0.0     # e.g., 0.30; keep 0 to rely purely on fused score
MIN_ISO_SCORE  = 0.0     # e.g., 0.30; keep 0 to rely purely on fused score

# =========================
# HELPERS
# =========================
def safe_parse_list(x):
    if isinstance(x, list):
        return x
    if pd.isna(x):
        return []
    s = str(x).strip()
    try:
        val = ast.literal_eval(s)
        if isinstance(val, (list, tuple)):
            return [float(v) for v in val]
    except Exception:
        pass
    try:
        parts = [p.strip() for p in s.strip("[]() ").split(",") if p.strip()]
        return [float(p) for p in parts]
    except Exception:
        return []

def round_unique(values, dp=3):
    out, seen = [], set()
    for v in values:
        rv = round(float(v), dp)
        if rv not in seen:
            seen.add(rv); out.append(rv)
    return out

def compute_delta_mz_list(frag_list, round_dp=3, min_delta=0.001, max_delta=None):
    if not frag_list or len(frag_list) < 2:
        return []
    frags = sorted(set(float(f) for f in frag_list))
    deltas = []
    for a, b in combinations(frags, 2):
        d = abs(a - b)
        if d < min_delta: continue
        if (max_delta is not None) and (d > max_delta): continue
        deltas.append(round(d, round_dp))
    return sorted(set(deltas))

def get_rt_minutes(spec):
    return float(spec['scanList']['scan'][0].get('scan start time', np.nan))

def build_ms1_cache(mzml_path):
    cache = []
    with mzml.MzML(mzml_path) as r:
        for spec in r:
            if spec.get('ms level') == 1:
                rt = get_rt_minutes(spec)
                unit = spec['scanList']['scan'][0].get('scan start time unit', '').lower()
                if 'second' in unit:
                    rt_seconds = rt
                else:
                    rt_seconds = rt * 60.0
                ms1_mz = spec['m/z array'].astype(float)
                ms1_int = spec['intensity array'].astype(float)
                cache.append((rt_seconds, ms1_mz, ms1_int))
    return cache

def nearest_ms1(ms1_cache, rt_seconds):
    if not ms1_cache:
        return np.array([]), np.array([])
    rts = np.array([t[0] for t in ms1_cache], dtype=float)
    i = int(np.argmin(np.abs(rts - rt_seconds)))
    return ms1_cache[i][1], ms1_cache[i][2]

def integrate_peak(ms1_mz, ms1_int, target_mz, bin_da=ISOTOPE_BIN_DA):
    if ms1_mz.size == 0:
        return 0.0
    mask = (ms1_mz >= target_mz - bin_da) & (ms1_mz <= target_mz + bin_da)
    return float(np.sum(ms1_int[mask])) if np.any(mask) else 0.0

def ms1_isotope_metrics(ms1_mz, ms1_int, precursor_mz):
    M  = integrate_peak(ms1_mz, ms1_int, precursor_mz, ISOTOPE_BIN_DA)
    M2 = integrate_peak(ms1_mz, ms1_int, precursor_mz + CL_M2_SHIFT, ISOTOPE_BIN_DA)
    M4 = integrate_peak(ms1_mz, ms1_int, precursor_mz + CL_M4_SHIFT, ISOTOPE_BIN_DA)
    ratio = (M2 / (M + EPS)) if M > 0 else 0.0
    best, best_n = 0.0, 0
    for n, exp_r in CL_EXPECTED_RATIOS.items():
        sc = float(np.exp(-abs(ratio - exp_r) / ISOTOPE_TAU))
        if sc > best:
            best, best_n = sc, n
    return M, M2, M4, ratio, best, best_n

def ms2_fragment_isotope_pairs(frag_mz_list, tol_da=MS2_PAIR_TOL_DA, shift=CL_M2_SHIFT, max_pairs_to_show=5):
    if not frag_mz_list:
        return 0, ""
    mzs = sorted(set(float(round(m, 3)) for m in frag_mz_list))
    mzs_arr = np.array(mzs)
    pairs = []
    for m in mzs:
        target = m + shift
        idx = np.argmin(np.abs(mzs_arr - target))
        if abs(mzs_arr[idx] - target) <= tol_da:
            pairs.append((m, float(mzs_arr[idx])))
    pairs = sorted(set(pairs))
    ex_str = "; ".join([f"{a:.3f}->{b:.3f}" for a,b in pairs[:max_pairs_to_show]])
    return len(pairs), ex_str

# =========================
# LOAD MODEL + ENCODERS
# =========================
bundle    = joblib.load(MODEL_PATH)
model     = bundle["model"]
mlb_frag  = bundle["mlb_frag"]
mlb_delta = bundle["mlb_delta"]

# =========================
# LOAD DATA
# =========================
df = pd.read_csv(INPUT_CSV)
df[MZ_COL]  = df[MZ_COL].apply(safe_parse_list)
df[INT_COL] = df[INT_COL].apply(safe_parse_list)

# Features to match training
df["frag_list"]  = df[MZ_COL].apply(lambda xs: round_unique(xs, dp=ROUND_DP))
df["delta_list"] = df["frag_list"].apply(lambda frags: compute_delta_mz_list(frags, round_dp=ROUND_DP,
                                                                            min_delta=MIN_DELTA, max_delta=MAX_DELTA))

# Transform with trained encoders
X_frag  = csr_matrix(mlb_frag.transform(df["frag_list"]), dtype=np.float32)
X_delta = csr_matrix(mlb_delta.transform(df["delta_list"]), dtype=np.float32)
X       = hstack([X_frag, X_delta], format="csr").astype(np.float32)

# Model predictions
y_prob = model.predict_proba(X)[:, 1]

# =========================
# MS1 + MS2 ISOTOPE EVIDENCE
# =========================
if not os.path.isfile(MZML_PATH):
    raise FileNotFoundError(f"mzML not found at: {MZML_PATH}")

print("Caching MS1 from mzML…")
ms1_cache = build_ms1_cache(MZML_PATH)
print(f"  cached {len(ms1_cache)} MS1 spectra")

# ensure RT in seconds
if df[RT_COL].max() < 100:
    df["_rt_seconds"] = df[RT_COL] * 60.0
else:
    df["_rt_seconds"] = df[RT_COL].astype(float)

ms1_M_list, ms1_M2_list, ms1_M4_list = [], [], []
ms1_ratio_list, ms1_iso_score_list, ms1_bestN_list = [], [], []
ms2_pair_count_list, ms2_pair_examples_list = [], []

for _, row in df.iterrows():
    rt_sec = float(row["_rt_seconds"])
    prec_mz = float(row[PRECURSOR_COL])

    ms1_mz, ms1_int = nearest_ms1(ms1_cache, rt_sec)
    M, M2, M4, ratio, iso_score, bestN = ms1_isotope_metrics(ms1_mz, ms1_int, prec_mz)
    ms1_M_list.append(M);   ms1_M2_list.append(M2); ms1_M4_list.append(M4)
    ms1_ratio_list.append(ratio); ms1_iso_score_list.append(iso_score); ms1_bestN_list.append(bestN)

    cnt, ex = ms2_fragment_isotope_pairs(row["frag_list"], tol_da=MS2_PAIR_TOL_DA, shift=CL_M2_SHIFT)
    ms2_pair_count_list.append(cnt); ms2_pair_examples_list.append(ex)

# =========================
# FUSED DECISION
# =========================
# Normalize MS2 evidence to {0,1} as “any Cl-like pair?”
ms2_evidence = np.array([1 if c > 0 else 0 for c in ms2_pair_count_list], dtype=float)

# Combine
y_prob_arr   = np.array(y_prob, dtype=float)
iso_arr      = np.array(ms1_iso_score_list, dtype=float)
combined     = W_MODEL*y_prob_arr + W_MS1ISO*iso_arr + W_MS2PAIR*ms2_evidence

# Optional guardrails (keep them 0.0 to disable)
meets_min = (y_prob_arr >= MIN_MODEL_PROB) & (iso_arr >= MIN_ISO_SCORE)
combined_pred = (combined >= FINAL_THRESHOLD) & meets_min
combined_pred = combined_pred.astype(int)

# =========================
# OUTPUT
# =========================
df_out = df.copy()
df_out["chlorinated_probability"] = y_prob_arr
df_out["ms1_M"]            = ms1_M_list
df_out["ms1_M2"]           = ms1_M2_list
df_out["ms1_M4"]           = ms1_M4_list
df_out["ms1_M2_over_M"]    = ms1_ratio_list
df_out["ms1_iso_score"]    = iso_arr
df_out["ms1_best_nCl_like"]= ms1_bestN_list
df_out["ms2_cl_pairs_count"]   = ms2_pair_count_list
df_out["ms2_cl_pairs_examples"]= ms2_pair_examples_list

df_out["combined_score"] = combined
df_out["combined_pred"]  = combined_pred

df_out.to_csv(OUT_ALL, index=False)

df_pos = df_out[df_out["combined_pred"] == 1].copy()
df_pos.to_csv(OUT_POS, index=False)

print(f"✅ Wrote full predictions + fused decision: {OUT_ALL}")
print(f"✅ Wrote chlorinated-only (fused): {OUT_POS}")
try:
    print(df_pos.head(10)[["FEATURE_ID","RTINSECONDS","PEPMASS",
                           "chlorinated_probability","ms1_iso_score",
                           "ms2_cl_pairs_count","combined_score","combined_pred"]])
except Exception:
    pass
