Functions to generate MS1 and MS2 matrixes plus metadata from rawfile

In [1]:
import os
import re
import glob
import numpy as np
from tqdm import tqdm
import gc
from fisher_py.data.business import Scan
from fisher_py import RawFile


# -----------------------------
# Config / binning
# -----------------------------
MS1_MIN_IDX, MS1_LEN = 6000, 13690   # 600.0 m/z * 10 .. 1935.9 (10 pts per m/z)
MS1_MAX_EXC = MS1_MIN_IDX + MS1_LEN
MS2_MIN_IDX, MS2_LEN = 400, 1600     # m/z 400..1999 (1 pt per m/z)
MS2_MAX_EXC = MS2_MIN_IDX + MS2_LEN

GROUPS = ("TreatmentA", "TreatmentB", "TreatmentC", "TreatmentD")

# -----------------------------
# Helpers
# -----------------------------
def _scan_type_label(text: str) -> str:
    m = re.search(r"Full\s+(\w+)", str(text), flags=re.IGNORECASE)
    return m.group(1).lower() if m else ""

def _group_from_name(name: str) -> str:
    for g in GROUPS:
        if g in name:
            return g
    return "Unknown"

def _as_float_array(x):
    if x is None:
        return np.array([], dtype=float)
    a = np.asarray(x)
    return a.astype(float, copy=False) if a.size else np.array([], dtype=float)

def _ensure_folder_list(paths):
    if isinstance(paths, (list, tuple)):
        return list(paths)
    return [paths]

def _gather_raw_files(folder_paths):
    folder_list = _ensure_folder_list(folder_paths)
    raw_files = []
    for fp in folder_list:
        fp_abs = os.path.abspath(fp)
        if not os.path.isdir(fp_abs):
            raise FileNotFoundError(f'Folder not found: "{fp_abs}"')
        raw_files.extend(glob.glob(os.path.join(fp_abs, "*.raw")))
        raw_files.extend(glob.glob(os.path.join(fp_abs, "*.RAW")))
    raw_files = sorted(set(os.path.abspath(p) for p in raw_files))
    if not raw_files:
        raise FileNotFoundError(
            f'No ".raw" files found in: {", ".join(map(os.path.abspath, folder_list))}'
        )
    return raw_files

def _sanitize_metadata_dict(md: dict) -> dict:
    """Ensure arrays are numeric or Unicode (never object dtype)."""
    safe = {}
    for k, v in md.items():
        if isinstance(v, (int, float, np.number, np.bool_)):
            safe[k] = np.array(v)
            continue
        if isinstance(v, (list, tuple, np.ndarray)):
            arr = np.asarray(v)
            if arr.dtype == object:
                try:
                    arr = arr.astype(np.float32)
                except Exception:
                    arr = arr.astype("U")
            if np.issubdtype(arr.dtype, np.character):
                arr = arr.astype("U")
            safe[k] = arr
            continue
        if isinstance(v, str):
            safe[k] = np.array(v, dtype="U")
            continue
        safe[k] = np.array(str(v), dtype="U")
    return safe

def _out_paths(out_dir: str, group: str):
    base = os.path.join(os.path.abspath(out_dir), group)
    return (f"{base}.ms1.npz", f"{base}.ms2.npz", f"{base}.meta.npz")

