In [None]:
# -- Import
from winnow.calibration.calibrator import ProbabilityCalibrator
from winnow.datasets.calibration_dataset import RESIDUE_MASSES, CalibrationDataset

from winnow.fdr.database_grounded import DatabaseGroundedFDRControl
from winnow.fdr.nonparametric import NonParametricFDRControl

import logging

import numpy as np
import pandas as pd
from pathlib import Path

from sklearn.model_selection import train_test_split

import seaborn.objects as so
from seaborn import axes_style

theme_dict = {**axes_style("whitegrid"), "grid.linestyle": ":"}

In [None]:
# -- Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

In [None]:
SPECIES = "immuno"  # [gluc, helaqc, herceptin, immuno, sbrodae, snakevenoms, tplantibodies, woundfluids]

In [None]:
# -- Load data
logger.info("Loading dataset.")
dataset = CalibrationDataset.from_predictions_csv(
    spectrum_path=f"../input_data/spectrum_data/labelled/dataset-{SPECIES}-annotated-0000-0001.parquet",
    beam_predictions_path=f"../input_data/beam_preds/labelled/{SPECIES}-annotated_beam_preds.csv",
)

logger.info("Filtering dataset.")
filtered_dataset = (
    dataset.filter_entries(
        metadata_predicate=lambda row: not isinstance(row["prediction"], list),
    )
    .filter_entries(metadata_predicate=lambda row: not row["prediction"])
    .filter_entries(
        metadata_predicate=lambda row: row["precursor_charge"] > 6
    )  # Prosit-specific filtering, see https://github.com/Nesvilab/FragPipe/issues/1775
    .filter_entries(
        metadata_predicate=lambda row: len(row["prediction"]) > 30
    )  # Prosit-specific filtering
    .filter_entries(
        predictions_predicate=lambda row: len(row[1].sequence) > 30
    )  # Prosit-specific filtering
)

TEST_FRACTION = 0.2
RANDOM_STATE = 42
train, test = train_test_split(
    filtered_dataset, test_size=TEST_FRACTION, random_state=RANDOM_STATE
)

train_metadata, train_predictions = zip(*train)
train_dataset = CalibrationDataset(
    metadata=pd.DataFrame(train_metadata).reset_index(drop=True),
    predictions=list(train_predictions),
)

test_metadata, test_predictions = zip(*test)
test_dataset = CalibrationDataset(
    metadata=pd.DataFrame(test_metadata).reset_index(drop=True),
    predictions=list(test_predictions),
)

In [None]:
print(len(train_dataset))
print(len(test_dataset))

In [None]:
logger.info("Loading calibrator.")
calibrator = ProbabilityCalibrator.load(Path(f"../checkpoints/{SPECIES}"))

In [None]:
logger.info("Calibrating scores.")
calibrator.predict(test_dataset)

In [None]:
# -- Evaluate and plot
def compute_roc_curve(
    input_dataset: pd.DataFrame,
    confidence_column: str,
    label_column: str,
    name: str,
) -> pd.DataFrame:
    original = input_dataset[[confidence_column, label_column]]
    original = original.sort_values(by=confidence_column, ascending=False)
    cum_correct = np.cumsum(original[label_column])
    precision = cum_correct / np.arange(1, len(original) + 1)
    recall = cum_correct / len(original)
    metrics = pd.DataFrame({"precision": precision, "recall": recall}).reset_index(
        drop=True
    )
    metrics["name"] = name
    return metrics

In [None]:
original = compute_roc_curve(
    input_dataset=test_dataset.metadata,
    confidence_column="confidence",
    label_column="correct",
    name="Original",
)
calibrated = compute_roc_curve(
    input_dataset=test_dataset.metadata,
    confidence_column="calibrated_confidence",
    label_column="correct",
    name="Calibrated",
)
metrics = pd.concat([original, calibrated], axis=0).reset_index(drop=True)
metrics["fdr"] = 1 - metrics["precision"]

plot = so.Plot(metrics, x="recall", y="precision", color="name")
plot = plot.add(so.Line(), group="name")
plot = plot.theme(theme_dict)
plot = plot.label(
    y="Precision", x="Recall", title="ROC curve for original and calibrated confidence"
)
plot

In [None]:
plot_df = test_dataset.metadata[["confidence", "correct"]].copy(deep=True)
plot_df["correct"] = plot_df["correct"].apply(lambda x: "T" if x else "F")
so.Plot(plot_df, "confidence").add(so.Bars(), so.Hist(bins=100), color="correct")

In [None]:
plot_df = test_dataset.metadata[["calibrated_confidence", "correct"]].copy(deep=True)
plot_df["correct"] = plot_df["correct"].apply(lambda x: "T" if x else "F")
so.Plot(plot_df, "calibrated_confidence").add(
    so.Bars(), so.Hist(bins=100), color="correct"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(
    dataset=test_dataset.metadata, residue_masses=RESIDUE_MASSES
)
database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=test_dataset.metadata["confidence"])
non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(
    dataset=test_dataset.metadata, residue_masses=RESIDUE_MASSES
)
database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=test_dataset.metadata["calibrated_confidence"])
non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
test_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    test_dataset.metadata, "calibrated_confidence"
)
test_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    test_dataset.metadata, "calibrated_confidence"
)

test_dataset_metadata