
# MALDI-TOF Spectra Classification

This notebook loads Bruker MALDI-TOF spectra from a specified genus (and optional species), preprocesses them,
and trains a Random Forest classifier to identify bacteria at the species level.

**Steps:**
1. Configure parameters (data paths, target genus/species, bin size).
2. Load and filter acquisition (`.acqu`) and time-domain (`.fid`) files.
3. Preprocess spectra: variance stabilization, smoothing, baseline correction, normalization, trimming, binning.
4. Extract intensities as feature vectors and build labels.
5. Split into train/test sets and train a Random Forest.
6. Evaluate performance with classification report.


In [None]:
import os
import logging
from datetime import datetime
import numpy as np
import joblib
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from spectrum import SpectrumObject, VarStabilizer, Smoother, BaselineCorrecter, Normalizer, Trimmer, Binner
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score

# Adjust logging
logging.basicConfig(level=logging.INFO,
                    format='[%(asctime)s] %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

# %%
# --- User Parameters ---
base_dir    = '/MARISMa'
years       = [str(y) for y in range(2018, 2020)] # From which years to load data
# e.g. [2018, 2019, 2020]
bin_step    = 9                      # m/z bin width for Binner

target_genus   = 'Klebsiella'        # e.g. 'Escherichia'
target_species = None                 # e.g. 'Coli' or None for all species

random_state = 42

# %%
# --- Data Loading Function ---
def load_spectra(base_dir, genus, species=None, years=None):
    """
    Walks through base_dir/YYYY/matched_bacteria/genus[/species] and collects
    Bruker acqu/fid file paths and corresponding labels.
    Returns lists: acqu_files, fid_files, labels.
    Label format: "Genus_species" if species provided or discovered, else genus only.
    """
    acqu_files, fid_files, labels = [], [], []
    for year in years:
        matched = os.path.join(base_dir, year, genus)
        if not os.path.isdir(matched):
            continue

        species_dirs = []
        if species:
            spath = os.path.join(matched, species)
            if os.path.isdir(spath):
                species_dirs = [spath]
        else:
            species_dirs = [os.path.join(matched, d)
                            for d in os.listdir(matched)
                            if os.path.isdir(os.path.join(matched, d))]

        for sp_dir in species_dirs:
            sp_label = f"{genus}_{os.path.basename(sp_dir)}"
            for extern_id in os.listdir(sp_dir):
                extern_dir = os.path.join(sp_dir, extern_id)
                if not os.path.isdir(extern_dir):
                    continue
                for pos in os.listdir(extern_dir):
                    pos_dir = os.path.join(extern_dir, pos)
                    if not os.path.isdir(pos_dir):
                        continue
                    for meas in os.listdir(pos_dir):
                        meas_dir = os.path.join(pos_dir, meas)
                        if not os.path.isdir(meas_dir):
                            continue
                        slin = os.path.join(meas_dir, '1SLin')
                        if not os.path.isdir(slin):
                            continue
                        a = os.path.join(slin, 'acqu')
                        f = os.path.join(slin, 'fid')
                        if os.path.exists(a) and os.path.exists(f):
                            acqu_files.append(a)
                            fid_files.append(f)
                            labels.append(sp_label)
    return acqu_files, fid_files, labels

# %%
# --- Preprocessing Function ---
def preprocess_spectrum(acqu_file, fid_file, bin_step):
    """
    Load Bruker files, apply variance stabilization, smoothing, baseline correction,
    normalization, trimming (2–20 kDa), and binning.
    Returns intensity vector (numpy array).
    """
    spec = SpectrumObject.from_bruker(acqu_file, fid_file)
    spec = VarStabilizer(method='sqrt')(spec)
    spec = Smoother(halfwindow=10, polyorder=3)(spec)
    spec = BaselineCorrecter(method='SNIP', snip_n_iter=10)(spec)
    spec = Normalizer()(spec)
    spec = Trimmer(min=2000, max=20000)(spec)
    spec = Binner(step=bin_step)(spec)
    return spec.intensity



In [2]:
# %%
# --- Load and Preprocess All Spectra ---
acqs, fids, labels = load_spectra(base_dir,
                                  target_genus,
                                  target_species,
                                  years)
logging.info(f"Found {len(acqs)} spectra for {target_genus} {target_species or ''}")

# Preallocate feature matrix
X = []
for a, f in zip(acqs, fids):
    X.append(preprocess_spectrum(a, f, bin_step))
X = np.vstack(X)            # shape = (n_samples, n_bins)

logging.info(f"Feature matrix shape: {X.shape}")

# Encode labels as categorical
y = np.array(labels)


[2025-05-13 12:19:08] INFO: Found 5502 spectra for Klebsiella 
[2025-05-13 12:20:09] INFO: Feature matrix shape: (5502, 2000)


In [3]:

# %%
# --- Train-Test Split ---
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    stratify=y,
    random_state=random_state
)
logging.info(f"Training on {len(y_train)} spectra, testing on {len(y_test)} spectra")


[2025-05-13 12:20:09] INFO: Training on 4401 spectra, testing on 1101 spectra


In [4]:
# %%
# --- Random Forest Training ---
clf = RandomForestClassifier(n_estimators=100,
                             random_state=random_state,
                             n_jobs=-1)
clf.fit(X_train, y_train)

logging.info("Random Forest training complete")

# Evaluate
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {acc:.3%}")
print(classification_report(y_test, y_pred))

# %%
# --- Save Model ---
model_path = 'rf_maldi_model.joblib'
joblib.dump(clf, model_path)
logging.info(f"Model saved to {model_path}")

[2025-05-13 12:20:11] INFO: Random Forest training complete
[2025-05-13 12:20:11] INFO: Model saved to rf_maldi_model.joblib


Test accuracy: 99.728%
                       precision    recall  f1-score   support

 Klebsiella_Aerogenes       1.00      1.00      1.00        58
   Klebsiella_Oxytoca       1.00      1.00      1.00       155
Klebsiella_Pneumoniae       1.00      1.00      1.00       857
 Klebsiella_Variicola       1.00      0.90      0.95        31

             accuracy                           1.00      1101
            macro avg       1.00      0.98      0.99      1101
         weighted avg       1.00      1.00      1.00      1101