# -----------------------------
# Core: process one treatment group at a time
# -----------------------------
def _process_group(group: str, group_files: list, out_dir: str):
    """
    Builds:
      - MS1 (float32, UNnormalized) stacked per MS1 scan for this group
      - MS2 (float16, per-scan normalized) stacked per MS2 scan for this group
      - METADATA aligned to the two matrices
    Saves three NPZ files and frees RAM.
    """
    if not group_files:
        return None

    # Guard: require fisher_py
    try:
        RawFile, Scan  # type: ignore # noqa
    except NameError:
        raise ImportError("fisher_py is required for RAW access. Uncomment the imports at the top.")

    os.makedirs(out_dir, exist_ok=True)
    ms1_path, ms2_path, meta_path = _out_paths(out_dir, group)

    # Per-group accumulators
    file_basenames, file_abspaths = [], []
    file_to_id = {}

    # MS1
    ms1_rows = []                               # list of vectors (float32)
    ms1_scan, ms1_rt, ms1_file_id = [], [], []  # aligned to ms1_rows

    # MS2
    ms2_rows = []                               # list of vectors (float16)
    ms2_scan, ms2_rt, ms2_prec_mz, ms2_file_id = [], [], [], []

    # Iterate files in this group
    for raw_abs in group_files:
        raw_name = os.path.basename(raw_abs)
        if raw_abs not in file_to_id:
            file_to_id[raw_abs] = len(file_basenames)
            file_basenames.append(raw_name)
            file_abspaths.append(raw_abs)
        f_id = file_to_id[raw_abs]

        # open RAW
        try:
            raw = RawFile(raw_abs)
        except Exception as e:
            print(f'[skip] Cannot open RAW: {raw_abs} ({e})')
            continue

        total_scans = int(getattr(raw, "number_of_scans", 0) or 0)

        for i in tqdm(range(1, total_scans + 1), desc=f"[{group}] {raw_name}", ncols=100):
            try:
                raw_scan = Scan.from_file(raw._raw_file_access, scan_number=i)
            except Exception:
                continue

            stype = _scan_type_label(raw_scan.scan_type)
            sc_num = getattr(raw_scan.scan_statistics, "scan_number", i)
            try:
                rt = float(raw.get_retention_time_from_scan_number(sc_num))
            except Exception:
                rt = np.nan

            masses = _as_float_array(getattr(raw_scan, "preferred_masses", None))
            intens = _as_float_array(getattr(raw_scan, "preferred_intensities", None))
            if masses.size == 0 or intens.size == 0:
                continue

            if stype == "ms":
                # Build UNnormalized float32 MS1 row
                # Bin at 0.1 m/z: index = round(m/z*10)
                idx = np.rint(masses * 10.0).astype(np.int32)
                mask = (idx >= MS1_MIN_IDX) & (idx < MS1_MAX_EXC)
                if not mask.any():
                    continue
                v32 = np.zeros(MS1_LEN, dtype=np.float32)
                np.add.at(v32, idx[mask] - MS1_MIN_IDX, intens[mask].astype(np.float32, copy=False))
                ms1_rows.append(v32)
                ms1_scan.append(sc_num)
                ms1_rt.append(rt)
                ms1_file_id.append(f_id)

            elif stype == "ms2":
                # Build per-scan normalized MS2 row (float16 for compact size)
                # Bin at 1.0 m/z: index = round(m/z)
                idx = np.rint(masses).astype(np.int32)
                mask = (idx >= MS2_MIN_IDX) & (idx < MS2_MAX_EXC)
                if not mask.any():
                    continue
                v32 = np.zeros(MS2_LEN, dtype=np.float32)
                np.add.at(v32, idx[mask] - MS2_MIN_IDX, intens[mask].astype(np.float32, copy=False))
                vmax = float(v32.max())
                if vmax > 0:
                    v32 /= vmax
                vec_ms2 = v32.astype(np.float16, copy=False)

                # Precursor m/z (fallback to parsing scan_type text)
                prec = np.nan
                for attr in ("precursor_mz", "master_precursor_mz", "isolation_mz"):
                    if hasattr(raw_scan, attr):
                        try:
                            prec = float(getattr(raw_scan, attr))
                            break
                        except Exception:
                            pass
                if np.isnan(prec):
                    m = re.findall(r'\d+\.\d+', str(raw_scan.scan_type))
                    prec = float(m[1]) if len(m) > 1 else np.nan

                ms2_rows.append(vec_ms2)
                ms2_scan.append(sc_num)
                ms2_rt.append(rt)
                ms2_prec_mz.append(prec)
                ms2_file_id.append(f_id)

        # dispose RAW handle
        try:
            raw.dispose()
        except Exception:
            pass

    # ---- Build metadata (per-group) ----
    # Note: IDs are per-group (0..n_files_in_group-1)
    metadata_raw = dict(
        group_name=np.array(group, dtype="U"),

        # MS1 row-aligned meta
        ms1_scan=np.asarray(ms1_scan, dtype=np.int32),
        ms1_rt=np.asarray(ms1_rt, dtype=np.float32),
        ms1_file_id=np.asarray(ms1_file_id, dtype=np.int32),

        # MS2 row-aligned meta
        ms2_scan=np.asarray(ms2_scan, dtype=np.int32),
        ms2_rt=np.asarray(ms2_rt, dtype=np.float32),
        ms2_precursor_mz=np.asarray(ms2_prec_mz, dtype=np.float32),
        ms2_file_id=np.asarray(ms2_file_id, dtype=np.int32),

        # Lookups
        file_names_lookup=np.asarray(file_basenames, dtype="U"),
        file_paths_lookup=np.asarray(file_abspaths, dtype="U"),
    )
    metadata = _sanitize_metadata_dict(metadata_raw)

    # ---- Stack & save (release RAM right after) ----
    # MS1 (float32, UNnormalized)
    if ms1_rows:
        MS1 = np.vstack(ms1_rows).astype(np.float32, copy=False)
    else:
        MS1 = np.zeros((0, MS1_LEN), dtype=np.float32)
    np.savez_compressed(ms1_path, ms1_matrix=MS1, **metadata)
    print(f"[{group}] Saved MS1: {ms1_path}  shape={MS1.shape}, dtype={MS1.dtype}")
    del MS1, ms1_rows
    gc.collect()

    # MS2 (float16, normalized per scan)
    if ms2_rows:
        MS2 = np.vstack(ms2_rows).astype(np.float16, copy=False)
    else:
        MS2 = np.zeros((0, MS2_LEN), dtype=np.float16)
    np.savez_compressed(ms2_path, ms2_matrix=MS2, **metadata)
    print(f"[{group}] Saved MS2: {ms2_path}  shape={MS2.shape}, dtype={MS2.dtype}")
    del MS2, ms2_rows
    gc.collect()

    # Save metadata standalone (useful if you want to load meta without matrices)
    np.savez_compressed(meta_path, **metadata)
    print(f"[{group}] Saved META: {meta_path}")

    # Final cleanup
    del metadata, metadata_raw
    gc.collect()

    return {"group": group, "ms1": ms1_path, "ms2": ms2_path, "meta": meta_path}

# -----------------------------
# Public API
# -----------------------------
def wholeCasting_per_group(folder_paths, out_dir: str):
    """
    Scans RAW files, partitions by TreatmentA/B/C/D (using filename contains),
    and for each group writes:
      <out_dir>/<Group>.ms1.npz  (float32, UNnormalized)
      <out_dir>/<Group>.ms2.npz  (float16, per-scan normalized)
      <out_dir>/<Group>.meta.npz

    RAM is freed between groups.
    Returns a dict of outputs keyed by group.
    """
    raw_files = _gather_raw_files(folder_paths)
    by_group = {g: [] for g in GROUPS}
    for p in raw_files:
        g = _group_from_name(os.path.basename(p))
        if g in by_group:
            by_group[g].append(p)

    outputs = {}
    for g in GROUPS:
        paths = _process_group(g, by_group[g], out_dir)
        outputs[g] = paths
        # safety: ensure memory is really freed between groups
        gc.collect()
    return outputs


Calling the wrapper function

In [None]:

# Warning: this will wipe *everything* you defined in the current session!
for var in list(globals().keys()):
    if var[0] != "_":  # keep built-ins like __name__, __doc__, etc.
        del globals()[var]

import gc
gc.collect()

wholeCasting_per_group(["F:/TreatmentABC", "F:/TreatmentD"], out_dir="F:/casts/test")

Combine MS2 matrixes

In [None]:
import numpy as np
import pandas as pd
import h5py

TreatmentA = "F:/casts/databank/TreatmentA.ms2.npz"
TreatmentB = "F:/casts/databank/TreatmentB.ms2.npz"
TreatmentC = "F:/casts/databank/TreatmentC.ms2.npz"
TreatmentD = "F:/casts/databank/TreatmentD.ms2.npz"

z = np.load(file=TreatmentD)

# Mat + metadata (same row count/order)
ms2_D = z["ms2_matrix"]             # (n_rows, 13690), float32
ms2_scan = z["ms2_scan"]          # (n_rows,)
ms2_rt   = z["ms2_rt"]            # (n_rows,) minutes
ms2_fid  = z["ms2_file_id"]       # (n_rows,)
fnames   = z["file_names_lookup"] # (n_files,)
group_name = z["group_name"]
precursor_mz = z["ms2_precursor_mz"]

