<a href="https://colab.research.google.com/github/jwasswa2023/ChloroFinder/blob/main/ChloroFinder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
from itertools import combinations
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score,
    balanced_accuracy_score, matthews_corrcoef, roc_auc_score,
    accuracy_score
)
from sklearn.ensemble import RandomForestClassifier
from collections import Counter
from scipy.sparse import hstack, csr_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import joblib

# ----------------------------
# Load cleaned data
# ----------------------------
df = pd.read_csv('fragments_cleaned_and_parsed.csv')
# ensure frag_list is a list
df['frag_list'] = df['frag_list'].apply(lambda x: x if isinstance(x, list) else eval(x))
y = df['chlorinated'].astype(int).values

# ----------------------------
# Helper: compute Δm/z features
# ----------------------------
def compute_delta_mz_list(frag_list, round_dp=3, min_delta=0.001, max_delta=None):
    if not frag_list or len(frag_list) < 2:
        return []
    frags = sorted(set(float(f) for f in frag_list))
    deltas = []
    for a, b in combinations(frags, 2):
        d = abs(a - b)
        if d < min_delta:
            continue
        if max_delta is not None and d > max_delta:
            continue
        deltas.append(round(d, round_dp))
    return sorted(set(deltas))

# ----------------------------
# Build feature spaces
# ----------------------------
# 1) fragment presence (multi-hot)
mlb_frag = MultiLabelBinarizer()
X_frag = mlb_frag.fit_transform(df['frag_list'])
X_frag = csr_matrix(X_frag, dtype=np.float32)

# 2) delta-m/z presence (multi-hot)
delta_lists = df['frag_list'].apply(lambda frags: compute_delta_mz_list(frags, round_dp=3))
mlb_delta = MultiLabelBinarizer()
X_delta = mlb_delta.fit_transform(delta_lists)
X_delta = csr_matrix(X_delta, dtype=np.float32)

# Combine
X = hstack([X_frag, X_delta], format='csr').astype(np.float32)

print(f"Fragments feature dim: {X_frag.shape[1]}")
print(f"Delta-m/z feature dim: {X_delta.shape[1]}")
print(f"Total feature dim:     {X.shape[1]}")

# ----------------------------
# Train / Test Split (stratified)
# ----------------------------
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, stratify=y, random_state=42
)

# Optional: quick class distribution check
counts = Counter(y_train)
print(f"Train class counts: {counts}")

# ----------------------------
# Random Forest + Grid Search
# ----------------------------
# Note: for classification, valid criteria are 'gini', 'entropy', 'log_loss'
param_grid = {
    "n_estimators": [50, 100, 150, 200, 300],
    "max_depth": [None, 5, 10, 20, 50],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 4],
    "bootstrap": [True, False]
}

base_model = RandomForestClassifier(
    random_state=42,
    n_jobs=-1,            # parallelize across CPU cores
    class_weight=None     # set to 'balanced' if your classes are very imbalanced
)

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
grid = GridSearchCV(
    base_model,
    param_grid,
    scoring='f1',         # you can switch to 'f1_macro' if classes are imbalanced
    cv=cv,
    verbose=1,
    n_jobs=-1
)

print("🔍 Starting Grid Search (Random Forest)...")
grid.fit(X_train, y_train)

# ----------------------------
# Best model & evaluation
# ----------------------------
best_model = grid.best_estimator_
print("\n Best Parameters:")
print(grid.best_params_)

y_pred = best_model.predict(X_test)
y_prob = best_model.predict_proba(X_test)[:, 1]  # available when criterion != 'entropy'? (RF provides proba for all)

print("\n📊 Classification Report:")
print(classification_report(y_test, y_pred))

acc = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
ba = balanced_accuracy_score(y_test, y_pred)
mcc = matthews_corrcoef(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_prob)

print(f" Accuracy: {acc:.3f}")
print(f" F1 Score: {f1:.3f}")
print(f" Balanced Accuracy: {ba:.3f}")
print(f" MCC: {mcc:.3f}")
print(f" ROC-AUC: {roc_auc:.3f}")

# ----------------------------
# Confusion Matrix
# ----------------------------
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Non-Chlorinated', 'Chlorinated'],
            yticklabels=['Non-Chlorinated', 'Chlorinated'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix (Random Forest)')
plt.show()

# ----------------------------
# Save model + encoders
# ----------------------------
joblib.dump({
    'model': best_model,
    'mlb_frag': mlb_frag,
    'mlb_delta': mlb_delta
}, 'ChloroFinder.pkl')
print(" Saved model + encoders to 'ChloroFinder.pkl'")