In [None]:
# -- Import
from winnow.calibration.calibration_features import (
    PrositFeatures,
    MassErrorFeature,
    RetentionTimeFeature,
)
from winnow.calibration.calibrator import ProbabilityCalibrator
from winnow.datasets.calibration_dataset import RESIDUE_MASSES, CalibrationDataset

from winnow.fdr.database_grounded import DatabaseGroundedFDRControl
from winnow.fdr.bayes import EmpiricalBayesFDRControl

import logging

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

import seaborn.objects as so
from seaborn import axes_style

theme_dict = {**axes_style("whitegrid"), "grid.linestyle": ":"}

In [None]:
# -- Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

In [None]:
# -- Load data
SPECIES = "wound_fluids"

logger.info("Loading dataset.")
dataset = CalibrationDataset.from_predictions_csv(
    spectrum_path=f"/home/j-daniel/Repos/winnow/{SPECIES}_labelled.ipc",
    predictions_path=f"/home/j-daniel/Repos/winnow/{SPECIES}_labelled_beam_preds.csv",
)

logger.info("Filtering dataset.")
# TODO: confirm only low confidence identifications
filtered_dataset = (
    dataset.filter_entries(
        metadata_predicate=lambda row: not isinstance(row["prediction"], list),
    )
    .filter_entries(
        metadata_predicate=lambda row: "N(+.98)" in row["prediction"],
    )
    .filter_entries(
        metadata_predicate=lambda row: "Q(+.98)" in row["prediction"],
    )
    .filter_entries(metadata_predicate=lambda row: not row["prediction"])
)

TEST_FRACTION = 0.2
RANDOM_STATE = 42
train, test = train_test_split(
    filtered_dataset, test_size=TEST_FRACTION, random_state=RANDOM_STATE
)

train_metadata, train_predictions = zip(*train)
train_dataset = CalibrationDataset(
    metadata=pd.DataFrame(train_metadata).reset_index(drop=True),
    predictions=list(train_predictions),
)

test_metadata, test_predictions = zip(*test)
test_dataset = CalibrationDataset(
    metadata=pd.DataFrame(test_metadata).reset_index(drop=True),
    predictions=list(test_predictions),
)

In [None]:
# -- Set up calibrator
logger.info("Initialising calibrator.")
SEED = 42
calibrator = ProbabilityCalibrator(SEED)

logger.info("Adding features to calibrator.")
MZ_TOLERANCE = 0.02
HIDDEN_DIM = 10
TRAIN_FRACTION = 0.1
calibrator.add_feature(MassErrorFeature(residue_masses=RESIDUE_MASSES))
calibrator.add_feature(PrositFeatures(mz_tolerance=MZ_TOLERANCE))
calibrator.add_feature(
    RetentionTimeFeature(hidden_dim=HIDDEN_DIM, train_fraction=TRAIN_FRACTION)
)
# calibrator.add_feature(ChimericFeatures(mz_tolerance=MZ_TOLERANCE))
# calibrator.add_feature(BeamFeatures())

In [None]:
len(test_dataset)

In [None]:
# -- Calibrate
logger.info("Calibrating scores.")
calibrator.fit(train_dataset)
calibrator.predict(test_dataset)

In [None]:
test_dataset.metadata

In [None]:
# -- Evaluate and plot
def compute_roc_curve(
    input_dataset: CalibrationDataset,
    confidence_column: str,
    label_column: str,
    name: str,
) -> pd.DataFrame:
    original = input_dataset.metadata[[confidence_column, label_column]].copy(deep=True)
    original = original.sort_values(by=confidence_column, ascending=False)
    cum_correct = np.cumsum(original[label_column])
    precision = cum_correct / np.arange(1, len(original) + 1)
    recall = cum_correct / len(original)
    metrics = pd.DataFrame({"precision": precision, "recall": recall}).reset_index(
        drop=True
    )
    metrics["name"] = name
    return metrics

In [None]:
original = compute_roc_curve(
    input_dataset=test_dataset,
    confidence_column="confidence",
    label_column="correct",
    name="Original",
)
calibrated = compute_roc_curve(
    input_dataset=test_dataset,
    confidence_column="calibrated_confidence",
    label_column="correct",
    name="Calibrated",
)
metrics = pd.concat([original, calibrated], axis=0).reset_index(drop=True)
metrics["fdr"] = 1 - metrics["precision"]

plot = so.Plot(metrics, x="recall", y="precision", color="name")
plot = plot.add(so.Line(), group="name")
plot = plot.theme(theme_dict)
plot = plot.label(
    y="Precision", x="Recall", title="ROC curve for original and calibrated confidence"
)
plot

In [None]:
# TODO: quantify confidence shift
data = test_dataset.metadata[["confidence", "calibrated_confidence", "correct"]].copy(
    deep=True
)
data["correct"] = pd.Categorical(data["correct"])

so.Plot(data, x="confidence", y="calibrated_confidence", color="correct").add(
    so.Dot(alpha=0.2)
).add(
    so.Line(color="black", linestyle="-"),
    data=pd.DataFrame(
        {
            "confidence": [0.0, 1.0],
            "calibrated_confidence": [0.0, 1.0],
            "correct": ["Null", "Null"],
        }
    ),
)

In [None]:
plot_df = test_dataset.metadata[["confidence", "correct"]].copy(deep=True)
plot_df["correct"] = plot_df["correct"].apply(lambda x: "T" if x else "F")
so.Plot(plot_df, "confidence").add(so.Bars(), so.Hist(bins=100), color="correct")

In [None]:
plot_df = test_dataset.metadata[["calibrated_confidence", "correct"]].copy(deep=True)
plot_df["correct"] = plot_df["correct"].apply(lambda x: "T" if x else "F")
so.Plot(plot_df, "calibrated_confidence").add(
    so.Bars(), so.Hist(bins=100), color="correct"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(
    dataset=test_dataset.metadata, residue_masses=RESIDUE_MASSES
)
database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
mixture_fdr_control = EmpiricalBayesFDRControl()
mixture_fdr_control.fit(dataset=test_dataset.metadata["confidence"])
mixture_fdr_control.get_confidence_cutoff(threshold=0.05)

In [None]:
mixture_fdr_control = EmpiricalBayesFDRControl()
mixture_fdr_control.fit(dataset=test_dataset.metadata["calibrated_confidence"])
mixture_fdr_control.get_confidence_cutoff(threshold=0.05)