# Optional: assemble a handy DataFrame aligned to ms1 rows
ms2_meta_D = pd.DataFrame({
    "scan": ms2_scan,
    "rt_min": ms2_rt,
    "precursor_mz": precursor_mz,
    "file_name": fnames[ms2_fid],
    'group_name': group_name
})

metadata = pd.concat([ms2_meta_A, ms2_meta_B, ms2_meta_C, ms2_meta_D], ignore_index=True)
ms2_lib = np.vstack((ms2_A, ms2_B, ms2_C, ms2_D))



with h5py.File("F:/casts/databank/ms2_dataset.h5", "w") as f:
    f.create_dataset("ms2_lib", data=ms2_lib, compression="gzip")
    for col in metadata.columns:
        f.create_dataset(col, data=metadata[col].values.astype("S") if metadata[col].dtype == object else metadata[col].values)

Upload the MS2 matrix

In [8]:
import h5py
import pandas as pd
with h5py.File("F:/casts/databank/ms2_dataset.h5", "r") as f:
    ms2_lib = f["ms2_lib"][:]
    metadata = pd.DataFrame({col: f[col][:] for col in f.keys() if col != "ms2_lib"})

Generate the retention time drift table and updated retrntion times

In [None]:
# -*- coding: utf-8 -*-
"""
Load HDF5 (ms2_dataset.h5) -> compute per-bin RT drifts vs first run -> align RTs
Save:
  - per-scan aligned metadata CSV (drops 'cast spectra')
  - per-bin drift tables CSV
"""

import os
import h5py
import numpy as np
import pandas as pd
from typing import Tuple
from math import floor, ceil
import matplotlib.pyplot as plt

# =====================
# Config
# =====================
H5_PATH = r"F:/casts/databank/ms2_dataset.h5"

SIM_THRESHOLD  = 0.95
MZ_WINDOW      = 1.0
TARGET_N       = 50
BIN_WIDTH      = 10.0
OVERLAP_MIN    = 2.5
FORCE_BIN_END_MIN = 80.0
SAMPLE_WITH_REPLACEMENT_IF_NEEDED = False

PLOT_DRIFT_CURVES = False     # set True if you want plots
PLOT_SANITY_AFTER = False

# CSV outputs
SAVE_ALIGNED_CSV  = True
CSV_OUT_PATH      = r"F:/casts/databank/aligned_metadata1.csv"

SAVE_DRIFTS_CSV   = True
DRIFTS_CSV_PATH   = r"F:/casts/databank/rt_drifts1.csv"

# =========================================================
# Helpers
# =========================================================
def _to_1d_float_array(x):
    if isinstance(x, np.ndarray):
        arr = x
    elif isinstance(x, (list, tuple)):
        arr = np.asarray(x, dtype=float)
    else:
        try:
            arr = np.asarray(x, dtype=float).ravel()
        except Exception:
            return None
    return arr.ravel().astype(float, copy=False)

def cosine(a, b):
    va = _to_1d_float_array(a); vb = _to_1d_float_array(b)
    if va is None or vb is None or va.size == 0 or vb.size == 0:
        return -np.inf
    if va.shape != vb.shape:
        n = min(va.size, vb.size)
        if n == 0:
            return -np.inf
        va, vb = va[:n], vb[:n]
    denom = np.linalg.norm(va) * np.linalg.norm(vb)
    if denom == 0:
        return -np.inf
    return float(np.dot(va, vb) / denom)

def decode_bytes_inplace(df: pd.DataFrame) -> None:
    for col in df.columns:
        dt = df[col].dtype
        if dt == object or str(dt).startswith("|S"):
            df[col] = df[col].apply(
                lambda x: x.decode("utf-8") if isinstance(x, (bytes, bytearray)) else x
            )

def pick_col(df: pd.DataFrame, *cands):
    for c in cands:
        if c in df.columns:
            return c
    raise KeyError(f"None of {cands} found. Available: {df.columns.tolist()}")

def harmonize_columns(df: pd.DataFrame) -> None:
    # sample_name
    if "sample_name" not in df.columns:
        s_col = pick_col(df, "sample_name", "file_name", "raw_name", "run_name")
        df["sample_name"] = df[s_col].astype(str)

    # m/z
    if "m/z" not in df.columns:
        mz_col = pick_col(df, "m/z", "mz", "precursor_mz")
        df["m/z"] = df[mz_col].astype(float)

    # retntion time (keep original spelling for compatibility)
    if "retntion time" not in df.columns:
        if "retention_time" in df.columns:
            df["retntion time"] = df["retention_time"].astype(float)
        elif {"rt_min", "rt_max"}.issubset(df.columns):
            df["retntion time"] = (df["rt_min"].astype(float) + df["rt_max"].astype(float)) / 2.0
        elif "rt_min" in df.columns:
            df["retntion time"] = df["rt_min"].astype(float)
        elif "rt" in df.columns:
            df["retntion time"] = df["rt"].astype(float)
        else:
            raise KeyError("Could not infer 'retntion time' column from metadata.")

def load_h5_build_df(h5_path: str) -> pd.DataFrame:
    if not os.path.exists(h5_path):
        raise FileNotFoundError(h5_path)

    with h5py.File(h5_path, "r") as f:
        if "ms2_lib" not in f:
            raise KeyError("HDF5 must contain 'ms2_lib' dataset.")
        ms2_lib = f["ms2_lib"][:]  # (N, L)
        meta = {k: f[k][:] for k in f.keys() if k != "ms2_lib"}

    metadata = pd.DataFrame(meta)
    decode_bytes_inplace(metadata)
    harmonize_columns(metadata)

    if len(metadata) != ms2_lib.shape[0]:
        raise ValueError(f"Row mismatch: metadata={len(metadata)} vs ms2_lib={ms2_lib.shape[0]}")

    metadata = metadata.copy()
    metadata["cast spectra"] = pd.Series(list(ms2_lib), index=metadata.index)
    return metadata

