In [None]:
# -- Import
from winnow.calibration.calibration_features import (
    PrositFeatures,
    MassErrorFeature,
    RetentionTimeFeature,
    ChimericFeatures,
    BeamFeatures,
)
from winnow.calibration.calibrator import ProbabilityCalibrator
from winnow.datasets.calibration_dataset import RESIDUE_MASSES, CalibrationDataset

from winnow.fdr.database_grounded import DatabaseGroundedFDRControl
from winnow.fdr.nonparametric import NonParametricFDRControl

import logging

import numpy as np
import pandas as pd
import ast

from sklearn.model_selection import train_test_split

import seaborn.objects as so
from seaborn import axes_style

theme_dict = {**axes_style("whitegrid"), "grid.linestyle": ":"}

In [None]:
# -- Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

In [None]:
SPECIES = "helaqc"  # [gluc, helaqc, herceptin, immuno, sbrodae, snakevenoms, tplantibodies, woundfluids]

In [None]:
# Load data
annotated_data = pd.read_parquet(
    f"../input_data/spectrum_data/labelled/dataset-{SPECIES}-annotated-0000-0001.parquet"
)
raw_beam_preds = pd.read_csv(f"../input_data/beam_preds/raw/{SPECIES}_beam_preds.csv")

In [None]:
# Function to safely convert string representations of lists
def try_convert(value):
    try:
        return ast.literal_eval(value)
    except (ValueError, SyntaxError):
        return value  # Return original if conversion fails


# Apply conversion to object (string) columns
for col in raw_beam_preds.select_dtypes(include=["object"]).columns:
    raw_beam_preds[col] = raw_beam_preds[col].apply(try_convert)

In [None]:
# Check uniqueness of spectrum_id
def check_uniqueness(df, name):
    unique_by_spectrum = df["spectrum_id"].nunique() == len(df)
    if not unique_by_spectrum:
        raise ValueError(f"Dataset {name} does not have unique spectrum_id.")
    print(f"{name}: spectrum_id can uniquely identify rows.")


check_uniqueness(raw_beam_preds, "raw_beam_preds")
check_uniqueness(annotated_data, "annotated_data")

# Check if all spectrum_id values in annotated_data are present in raw_beam_preds
missing_scans = set(annotated_data["spectrum_id"]) - set(raw_beam_preds["spectrum_id"])
if missing_scans:
    raise ValueError(
        f"{len(missing_scans)} spectrum_id values in annotated_data are missing from beam_preds."
    )
print("All spectrum_id values in annotated_data are present in beam_preds.")

# Merge datasets
annotated_beam_preds = annotated_data.merge(
    raw_beam_preds,
    on=["spectrum_id"],
    how="inner",
    suffixes=("", "_from_raw"),
)

# Drop duplicate columns after merge
for col in raw_beam_preds.columns:
    if col in annotated_data.columns and col not in ["spectrum_id"]:
        annotated_beam_preds.drop(columns=[col + "_from_raw"], inplace=True)

# Validate merge result
if len(annotated_beam_preds) != len(annotated_data):
    raise ValueError(
        f"Merge conflict: Expected {len(annotated_data)} rows, but got {len(annotated_beam_preds)}."
    )

# Save output
output_path = f"../input_data/beam_preds/labelled/{SPECIES}-annotated_beam_preds.csv"
annotated_beam_preds.to_csv(output_path, index=False)
print(f"Annotated beam predictions saved: {output_path}")

In [None]:
# -- Load data
logger.info("Loading dataset.")
dataset = CalibrationDataset.from_predictions_csv(
    spectrum_path=f"../input_data/spectrum_data/labelled/dataset-{SPECIES}-annotated-0000-0001.parquet",
    beam_predictions_path=f"../input_data/beam_preds/labelled/{SPECIES}-annotated_beam_preds.csv",
)

logger.info("Filtering dataset.")
filtered_dataset = (
    dataset.filter_entries(
        metadata_predicate=lambda row: not isinstance(row["prediction"], list),
    )
    .filter_entries(metadata_predicate=lambda row: not row["prediction"])
    .filter_entries(
        metadata_predicate=lambda row: row["precursor_charge"] > 6
    )  # Prosit-specific filtering, see https://github.com/Nesvilab/FragPipe/issues/1775
    .filter_entries(
        metadata_predicate=lambda row: len(row["prediction"]) > 30
    )  # Prosit-specific filtering
    .filter_entries(
        predictions_predicate=lambda row: len(row[1].sequence) > 30
    )  # Prosit-specific filtering
)