def build_bins_for_target(df_target: pd.DataFrame,
                          bin_width: float,
                          force_end_min):
    if df_target.empty:
        return [], np.nan, np.nan

    rt_min = float(df_target["retntion time"].min())
    rt_max = float(df_target["retntion time"].max())

    start_edge = bin_width * floor(rt_min / bin_width)
    end_edge   = bin_width * ceil(rt_max / bin_width)

    if force_end_min is not None:
        end_edge = float(force_end_min)
        if end_edge <= start_edge:
            raise ValueError(f"FORCE_BIN_END_MIN ({force_end_min}) must be > start_edge ({start_edge}).")

    bins = []
    t = start_edge
    while t < end_edge:
        bins.append((t, t + bin_width))
        t += bin_width
    return bins, rt_min, rt_max

def collect_valid_drifts(bin_df: pd.DataFrame,
                         mz_ref: np.ndarray,
                         rt_ref: np.ndarray,
                         cast_ref: np.ndarray,
                         sim_threshold: float,
                         mz_window: float,
                         target_n: int,
                         sample_with_replacement: bool) -> list:
    if bin_df.empty:
        return []

    def drift_for_row(row):
        mz_i   = float(row["m/z"])
        rt_i   = float(row["retntion time"])
        cast_i = row["cast spectra"]

        mask = np.abs(mz_ref - mz_i) < mz_window
        idxs = np.where(mask)[0]
        if idxs.size == 0:
            return None

        match_count = 0
        rt_sum = 0.0
        for j in idxs:
            if cosine(cast_i, cast_ref[j]) > sim_threshold:
                match_count += 1
                rt_sum += rt_ref[j]
        if match_count == 0:
            return None
        return rt_i - (rt_sum / match_count)

    drifts = []
    if sample_with_replacement:
        tries = 0
        max_tries = max(200, target_n * 20)
        while len(drifts) < target_n and tries < max_tries:
            row = bin_df.sample(n=1, replace=True).iloc[0]
            tries += 1
            d = drift_for_row(row)
            if d is not None:
                drifts.append(d)
        return drifts

    bin_df_shuf = bin_df.sample(frac=1.0, replace=False, random_state=42).reset_index(drop=True)
    for _, row in bin_df_shuf.iterrows():
        if len(drifts) >= target_n:
            break
        d = drift_for_row(row)
        if d is not None:
            drifts.append(d)
    return drifts

def compute_drift_table_for_target(df_target: pd.DataFrame,
                                   mz_ref: np.ndarray,
                                   rt_ref: np.ndarray,
                                   cast_ref: np.ndarray) -> Tuple[pd.DataFrame, pd.DataFrame]:
    bins, rt_min, rt_max = build_bins_for_target(df_target, BIN_WIDTH, FORCE_BIN_END_MIN)
    records = []

    for (t0, t1) in bins:
        win_start = max(t0 - OVERLAP_MIN, rt_min)
        win_end   = min(t1 + OVERLAP_MIN, rt_max)

        bin_df = df_target[(df_target["retntion time"] >= win_start) &
                           (df_target["retntion time"] <  win_end)].copy()

        drifts = collect_valid_drifts(
            bin_df,
            mz_ref=mz_ref, rt_ref=rt_ref, cast_ref=cast_ref,
            sim_threshold=SIM_THRESHOLD,
            mz_window=MZ_WINDOW,
            target_n=TARGET_N,
            sample_with_replacement=SAMPLE_WITH_REPLACEMENT_IF_NEEDED
        )
        n_valid = len(drifts)
        avg_drift = float(np.mean(drifts)) if n_valid > 0 else float("nan")

        records.append({
            "bin_start_min": t0,
            "bin_end_min": t1,
            "expanded_start_min": win_start,
            "expanded_end_min": win_end,
            "n_in_expanded_window": len(bin_df),
            "n_valid_used": n_valid,
            "target_n": TARGET_N,
            "avg_rt_drift": avg_drift,
        })

    result_df = pd.DataFrame.from_records(records)
    if result_df.empty:
        return result_df, result_df

    result_df["bin_center_min"] = 0.5 * (result_df["bin_start_min"] + result_df["bin_end_min"])
    plot_df_valid = result_df[
        (~np.isnan(result_df["avg_rt_drift"])) & (result_df["n_valid_used"] > 0)
    ].copy()
    return result_df, plot_df_valid

def build_alignment_function(plot_df_valid: pd.DataFrame):
    if plot_df_valid is None or plot_df_valid.empty:
        return lambda x: np.zeros_like(np.asarray(x, dtype=float))

    x = plot_df_valid["bin_center_min"].to_numpy()
    y = plot_df_valid["avg_rt_drift"].to_numpy()
    order = np.argsort(x)
    x = x[order]; y = y[order]

    if x.size == 1:
        c = float(y[0])
        return lambda rt: np.full_like(np.asarray(rt, dtype=float), c)

    def f(rt):
        rt = np.asarray(rt, dtype=float)
        return np.interp(rt, x, y, left=y[0], right=y[-1])
    return f