TEST_FRACTION = 0.2
RANDOM_STATE = 42
train, test = train_test_split(
    filtered_dataset, test_size=TEST_FRACTION, random_state=RANDOM_STATE
)

train_metadata, train_predictions = zip(*train)
train_dataset = CalibrationDataset(
    metadata=pd.DataFrame(train_metadata).reset_index(drop=True),
    predictions=list(train_predictions),
)

test_metadata, test_predictions = zip(*test)
test_dataset = CalibrationDataset(
    metadata=pd.DataFrame(test_metadata).reset_index(drop=True),
    predictions=list(test_predictions),
)

In [None]:
print(len(train_dataset))
print(len(test_dataset))

In [None]:
# -- Set up calibrator
logger.info("Initialising calibrator.")
SEED = 42
calibrator = ProbabilityCalibrator(SEED)

logger.info("Adding features to calibrator.")
MZ_TOLERANCE = 0.02
HIDDEN_DIM = 10
TRAIN_FRACTION = 0.1
calibrator.add_feature(MassErrorFeature(residue_masses=RESIDUE_MASSES))
calibrator.add_feature(PrositFeatures(mz_tolerance=MZ_TOLERANCE))
calibrator.add_feature(
    RetentionTimeFeature(hidden_dim=HIDDEN_DIM, train_fraction=TRAIN_FRACTION)
)
calibrator.add_feature(ChimericFeatures(mz_tolerance=MZ_TOLERANCE))
calibrator.add_feature(BeamFeatures())

In [None]:
# -- Calibrate
logger.info("Calibrating scores.")
calibrator.fit(train_dataset)
calibrator.predict(test_dataset)

In [None]:
train_dataset.to_csv(f"../calibrated_datasets/labelled/{SPECIES}_train_labelled.csv")
test_dataset.to_csv(f"../calibrated_datasets/labelled/{SPECIES}_test_labelled.csv")
test_dataset.metadata

In [None]:
# -- Evaluate and plot
def compute_roc_curve(
    metadata_path: str,
    confidence_column: str,
    label_column: str,
    name: str,
) -> pd.DataFrame:
    input_dataset = pd.read_csv(metadata_path)
    original = input_dataset[[confidence_column, label_column]]
    original = original.sort_values(by=confidence_column, ascending=False)
    cum_correct = np.cumsum(original[label_column])
    precision = cum_correct / np.arange(1, len(original) + 1)
    recall = cum_correct / len(original)
    metrics = pd.DataFrame({"precision": precision, "recall": recall}).reset_index(
        drop=True
    )
    metrics["name"] = name
    return metrics

In [None]:
original = compute_roc_curve(
    metadata_path=f"../calibrated_datasets/labelled/{SPECIES}_test_labelled.csv",
    confidence_column="confidence",
    label_column="correct",
    name="Original",
)
calibrated = compute_roc_curve(
    metadata_path=f"../calibrated_datasets/labelled/{SPECIES}_test_labelled.csv",
    confidence_column="calibrated_confidence",
    label_column="correct",
    name="Calibrated",
)
metrics = pd.concat([original, calibrated], axis=0).reset_index(drop=True)
metrics["fdr"] = 1 - metrics["precision"]

plot = so.Plot(metrics, x="recall", y="precision", color="name")
plot = plot.add(so.Line(), group="name")
plot = plot.theme(theme_dict)
plot = plot.label(
    y="Precision", x="Recall", title="ROC curve for original and calibrated confidence"
)
plot

In [None]:
# Read CSV
test_dataset_metadata = pd.read_csv(
    f"../calibrated_datasets/labelled/{SPECIES}_test_labelled.csv"
)


def try_convert(value):
    try:
        return ast.literal_eval(value)
    except (ValueError, SyntaxError):
        return value  # Return original value if conversion fails


# Apply conversion to all object (string) columns
for col in test_dataset_metadata.select_dtypes(include=["object"]).columns:
    test_dataset_metadata[col] = test_dataset_metadata[col].apply(try_convert)

In [None]:
data = test_dataset_metadata[["confidence", "calibrated_confidence", "correct"]].copy(
    deep=True
)
data["correct"] = pd.Categorical(data["correct"])