def align_runs_from_h5(h5_path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Returns:
      aligned_df   : per-scan DataFrame with rt_correction and rt_aligned
      drift_table  : per-bin drift table for all targets
    """
    df = load_h5_build_df(h5_path)

    # reference & targets
    sample_order = df["sample_name"].dropna().unique().tolist()
    if len(sample_order) < 2:
        raise ValueError(f"Need ≥2 samples to align; found {len(sample_order)}: {sample_order}")
    ref_name = sample_order[0]
    target_names = sample_order[1:]

    df_ref = df[df["sample_name"] == ref_name].copy()
    if df_ref.empty:
        raise ValueError(f"No reference rows found for '{ref_name}'.")
    mz_ref   = df_ref["m/z"].to_numpy()
    rt_ref   = df_ref["retntion time"].to_numpy()
    cast_ref = df_ref["cast spectra"].to_numpy(object)

    df = df.copy()
    df["rt_correction"] = 0.0
    df["rt_aligned"] = df["retntion time"].astype(float)

    all_drifts = []  # collect per-target drift tables

    if PLOT_DRIFT_CURVES:
        plt.figure()
        any_series = False

    for tname in target_names:
        dft = df[df["sample_name"] == tname].copy()
        if dft.empty:
            print(f"Warning: no rows for target '{tname}', skipping.")
            continue

        res_df, plot_df_valid = compute_drift_table_for_target(dft, mz_ref, rt_ref, cast_ref)

        # add target name & collect drift table
        res_df = res_df.copy()
        res_df["target_name"] = tname
        all_drifts.append(res_df)

        # optional: weighted avg summary
        if not plot_df_valid.empty:
            weights = plot_df_valid["n_valid_used"].to_numpy()
            vals    = plot_df_valid["avg_rt_drift"].to_numpy()
            wavg    = np.average(vals, weights=weights)
            print(f"{tname}: weighted overall avg drift = {wavg:.3f} min "
                  f"(kept {plot_df_valid.shape[0]} bins with ≥1 valid match; TARGET_N={TARGET_N})")
        else:
            print(f"{tname}: no bins with ≥1 valid match.")

        if PLOT_DRIFT_CURVES and not plot_df_valid.empty:
            plt.plot(plot_df_valid["bin_center_min"], plot_df_valid["avg_rt_drift"], marker="o", label=tname)

        # build & apply alignment
        align_fn = build_alignment_function(plot_df_valid)
        rt_vals = dft["retntion time"].to_numpy(dtype=float)
        corr = align_fn(rt_vals)
        aligned = rt_vals - corr
        df.loc[dft.index, "rt_correction"] = corr
        df.loc[dft.index, "rt_aligned"] = aligned

    # reference unchanged
    df.loc[df["sample_name"] == ref_name, "rt_correction"] = 0.0
    df.loc[df["sample_name"] == ref_name, "rt_aligned"] = df.loc[df["sample_name"] == ref_name, "retntion time"].astype(float)

    if PLOT_DRIFT_CURVES:
        plt.axhline(0.0, linestyle="--", color="gray")
        plt.axhline(5.0, linestyle="--", alpha=0.6)
        plt.axhline(-5.0, linestyle="--", alpha=0.6)
        plt.xlabel("Retention time (min, bin center)")
        plt.ylabel("Average RT drift vs ref (min)")
        # plt.legend(title="Target samples", fontsize=9)
        plt.grid(True, which="both", linestyle=":", linewidth=0.5)
        plt.tight_layout()
        plt.show()

    # combine drift tables
    drift_table = pd.concat(all_drifts, ignore_index=True) if all_drifts else pd.DataFrame()

    if PLOT_SANITY_AFTER:
        plt.figure()
        for name in df["sample_name"].dropna().unique().tolist():
            dfx = df[df["sample_name"] == name]
            tmp = dfx[["retntion time", "rt_correction"]].copy()
            tmp["bin"] = (tmp["retntion time"] // 2.0) * 2.0  # 2-min bins
            grp = tmp.groupby("bin", as_index=False)["rt_correction"].median()
            plt.plot(grp["bin"], grp["rt_correction"], marker=".", alpha=0.85, label=name)
        plt.axhline(0.0, linestyle="--", color="gray")
        plt.xlabel("Raw RT (min, 2-min bins)")
        plt.ylabel("Median applied correction (min)")
        # plt.legend(fontsize=8)
        plt.grid(True, linestyle=":", linewidth=0.5)
        plt.tight_layout()
        plt.show()

    return df, drift_table

# =====================
# Run
# =====================
if __name__ == "__main__":
    aligned_df, drift_table = align_runs_from_h5(H5_PATH)

    # Save aligned per-scan metadata (drop huge spectra)
    if SAVE_ALIGNED_CSV:
        os.makedirs(os.path.dirname(CSV_OUT_PATH), exist_ok=True)
        aligned_df.drop(columns=["cast spectra"], errors="ignore").to_csv(CSV_OUT_PATH, index=False)
        print(f"Saved aligned metadata to: {CSV_OUT_PATH}")

    # Save per-bin drift table
    if SAVE_DRIFTS_CSV:
        os.makedirs(os.path.dirname(DRIFTS_CSV_PATH), exist_ok=True)
        drift_table.to_csv(DRIFTS_CSV_PATH, index=False)
        print(f"Saved per-bin RT drifts to: {DRIFTS_CSV_PATH}")

    # Quick summary
    for name in aligned_df["sample_name"].dropna().unique().tolist():
        dfx = aligned_df[aligned_df["sample_name"] == name]
        med_corr = float(np.nanmedian(dfx["rt_correction"])) if len(dfx) else np.nan
        print(f"{name:30s} median correction: {med_corr: .3f} min")

    print("\nColumns in aligned_df:")
    print("  sample_name, m/z, retntion time, rt_correction, rt_aligned, cast spectra")
    if not drift_table.empty:
        print("\nDrift table columns:")
        print(drift_table.columns.tolist())

# ---- Make runs × bins drift matrix and save to CSV ----
import numpy as np
import pandas as pd
import os

SAVE_DRIFT_MATRIX_CSV = True
DRIFT_MATRIX_CSV_PATH = r"F:/casts/databank/rt_drifts_matrix1.csv"

if not drift_table.empty and SAVE_DRIFT_MATRIX_CSV:
    dt = drift_table.copy()

    # Use bin centers as columns (minutes). Round to 2 decimals for clean headers.
    dt["bin_center_min"] = dt["bin_center_min"].astype(float).round(2)

    # Pivot: rows=runs (target_name), cols=bins, values=avg drift
    drift_matrix = (
        dt.pivot_table(
            index="target_name",
            columns="bin_center_min",
            values="avg_rt_drift",
            aggfunc="mean"  # safe if duplicates ever appear
        )
        .sort_index(axis=1)  # sort bins left→right
    )

    # Optional: include the reference run as a zero row if you want it in the matrix
    try:
        ref_name = aligned_df["sample_name"].dropna().unique().tolist()[0]
        if ref_name not in drift_matrix.index:
            # add zero drift across all bins for the reference
            drift_matrix.loc[ref_name] = 0.0
            drift_matrix = drift_matrix.sort_index()
    except Exception:
        pass  # skip if aligned_df is not available

    # (Optional) prettier column labels like "t00-10", else keep numeric centers:
    # dt2 = drift_table.copy()
    # dt2["bin_label"] = dt2["bin_start_min"].astype(int).astype(str) + "-" + dt2["bin_end_min"].astype(int).astype(str)
    # drift_matrix = (dt2.pivot_table(index="target_name", columns="bin_label", values="avg_rt_drift").sort_index(axis=1))

    os.makedirs(os.path.dirname(DRIFT_MATRIX_CSV_PATH), exist_ok=True)
    drift_matrix.to_csv(DRIFT_MATRIX_CSV_PATH, float_format="%.5f")
    print(f"Saved drift matrix (runs × bins) to: {DRIFT_MATRIX_CSV_PATH}")



Using the drift table do the actual quantification

In [None]:
# Warning: this will wipe *everything* you defined in the current session!
for var in list(globals().keys()):
    if var[0] != "_":  # keep built-ins like __name__, __doc__, etc.
        del globals()[var]

import gc
gc.collect()


# -*- coding: utf-8 -*-
import os
import re
import numpy as np
import pandas as pd

# ------------------ helpers ------------------

def _resolve_csv(path: str) -> str:
    """Return path (or path.csv) if exists; else raise."""
    if os.path.exists(path):
        return path
    root, ext = os.path.splitext(path)
    if not ext and os.path.exists(path + ".csv"):
        return path + ".csv"
    raise FileNotFoundError(f"Drift file not found: {path}  (also tried {path+'.csv'})")

def _decode_bytes_arr(a):
    """Decode a 1D array of bytes/objects to str objects."""
    if isinstance(a, np.ndarray) and (a.dtype.kind in ("S", "O")):
        out = []
        for x in a:
            if isinstance(x, (bytes, bytearray)):
                try:
                    out.append(x.decode("utf-8"))
                except Exception:
                    out.append(str(x))
            else:
                out.append(str(x))
        return np.array(out, dtype=object)
    return a

def _safe_metadata_from_npz_with_lut(z: np.lib.npyio.NpzFile, n_rows: int) -> pd.DataFrame:
    """
    Robust metadata builder:
      - keep 1D arrays of length n_rows
      - broadcast 0D or length-1 arrays
      - skip other shapes/lengths (e.g., ms2_* if mismatched)
      - build sample_name from file_names_lookup[ms1_file_id] when available
      - set retntion time from ms1_rt
    """
    cols = {}
    for k in z.files:
        if k == "ms1_matrix":
            continue
        arr = z[k]
        # Keep lookups for later mapping
        if k in ("file_names_lookup", "file_paths_lookup"):
            cols[k] = arr
            continue

        a = np.asarray(arr)
        if a.ndim == 0:
            cols[k] = np.repeat(a.item(), n_rows)
        elif a.ndim == 1:
            if a.shape[0] == n_rows:
                cols[k] = a
            elif a.shape[0] == 1:
                cols[k] = np.repeat(a[0], n_rows)
            else:
                # skip mismatched lengths
                pass
        else:
            # skip 2D+
            pass

    df = pd.DataFrame({k: cols[k] for k in cols if k not in ("file_names_lookup", "file_paths_lookup")})

    # decode bytes in df columns
    for c in df.columns:
        if df[c].dtype == object or str(df[c].dtype).startswith("|S"):
            df[c] = pd.Series([x.decode("utf-8") if isinstance(x, (bytes, bytearray)) else x for x in df[c]])

    # sample_name via file_names_lookup[ms1_file_id] when possible
    if "ms1_file_id" in df.columns and "file_names_lookup" in cols:
        fid = pd.Series(df["ms1_file_id"]).astype(int).to_numpy()
        names_lut = _decode_bytes_arr(cols["file_names_lookup"])
        names_lut = np.asarray(names_lut, dtype=object)
        fallback = np.array([f"fid_{i}" for i in fid], dtype=object)
        ok = (fid >= 0) & (fid < names_lut.shape[0])
        mapped = fallback.copy()
        mapped[ok] = names_lut[fid[ok]]
        df["sample_name"] = mapped.astype(str)
    else:
        if "file_name" in df.columns:
            df["sample_name"] = df["file_name"].astype(str)
        elif "raw_name" in df.columns:
            df["sample_name"] = df["raw_name"].astype(str)
        elif "run_name" in df.columns:
            df["sample_name"] = df["run_name"].astype(str)
        elif "ms1_file_id" in df.columns:
            df["sample_name"] = ("fid_" + pd.Series(df["ms1_file_id"]).astype(int).astype(str)).astype(str)
        else:
            df["sample_name"] = "UnknownRun"

    if "group_name" not in df.columns:
        df["group_name"] = "Unknown"

    # retention time (legacy spelling)
    if "retntion time" not in df.columns:
        if "ms1_rt" in df.columns:
            df["retntion time"] = pd.Series(df["ms1_rt"]).astype(float)
        elif "retention_time" in df.columns:
            df["retntion time"] = pd.Series(df["retention_time"]).astype(float)
        elif "rt" in df.columns:
            df["retntion time"] = pd.Series(df["rt"]).astype(float)
        elif {"rt_min", "rt_max"}.issubset(df.columns):
            df["retntion time"] = (pd.Series(df["rt_min"]).astype(float) + pd.Series(df["rt_max"]).astype(float)) / 2.0
        else:
            raise KeyError("Couldn't infer 'retntion time' (looked for ms1_rt, retention_time, rt, rt_min/rt_max).")

    df["sample_name"] = df["sample_name"].astype(str)
    df["group_name"]  = df["group_name"].astype(str)
    return df

def _build_align_functions_from_drift(drift_path: str):
    """
    Accepts:
      1) Wide matrix CSV: index=runs, columns=bin centers (minutes)
      2) Long table  CSV: target_name, bin_center_min, avg_rt_drift
    All missing values are filled with 0.
    Returns (fns, default_fn) where default_fn is the zero-curve.
    """
    p = _resolve_csv(drift_path)

    # Try wide matrix first
    try:
        wide = pd.read_csv(p, index_col=0)
        # convert column names to numeric bin centers
        bin_centers = []
        ok = True
        for c in wide.columns:
            try:
                bin_centers.append(float(c))
            except Exception:
                ok = False
                break
        if ok and len(bin_centers) > 0:
            order = np.argsort(bin_centers)
            cols_sorted = [wide.columns[i] for i in order]
            wide = wide.loc[:, cols_sorted]
            x_all = np.array([float(c) for c in cols_sorted], dtype=float)

            # fill all missing with 0
            wide = wide.apply(pd.to_numeric, errors="coerce").fillna(0.0)

            fns = {}
            for run, row in wide.iterrows():
                y = row.to_numpy(dtype=float)  # NaNs already 0
                if x_all.size == 1:
                    c = float(y[0])
                    fns[str(run)] = (lambda c: (lambda rt: np.full_like(np.asarray(rt, float), c)))(c)
                else:
                    def make_f(xv, yv):
                        def f(rt):
                            rt = np.asarray(rt, float)
                            return np.interp(rt, xv, yv, left=yv[0], right=yv[-1])
                        return f
                    fns[str(run)] = make_f(x_all, y)

            default_fn = lambda rt: np.zeros_like(np.asarray(rt, float))
            return fns, default_fn
    except Exception:
        pass

    # Fallback: long table
    long = pd.read_csv(p)
    # normalize headers
    rename = {}
    for need in ("target_name", "bin_center_min", "avg_rt_drift"):
        if need not in long.columns:
            for c in long.columns:
                if c.lower() == need.lower():
                    rename[c] = need
    if rename:
        long = long.rename(columns=rename)
    for need in ("target_name", "bin_center_min", "avg_rt_drift"):
        if need not in long.columns:
            raise KeyError(f"Drift file missing column '{need}'")

    # full grid of bin centers
    all_bins = np.sort(long["bin_center_min"].astype(float).unique())

    fns = {}
    for run, grp in long.groupby("target_name"):
        # initialize y as zeros (missing -> 0)
        y = np.zeros_like(all_bins, dtype=float)
        x_run = grp["bin_center_min"].astype(float).to_numpy()
        y_run = grp["avg_rt_drift"].astype(float).to_numpy()
        # map provided points
        idx_map = {bx: i for i, bx in enumerate(all_bins)}
        for xr, yr in zip(x_run, y_run):
            i = idx_map.get(xr, None)
            if i is not None and np.isfinite(yr):
                y[i] = yr  # others remain 0

        if all_bins.size == 1:
            c = float(y[0])
            fns[str(run)] = (lambda c: (lambda rt: np.full_like(np.asarray(rt, float), c)))(c)
        else:
            def make_f(xv, yv):
                def f(rt):
                    rt = np.asarray(rt, float)
                    return np.interp(rt, xv, yv, left=yv[0], right=yv[-1])
                return f
            fns[str(run)] = make_f(all_bins, y)

    default_fn = lambda rt: np.zeros_like(np.asarray(rt, float))
    return fns, default_fn

def _sum_rows_chunked(M, idxs, chunk_rows=1024, out_dtype=np.float32):
    """Memory-safe sum over selected rows."""
    if idxs.size == 0:
        return np.zeros(M.shape[1], dtype=out_dtype)
    acc = np.zeros(M.shape[1], dtype=np.float64)
    for s in range(0, idxs.size, chunk_rows):
        block = M[idxs[s:s+chunk_rows]]
        acc += block.sum(axis=0, dtype=np.float64)
    return acc.astype(out_dtype, copy=False)

# ------------------ main (per-sample) ------------------

def bin_ms1_npz_with_alignment_per_sample(
    npz_path: str,
    drift_path: str,
    out_csv_path: str,
    bin_width: float = 10.0,
    overlap: float = 2.5,
    num_bins: int = 8,
    chunk_rows: int = 1024
) -> str:
    """
    Align per-scan RT using per-run drift curves, then for **each sample_name**
    sum MS1 spectra into 8 bins (10 min) with ±2.5 min overlap on aligned RT.
    Missing drift values -> 0. Writes ONE CSV with 8 rows per sample.
    """
    if not os.path.exists(npz_path):
        raise FileNotFoundError(npz_path)

    z = np.load(npz_path, allow_pickle=True)
    if "ms1_matrix" not in z:
        raise KeyError("NPZ must contain 'ms1_matrix'")
    MS1 = z["ms1_matrix"]       # shape: (N, L)
    N, L = MS1.shape

    # metadata with sample_name, group_name, retntion time
    metadata = _safe_metadata_from_npz_with_lut(z, n_rows=N)

    # build align functions; default is zero-curve
    align_fns, default_fn = _build_align_functions_from_drift(drift_path)

    # per-scan aligned RT
    rt_raw = metadata["retntion time"].to_numpy(dtype=float)
    runs   = metadata["sample_name"].astype(str).to_numpy()
    groups = metadata["group_name"].astype(str).to_numpy()

    rt_corr = np.zeros_like(rt_raw, dtype=float)
    for run in np.unique(runs):
        f = align_fns.get(run, default_fn)  # if run missing -> zero drift
        m = (runs == run)
        if np.any(m):
            rt_corr[m] = f(rt_raw[m])
    rt_aligned = rt_raw - rt_corr

    # fixed bins: [0, 80) stepped by 10, with ±2.5 overlap on aligned RT
    starts  = np.arange(0.0, num_bins * bin_width, bin_width, dtype=float)
    ends    = starts + bin_width
    centers = 0.5 * (starts + ends)

    # Precompute for clipping
    rt_min = float(np.nanmin(rt_aligned)) if rt_aligned.size else 0.0
    rt_max = float(np.nanmax(rt_aligned)) if rt_aligned.size else 0.0

    cast_cols = [f"cast_{i:05d}" for i in range(L)]
    rows = []

    # ---- PER-SAMPLE LOOP ----
    unique_runs = np.unique(runs)
    for run in unique_runs:
        idx_run = np.flatnonzero(runs == run)
        if idx_run.size == 0:
            continue

        # group label for this run (assume constant within run)
        grp_vals = np.unique(groups[idx_run])
        group_label = grp_vals[0] if grp_vals.size > 0 else "Unknown"

        rt_run = rt_aligned[idx_run]

        for t0, t1, mid in zip(starts, ends, centers):
            win_start = max(t0 - overlap, rt_min)
            win_end   = min(t1 + overlap, rt_max)

            # indices of this run that fall in the window
            mask_local = (rt_run >= win_start) & (rt_run < win_end)
            idxs = idx_run[mask_local]
            n_scans = int(idxs.size)

            if n_scans > 0:
                vec = _sum_rows_chunked(MS1, idxs, chunk_rows=chunk_rows, out_dtype=np.float32)
                rt_obs_min = float(rt_run[mask_local].min())
                rt_obs_max = float(rt_run[mask_local].max())
            else:
                vec = np.zeros(L, dtype=np.float32)
                rt_obs_min, rt_obs_max = np.nan, np.nan

            # include sample_name so output is per-sample
            rows.append([
                run, group_label,
                t0, t1, win_start, win_end, mid,
                n_scans, rt_obs_min, rt_obs_max
            ] + vec.tolist())

    out_df = pd.DataFrame(
        rows,
        columns=[
            "sample_name", "group_name",
            "rt_start_min","rt_end_min",
            "expanded_start_min","expanded_end_min",
            "rt_center_min","n_scans",
            "rt_aligned_min_obs","rt_aligned_max_obs"
        ] + cast_cols
    )

    os.makedirs(os.path.dirname(out_csv_path), exist_ok=True)
    out_df.to_csv(out_csv_path, index=False)
    return out_csv_path

# ------------------ run ------------------
if __name__ == "__main__":
    npz_path   = r"F:\casts\databank\TreatmentD.ms1.npz"
    drift_path = r"F:\casts\databank\rt_drifts_matrix"  # auto-tries .csv
    out_csv    = r"F:\casts\databank\TreatmentD_aligned_bins_per_sample.csv"

    wrote = bin_ms1_npz_with_alignment_per_sample(
        npz_path=npz_path,
        drift_path=drift_path,
        out_csv_path=out_csv,
        bin_width=10.0,
        overlap=2.5,
        num_bins=8,
        chunk_rows=1024  # lower if you still see MemoryError
    )
    print("Saved:", wrote)


Imporing tdportal report to databank

In [1]:
def ID_import(tdportal, databank, cast_path):
  def str_to_int(st):
      internal = []
      digits = re.findall(r'\d+', st)
      for i in range(0, len(digits)):
          internal.append(int(digits[i]))
      return(internal)

  scan_number = [0]*len(tdportal['File Name'])
  td_samples = []

  for i in range(0, len(tdportal['File Name'])):
      scan_number[i] = str_to_int(str(tdportal['Fragment Scans'][i]))
      if tdportal['File Name'][i] not in td_samples:
        td_samples.append(tdportal['File Name'][i])

  my_dic_scan = {key: [] for key in td_samples}
  my_dic_index = {key: [] for key in td_samples}

  for i in range(0, len(tdportal['File Name'])):
      my_dic_scan[tdportal['File Name'][i]].append(scan_number[i])
      my_dic_index[tdportal['File Name'][i]].append([i]*len(scan_number[i]))

  for i in range(0, len(td_samples)):
      nested_list = my_dic_scan[td_samples[i]]
      flat_list = []
      for item in nested_list:
          if isinstance(item, list):
              flat_list.extend(item)
          else:
              flat_list.append(item)
      my_dic_scan[td_samples[i]] = [elem for sublist in flat_list for elem in (sublist if isinstance(sublist, list) else [sublist])]


  for i in range(0, len(td_samples)):
      nested_list = my_dic_index[td_samples[i]]
      flat_list = []
      for item in nested_list:
          if isinstance(item, list):
              flat_list.extend(item)
          else:
              flat_list.append(item)
      my_dic_index[td_samples[i]] = [elem for sublist in flat_list for elem in (sublist if isinstance(sublist, list) else [sublist])]

  sequence, MASS, Accession, missing, PFR = [], [], [], [], []

  for i in tqdm(range(len(databank['scan'])), desc="Processing scans", ncols=100):
      try:
          sample = databank['sample_name'][i]
          scan   = databank['scan'][i]

          if scan in my_dic_scan[sample]:
              tt = my_dic_index[sample][my_dic_scan[sample].index(scan)]
              sequence.append(tdportal.at[tt, 'Sequence'])
              MASS.append(tdportal.at[tt, 'Average Mass'])
              Accession.append(tdportal.at[tt, 'Accession'])
              PFR.append(tdportal.at[tt, 'PFR'])
          else:
              sequence.append(None)
              MASS.append(None)
              Accession.append(None)
              PFR.append(None)

      except KeyError as e:
          missing.append(sample)
        # Handles missing sample_name or missing index key
        # You could also log: print(f"Missing key: {e}")
          sequence.append(None)
          MASS.append(None)
          Accession.append(None)
          PFR.append(None)

      except Exception as e:
        # Catches other unexpected issues (out-of-range, missing column, etc.)
        # print(f"Unexpected error: {e}")
          sequence.append(None)
          MASS.append(None)
          Accession.append(None)
          PFR.append(None)

  print(set(missing))

  databank['sequence'] = sequence
  databank['MASS'] = MASS
  databank['Accession'] = Accession
  databank['PFR'] = PFR

  databank = pd.DataFrame(databank)

  databank.to_csv(cast_path, index=False)

  return()


In [3]:
import pandas as pd
import re
from tqdm import tqdm

tdportal = pd.read_csv(r'F:\casts\databank\csv_files\tdportal.csv')
df = pd.read_csv(r'F:\casts\databank\csv_files\aligned_metadata.csv')
cast_path = 'F:/casts/databank/csv_files/databank_pfr.csv'

ID_import(tdportal, df, cast_path)

Processing scans: 100%|█████████████████████████████████| 1219397/1219397 [08:10<00:00, 2486.71it/s]


set()


()

This code removes the rows with missing values

In [1]:
import pandas as pd

# Input and output paths
input_path = r'F:/casts/databank/csv_files/databank_pfr.csv'
output_path = r'F:/casts/databank/csv_files/databank_pfr_clean.csv'

# Read the CSV
df = pd.read_csv(input_path)

# Remove rows where PFR is missing (NaN or empty string)
df_clean = df.dropna(subset=["PFR"])   # removes NaN
df_clean = df_clean[df_clean["PFR"].astype(str).str.strip() != ""]  # removes empty strings

# Save cleaned CSV
df_clean.to_csv(output_path, index=False)

print(f"✅ Cleaned file saved to: {output_path}")


✅ Cleaned file saved to: F:/casts/databank/csv_files/databank_pfr_clean.csv