so.Plot(data, x="confidence", y="calibrated_confidence", color="correct").add(
    so.Dot(alpha=0.2)
).add(
    so.Line(color="black", linestyle="-"),
    data=pd.DataFrame(
        {
            "confidence": [0.0, 1.0],
            "calibrated_confidence": [0.0, 1.0],
            "correct": ["Null", "Null"],
        }
    ),
)

In [None]:
plot_df = test_dataset_metadata[["confidence", "correct"]].copy(deep=True)
plot_df["correct"] = plot_df["correct"].apply(lambda x: "T" if x else "F")
so.Plot(plot_df, "confidence").add(so.Bars(), so.Hist(bins=100), color="correct")

In [None]:
plot_df = test_dataset_metadata[["calibrated_confidence", "correct"]].copy(deep=True)
plot_df["correct"] = plot_df["correct"].apply(lambda x: "T" if x else "F")
so.Plot(plot_df, "calibrated_confidence").add(
    so.Bars(), so.Hist(bins=100), color="correct"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(
    dataset=test_dataset_metadata, residue_masses=RESIDUE_MASSES
)
database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
nonparametric_fdr_control = NonParametricFDRControl()
nonparametric_fdr_control.fit(dataset=test_dataset_metadata["confidence"])
nonparametric_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
nonparametric_fdr_control = NonParametricFDRControl()
nonparametric_fdr_control.fit(dataset=test_dataset_metadata["calibrated_confidence"])
nonparametric_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
labelled_ids = annotated_data["spectrum_id"]

raw_data = pd.read_parquet(
    f"../input_data/spectrum_data/raw/dataset-{SPECIES}-raw-0000-0001.parquet"
)

# Exclude rows with spectrum_id in labelled_ids
raw_data = raw_data[~raw_data["spectrum_id"].isin(labelled_ids)]
raw_beam_preds = raw_beam_preds[~raw_beam_preds["spectrum_id"].isin(labelled_ids)]

raw_data.to_parquet(
    f"../input_data/spectrum_data/de_novo/{SPECIES}_raw_filtered.parquet"
)
raw_beam_preds.to_csv(
    f"../input_data/beam_preds/de_novo/{SPECIES}_raw_beam_preds_filtered.csv"
)

In [None]:
# -- Load the raw, unlabelled data
logger.info("Loading raw dataset.")
dataset = CalibrationDataset.from_predictions_csv(
    spectrum_path=f"../input_data/spectrum_data/de_novo/{SPECIES}_raw_filtered.parquet",
    beam_predictions_path=f"../input_data/beam_preds/de_novo/{SPECIES}_raw_beam_preds_filtered.csv",
)

logger.info("Filtering dataset.")
raw_filtered_dataset = (
    dataset.filter_entries(
        metadata_predicate=lambda row: not isinstance(row["prediction"], list),
    )
    .filter_entries(metadata_predicate=lambda row: not row["prediction"])
    .filter_entries(
        metadata_predicate=lambda row: row["precursor_charge"] > 6
    )  # Prosit-specific filtering
    .filter_entries(
        metadata_predicate=lambda row: len(row["prediction"]) > 30
    )  # Prosit-specific filtering
    .filter_entries(
        predictions_predicate=lambda row: len(row[1].sequence) > 30
    )  # Prosit-specific filtering
)

In [None]:
# -- Predict on the raw, unlabelled data
calibrator.predict(raw_filtered_dataset)

In [None]:
raw_filtered_dataset.to_csv(
    f"../calibrated_datasets/de_novo/{SPECIES}_de_novo_preds.csv"
)

In [None]:
confidence_type = "calibrated_confidence"

In [None]:
nonparametric_fdr_control = NonParametricFDRControl()
nonparametric_fdr_control.fit(dataset=raw_filtered_dataset.metadata[confidence_type])
confidence_cutoff = nonparametric_fdr_control.get_confidence_cutoff(threshold=0.05)
confidence_cutoff

In [None]:
raw_filtered_dataset_metadata = nonparametric_fdr_control.add_psm_fdr(
    raw_filtered_dataset.metadata, confidence_type
)
raw_filtered_dataset_metadata = nonparametric_fdr_control.add_psm_pep(
    raw_filtered_dataset_metadata, confidence_type
)
raw_filtered_dataset_metadata = nonparametric_fdr_control.add_psm_p_value(
    raw_filtered_dataset_metadata, confidence_type
)

In [None]:
raw_filtered_dataset_metadata[
    raw_filtered_dataset_metadata[confidence_type] >= confidence_cutoff
]