In [None]:
# -- Import
from winnow.datasets.calibration_dataset import CalibrationDataset, RESIDUE_MASSES
from winnow.datasets.data_loaders import InstaNovoDatasetLoader
from winnow.scripts.main import (
    filter_dataset,
    initialise_calibrator,
)
from winnow.fdr.database_grounded import DatabaseGroundedFDRControl
from winnow.fdr.nonparametric import NonParametricFDRControl

import numpy as np
import logging
import os
import glob
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve

In [None]:
# -- Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

## Load data

In [None]:
SPECIES = "helaqc"  # [gluc, helaqc, herceptin, immuno, sbrodae, snakevenoms, tplantibodies, woundfluids]

In [None]:
# -- Load data
logger.info("Loading dataset.")
dataset = InstaNovoDatasetLoader().load(
    "../validation_datasets_corrected/spectrum_data/labelled/dataset-helaqc-annotated-0000-0001.parquet",
    "../validation_datasets_corrected/beam_preds/labelled/helaqc-annotated_beam_preds.csv",
)

logger.info("Filtering dataset.")
filtered_dataset = filter_dataset(dataset)

# Split dataset
TEST_FRACTION = 0.2
RANDOM_STATE = 42
train, test = train_test_split(
    filtered_dataset, test_size=TEST_FRACTION, random_state=RANDOM_STATE
)

train_metadata, train_predictions = zip(*train)
train_dataset = CalibrationDataset(
    metadata=pd.DataFrame(train_metadata).reset_index(drop=True),
    predictions=list(train_predictions),
)

test_metadata, test_predictions = zip(*test)
test_dataset = CalibrationDataset(
    metadata=pd.DataFrame(test_metadata).reset_index(drop=True),
    predictions=list(test_predictions),
)

## Train model

In [None]:
logger.info("Training calibrator.")
calibrator = initialise_calibrator()
calibrator.fit(train_dataset)

## Evaluate model on labelled test set

In [None]:
logger.info("Calibrating scores.")
calibrator.predict(test_dataset)

## Save metadata

In [None]:
logger.info("Saving evaluation results.")
test_dataset.metadata.to_csv(f"../{SPECIES}_results/test_dataset.csv", index=False)

## Compute FDR metrics on raw confidence

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=test_dataset.metadata["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

test_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    test_dataset.metadata, "confidence"
)
test_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    test_dataset_metadata, "confidence"
)
test_dataset_metadata = non_parametric_fdr_control.add_psm_qvalue(
    test_dataset_metadata, "confidence"
)

# Save metrics
test_dataset_metadata[["spectrum_id", "psm_fdr", "psm_pep", "psm_qvalue"]].to_csv(
    f"../{SPECIES}_results/test_dataset_raw_confidence_winnow_fdr.csv", index=False
)

## Compute FDR metrics on calibrated confidence

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=test_dataset.metadata["calibrated_confidence"])

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

test_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    test_dataset.metadata, "calibrated_confidence"
)
test_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    test_dataset_metadata, "calibrated_confidence"
)
test_dataset_metadata = non_parametric_fdr_control.add_psm_qvalue(
    test_dataset_metadata, "calibrated_confidence"
)

# Save metrics
test_dataset_metadata[["spectrum_id", "psm_fdr", "psm_pep", "psm_qvalue"]].to_csv(
    f"../{SPECIES}_results/test_dataset_calibrated_confidence_winnow_fdr.csv",
    index=False,
)

## Compute database-grounded FDR metrics on raw confidence

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(
    dataset=test_dataset.metadata, residue_masses=RESIDUE_MASSES
)

logger.info(
    f"Database-grounded FDR threshold for raw confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

test_dataset_metadata = database_grounded_fdr_control.add_psm_fdr(
    test_dataset.metadata, "confidence"
)
test_dataset_metadata = database_grounded_fdr_control.add_psm_qvalue(
    test_dataset_metadata, "confidence"
)

# Save metrics
test_dataset_metadata[["spectrum_id", "psm_fdr", "psm_qvalue"]].to_csv(
    f"../{SPECIES}_results/test_dataset_raw_confidence_dbg_fdr.csv", index=False
)

## Compute database-grounded FDR metrics on calibrated confidence

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(
    dataset=test_dataset.metadata, residue_masses=RESIDUE_MASSES
)

logger.info(
    f"Database-grounded FDR threshold for calibrated confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

test_dataset_metadata = database_grounded_fdr_control.add_psm_fdr(
    test_dataset.metadata, "calibrated_confidence"
)
test_dataset_metadata = database_grounded_fdr_control.add_psm_qvalue(
    test_dataset_metadata, "calibrated_confidence"
)

# Save metrics
test_dataset_metadata[["spectrum_id", "psm_fdr", "psm_qvalue"]].to_csv(
    f"../{SPECIES}_results/test_dataset_calibrated_confidence_dbg_fdr.csv", index=False
)

## Evaluate model on full search space

In [None]:
# -- Load the raw, unlabelled data
logger.info("Loading raw dataset.")
raw_dataset = InstaNovoDatasetLoader().load(
    f"../validation_datasets_corrected/spectrum_data/raw/{SPECIES}_raw_less_train.parquet",
    f"../validation_datasets_corrected/beam_preds/raw/{SPECIES}_raw_less_train.csv",
)

logger.info("Filtering dataset.")
raw_filtered_dataset = filter_dataset(raw_dataset)

In [None]:
logger.info("Calibrating scores.")
calibrator.predict(raw_filtered_dataset)

## Save metadata

In [None]:
logger.info("Saving evaluation results.")
raw_filtered_dataset.metadata.to_csv(
    f"../{SPECIES}_results/raw_less_train.csv", index=False
)

## Compute FDR metrics on raw confidence

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=raw_filtered_dataset.metadata["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    raw_filtered_dataset.metadata, "confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    raw_filtered_dataset_metadata, "confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_qvalue(
    raw_filtered_dataset_metadata, "confidence"
)

# Save metrics
raw_filtered_dataset_metadata[
    ["spectrum_id", "psm_fdr", "psm_pep", "psm_qvalue"]
].to_csv(
    f"../{SPECIES}_results/raw_less_train_dataset_raw_confidence_winnow_fdr.csv",
    index=False,
)

## Compute FDR metrics on calibrated confidence

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(
    dataset=raw_filtered_dataset.metadata["calibrated_confidence"]
)

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    raw_filtered_dataset.metadata, "calibrated_confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    raw_filtered_dataset_metadata, "calibrated_confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_qvalue(
    raw_filtered_dataset_metadata, "calibrated_confidence"
)

# Save metrics
raw_filtered_dataset_metadata[
    ["spectrum_id", "psm_fdr", "psm_pep", "psm_qvalue"]
].to_csv(
    f"../{SPECIES}_results/raw_less_train_dataset_calibrated_confidence_winnow_fdr.csv",
    index=False,
)

## Evaluate model on de novo data

In [None]:
# -- Load the raw, unlabelled data
logger.info("Loading de novo dataset.")
raw_dataset = InstaNovoDatasetLoader().load(
    f"../validation_datasets_corrected/spectrum_data/de_novo/{SPECIES}_raw_filtered.parquet",
    f"../validation_datasets_corrected/beam_preds/de_novo/{SPECIES}_raw_beam_preds_filtered.csv",
)

logger.info("Filtering dataset.")
raw_filtered_dataset = filter_dataset(raw_dataset)

In [None]:
logger.info("Calibrating scores.")
calibrator.predict(raw_filtered_dataset)

## Save metadata

In [None]:
logger.info("Saving evaluation results.")
raw_filtered_dataset.metadata.to_csv(
    f"../{SPECIES}_results/raw_filtered_dataset.csv", index=False
)

## Compute FDR metrics on raw confidence

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=raw_filtered_dataset.metadata["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    raw_filtered_dataset.metadata, "confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    raw_filtered_dataset_metadata, "confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_qvalue(
    raw_filtered_dataset_metadata, "confidence"
)

# Save metrics
raw_filtered_dataset_metadata[
    ["spectrum_id", "psm_fdr", "psm_pep", "psm_qvalue"]
].to_csv(
    f"../{SPECIES}_results/de_novo_dataset_raw_confidence_winnow_fdr.csv", index=False
)

## Compute FDR metrics on calibrated confidence

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(
    dataset=raw_filtered_dataset.metadata["calibrated_confidence"]
)

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_fdr(
    raw_filtered_dataset.metadata, "calibrated_confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_pep(
    raw_filtered_dataset_metadata, "calibrated_confidence"
)
raw_filtered_dataset_metadata = non_parametric_fdr_control.add_psm_qvalue(
    raw_filtered_dataset_metadata, "calibrated_confidence"
)

# Save metrics
raw_filtered_dataset_metadata[
    ["spectrum_id", "psm_fdr", "psm_pep", "psm_qvalue"]
].to_csv(
    f"../{SPECIES}_results/de_novo_dataset_calibrated_confidence_winnow_fdr.csv",
    index=False,
)

## Plot proteome mapping results

We want to illustrate using proteome mapping in this panel, using the helaqc dataset.

- PR plot (confidence vs calibrated confidence)
- FDR run plot (DBG FDR on confidence and on calibrated confidence; winnow on calibrated confidence)
- Calibration plot (confidence vs calibrated confidence)

In [None]:
# Color scheme from the attached notebook
COLORS = {
    "fairy": "#FFCAE9",
    "magenta": "#8E5572",
    "ash": "#BBC5AA",
    "ebony": "#5A6650",
    "sky": "#7FC8F8",
    "navy": "#3C81AE",
}

# Species name mapping for nicer plot labels
SPECIES_NAME_MAPPING = {
    "gluc": "HeLa degradome",
    "helaqc": "HeLa single shot",
    "herceptin": "Herceptin",
    "immuno": "Immunopeptidomics-1",
    "sbrodae": "Scalindua brodae",
    "snakevenoms": "Snake venomics",
    "woundfluids": "Wound exudates",
    "PXD014877": "C. elegans",
    "PXD019483": "HepG2",
    "PXD023064": "Immunopeptidomics-2",
    "general": "General test set",
}

# Configure matplotlib to match seaborn "paper" context with font_scale=2
plt.style.use("seaborn-v0_8-paper")
plt.rcParams.update(
    {
        # Figure settings
        "figure.figsize": [12.0, 10.0],
        "figure.facecolor": "white",
        "figure.dpi": 100.0,
        # Axes settings
        "axes.labelcolor": ".15",
        "axes.axisbelow": True,
        "axes.grid": False,
        "axes.facecolor": "white",
        "axes.edgecolor": ".15",
        "axes.linewidth": 1.0,
        "axes.spines.left": True,
        "axes.spines.bottom": True,
        "axes.spines.right": True,
        "axes.spines.top": True,
        # Tick settings
        "xtick.direction": "out",
        "ytick.direction": "out",
        "xtick.color": ".15",
        "ytick.color": ".15",
        "xtick.top": False,
        "ytick.right": False,
        "xtick.bottom": False,
        "ytick.left": False,
        "xtick.major.width": 1.0,
        "ytick.major.width": 1.0,
        "xtick.minor.width": 0.8,
        "ytick.minor.width": 0.8,
        "xtick.major.size": 4.8,
        "ytick.major.size": 4.8,
        "xtick.minor.size": 3.2,
        "ytick.minor.size": 3.2,
        # Grid settings
        "grid.linestyle": "-",
        "grid.color": ".8",
        "grid.linewidth": 0.8,
        # Text settings
        "text.color": ".15",
        "font.family": ["sans-serif"],
        "font.sans-serif": [
            "Arial",
            "DejaVu Sans",
            "Liberation Sans",
            "Bitstream Vera Sans",
            "sans-serif",
        ],
        "font.size": 9.6,
        "axes.labelsize": 9.6,
        "axes.titlesize": 9.6,
        "xtick.labelsize": 8.8,
        "ytick.labelsize": 8.8,
        "legend.fontsize": 8.8,
        "legend.title_fontsize": 9.6,
        # Line and patch settings
        "lines.linewidth": 1.2,
        "lines.markersize": 4.8,
        "lines.solid_capstyle": "round",
        "patch.linewidth": 0.8,
        "patch.edgecolor": "black",
        "patch.force_edgecolor": True,
        # Image settings
        "image.cmap": "rocket",
    }
)
# paper_params = {
#     # bar edge settings
#     "patch.force_edgecolor": True,   # force edgecolors on histogram/bar patches
#     "patch.edgecolor": "black",      # default edge color for patches
#     "patch.linewidth": 1.0,          # border thickness
# }
# mpl.rcParams.update(paper_params)

# Print style information
print("Matplotlib seaborn-v0_0-paper Style Characteristics:")
print(f"Figure size: {plt.rcParams['figure.figsize']}")
print(f"DPI: {plt.rcParams['figure.dpi']}")
print(f"Font size: {plt.rcParams['font.size']}")
print(f"Axes line width: {plt.rcParams['axes.linewidth']}")

print("\nKey rcParams settings:")
paper_relevant_params = [
    "figure.figsize",
    "figure.dpi",
    "font.size",
    "axes.linewidth",
    "axes.grid",
    "axes.spines.left",
    "axes.spines.bottom",
    "axes.spines.top",
    "axes.spines.right",
    "xtick.bottom",
    "xtick.top",
    "ytick.left",
    "ytick.right",
    "axes.axisbelow",
    "grid.linewidth",
    "lines.linewidth",
    "patch.linewidth",
    "lines.markersize",
    "axes.titlesize",
    "axes.labelsize",
    "xtick.labelsize",
    "ytick.labelsize",
    "legend.fontsize",
]

for param in paper_relevant_params:
    if param in plt.rcParams:
        print(f"  {param}: {plt.rcParams[param]}")

print(f"\nCurrent style: {mpl.get_backend()}")
print(
    f"Available styles containing 'seaborn': {[s for s in plt.style.available if 'seaborn' in s]}"
)


def compute_pr_curve(
    input_dataset: pd.DataFrame,
    confidence_column: str,
    label_column: str,
    name: str,
) -> pd.DataFrame:
    """Compute precision-recall curve for given confidence scores and labels.

    Args:
        input_dataset: DataFrame containing confidence scores and labels
        confidence_column: Name of the column containing confidence scores
        label_column: Name of the column containing boolean labels
        name: Name to assign to the computed curve

    Returns:
        DataFrame with precision, recall, and name columns
    """
    original = input_dataset[[confidence_column, label_column]]
    original = original.sort_values(by=confidence_column, ascending=False)
    cum_correct = np.cumsum(original[label_column])
    precision = cum_correct / np.arange(1, len(original) + 1)
    recall = cum_correct / len(original)
    metrics = pd.DataFrame({"precision": precision, "recall": recall}).reset_index(
        drop=True
    )
    metrics["name"] = name
    return metrics


def plot_pr_curve_on_axes(
    metadata: pd.DataFrame,
    ax: plt.Axes,
    title: str = "Precision-Recall Curve",
    label_column: str = "correct",
) -> None:
    """Plot precision-recall curves for original and calibrated confidence on a given axes."""
    # Compute PR curves
    original = compute_pr_curve(
        input_dataset=metadata,
        confidence_column="confidence",
        label_column=label_column,
        name="Raw confidence",
    )
    calibrated = compute_pr_curve(
        input_dataset=metadata,
        confidence_column="calibrated_confidence",
        label_column=label_column,
        name="Calibrated confidence",
    )
    metrics = pd.concat([original, calibrated], axis=0).reset_index(drop=True)

    # Plot each curve with new color scheme
    for name, group in metrics.groupby("name"):
        if name == "Raw confidence":
            color = COLORS["sky"]  # Sky blue for original
        else:
            color = COLORS["ebony"]  # Ebony for calibrated
        ax.plot(group["recall"], group["precision"], label=name, color=color, zorder=2)

    ax.set_axisbelow(True)
    ax.grid(True, color="lightgray", zorder=0)
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title(title)
    ax.legend()
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)


def plot_confidence_distribution_on_axes(
    metadata: pd.DataFrame,
    ax: plt.Axes,
    confidence_column: str = "confidence",
    title: str = "Confidence Distribution",
    density: bool = False,
    label_column: str = "correct",
) -> None:
    """Plot confidence distribution on a given axes."""
    plot_df = metadata[[confidence_column, label_column]].copy(deep=True)
    plot_df[label_column] = plot_df[label_column].apply(lambda x: "T" if x else "F")

    true_conf = plot_df[plot_df[label_column] == "T"][confidence_column]
    false_conf = plot_df[plot_df[label_column] == "F"][confidence_column]

    ax.hist(
        false_conf,
        bins=50,
        alpha=0.7,
        label="Incorrect",
        color=COLORS["sky"],
        density=density,
        edgecolor="#333333",
    )
    ax.hist(
        true_conf,
        bins=50,
        alpha=0.7,
        label="Correct",
        color=COLORS["ebony"],
        density=density,
        edgecolor="#333333",
    )
    ax.set_xlabel(confidence_column.replace("_", " ").title())
    if density:
        ax.set_ylabel("Density")
    else:
        ax.set_ylabel("Frequency")
    ax.set_title(title)
    ax.legend()


def plot_calibration_curve_on_axes(
    metadata: pd.DataFrame,
    ax: plt.Axes,
    confidence_column: str = "confidence",
    title: str = "Confidence Calibration",
    label_column: str = "correct",
) -> None:
    """Plot probability calibration curve on a given axes."""
    confidence_scores = metadata[confidence_column].values
    true_labels = metadata[label_column].values

    # Calculate calibration curve
    fraction_of_positives, mean_predicted_value = calibration_curve(
        true_labels, confidence_scores, n_bins=10, strategy="uniform"
    )

    # Determine color based on confidence column
    if confidence_column == "confidence":
        color = COLORS["sky"]  # Sky for original
        label = "Raw confidence"
    else:
        color = COLORS["ebony"]  # Ebony for calibrated
        label = "Calibrated confidence"

    # Plot calibration curve
    ax.plot(
        mean_predicted_value,
        fraction_of_positives,
        "s-",
        label=label,
        color=color,
        zorder=2,
    )
    ax.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated", alpha=0.5, zorder=2)
    ax.set_axisbelow(True)
    ax.grid(True, color="lightgray", zorder=0)
    ax.set_xlabel("Mean predicted probability")
    ax.set_ylabel("Fraction of positives")
    ax.set_title(title)
    ax.legend()
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])


def plot_combined_calibration_curves(
    metadata: pd.DataFrame,
    title: str = "Confidence Calibration Comparison",
    label_column: str = "correct",
) -> plt.Figure:
    """Plot both original and calibrated confidence calibration curves on a single axis."""
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot original confidence calibration
    confidence_scores = metadata["confidence"].values
    true_labels = metadata[label_column].values
    fraction_of_positives, mean_predicted_value = calibration_curve(
        true_labels, confidence_scores, n_bins=10, strategy="uniform"
    )
    ax.plot(
        mean_predicted_value,
        fraction_of_positives,
        "s-",
        label="Raw confidence",
        color=COLORS["sky"],
        zorder=2,
    )

    # Plot calibrated confidence calibration
    calibrated_scores = metadata["calibrated_confidence"].values
    fraction_of_positives_cal, mean_predicted_value_cal = calibration_curve(
        true_labels, calibrated_scores, n_bins=10, strategy="uniform"
    )
    ax.plot(
        mean_predicted_value_cal,
        fraction_of_positives_cal,
        "o-",
        label="Calibrated confidence",
        color=COLORS["ebony"],
        zorder=2,
    )

    # Perfect calibration line
    ax.plot([0, 1], [0, 1], "k--", label="Perfectly calibrated", alpha=0.5, zorder=2)

    ax.set_axisbelow(True)
    ax.grid(True, color="lightgray", zorder=0)

    ax.set_xlabel("Mean predicted probability")
    ax.set_ylabel("Fraction of positives")
    ax.set_title(title)
    ax.legend()
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])

    return fig


def plot_combined_confidence_distributions(
    metadata: pd.DataFrame,
    title: str = "Confidence Distributions",
    label_column: str = "correct",
) -> plt.Figure:
    """Plot both original and calibrated confidence distributions on a single panel."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # Original confidence distribution
    plot_confidence_distribution_on_axes(
        metadata,
        ax1,
        "confidence",
        "Raw confidence distribution",
        label_column=label_column,
    )

    # Calibrated confidence distribution
    plot_confidence_distribution_on_axes(
        metadata,
        ax2,
        "calibrated_confidence",
        "Calibrated confidence distribution",
        label_column=label_column,
    )

    plt.suptitle(title)
    return fig


def get_plot_dataframe(
    features_df: pd.DataFrame,
    winnow_metrics_df: pd.DataFrame,
    decoy_metrics_df: pd.DataFrame,
) -> pd.DataFrame:
    """Create a dataframe for FDR plotting with true and estimated FDR values.

    Args:
        features_df: DataFrame containing features
        winnow_metrics_df: DataFrame containing Winnow metrics
        decoy_metrics_df: DataFrame containing decoy metrics

    Returns:
        DataFrame with confidence, FDR, and source columns
    """
    metrics_df = pd.merge(
        decoy_metrics_df,
        winnow_metrics_df,
        on="spectrum_id",
        how="inner",
        suffixes=("_dbg", "_winnow"),
    )
    df = pd.merge(
        features_df[["spectrum_id", "calibrated_confidence"]],
        metrics_df,
        on="spectrum_id",
        how="inner",
    ).sort_values(by="calibrated_confidence")
    return df


def plot_fdr_accuracy_on_axes(
    metadata: pd.DataFrame,
    winnow_metrics_df: pd.DataFrame,
    decoy_metrics_df: pd.DataFrame,
    ax: plt.Axes,
    title: str = "FDR Accuracy",
) -> None:
    """Plot FDR accuracy comparison on a given axes.

    Args:
        metadata: DataFrame containing confidence scores and labels
        ax: Matplotlib axes to plot on
        fdr_function: Function to calculate estimated FDR
        confidence_column: Name of the column containing confidence scores
        title: Title for the plot
        label_column: Name of the column containing boolean labels
    """
    # Get the multi-plot dataframe
    multi_plot_df = get_plot_dataframe(
        features_df=metadata,
        winnow_metrics_df=winnow_metrics_df,
        decoy_metrics_df=decoy_metrics_df,
    )

    # Plot FDR lines for each source
    ax.plot(
        multi_plot_df["calibrated_confidence"],
        multi_plot_df["psm_fdr_winnow"],
        label="Winnow FDR",
        color=COLORS["sky"],
        zorder=2,
    )
    ax.plot(
        multi_plot_df["calibrated_confidence"],
        multi_plot_df["psm_fdr_dbg"],
        label="Decoy FDR",
        color=COLORS["ebony"],
        zorder=2,
    )

    # Add horizontal line at FDR = 0.05
    ax.axhline(y=0.05, color="black", linestyle="--", alpha=0.5, zorder=2)

    # Customize the plot
    ax.set_axisbelow(True)
    ax.grid(True, color="lightgray", zorder=0)
    ax.set_xlabel("Calibrated Confidence")
    ax.set_ylabel("False discovery rate (FDR)")
    ax.set_title(title)
    ax.legend()


def create_fdr_accuracy_plot(
    metadata: pd.DataFrame,
    winnow_metrics_df: pd.DataFrame,
    decoy_metrics_df: pd.DataFrame,
    title: str = "FDR Accuracy",
) -> plt.Figure:
    """Create standalone FDR accuracy plot."""
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_fdr_accuracy_on_axes(
        metadata,
        winnow_metrics_df,
        decoy_metrics_df,
        ax,
        title,
    )
    return fig


def create_pr_curve_plot(
    metadata: pd.DataFrame,
    title: str = "Precision-Recall Curve",
    label_column: str = "correct",
) -> plt.Figure:
    """Create standalone precision-recall curve plot."""
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_pr_curve_on_axes(metadata, ax, title, label_column)
    return fig


def find_data_files(base_dir: str = "new_model/results") -> dict:
    """Find all relevant data files in the results directory.

    Args:
        base_dir: Base directory to search for files

    Returns:
        Dictionary with file categories and their paths
    """
    files: dict[str, list[str]] = {"labelled": [], "de_novo": [], "raw": []}

    # Find all CSV files in the directory
    csv_files = []
    for pattern in ["*.csv", "*.csv.*"]:  # Include files with suffixes
        csv_files.extend(glob.glob(os.path.join(base_dir, pattern)))

    for file_path in csv_files:
        file_name = os.path.basename(file_path)

        if file_name.startswith("labelled_"):
            files["labelled"].append(file_path)
        elif file_name.startswith("de_novo_"):
            files["de_novo"].append(file_path)
        elif file_name.startswith("raw_"):
            files["raw"].append(file_path)

    return files


def extract_dataset_name(file_path: str) -> str:
    """Extract dataset name from file path.

    Args:
        file_path: Path to the data file

    Returns:
        Dataset name
    """
    file_name = os.path.basename(file_path)

    if file_name.startswith("labelled_"):
        # Remove "labelled_" prefix and ".csv" suffix (and any additional suffixes)
        name = file_name[9:]  # Remove "labelled_"
        name = name.split(".csv")[0]  # Remove .csv and any suffixes
        return name.replace("_results", "")
    elif file_name.startswith("de_novo_"):
        # Remove "de_novo_" prefix
        name = file_name[8:]  # Remove "de_novo_"
        name = name.split(".csv")[0]  # Remove .csv and any suffixes
        return name.replace("_preds", "").replace("_results", "")
    elif file_name.startswith("raw_"):
        # Remove "raw_" prefix
        name = file_name[4:]  # Remove "raw_"
        name = name.split(".csv")[0]  # Remove .csv and any suffixes
        return name.replace("_results", "")

    return file_name


def convert_object_columns(metadata: pd.DataFrame) -> pd.DataFrame:
    """Convert object columns that might contain string representations of Python objects."""

    def try_convert(value):
        try:
            return ast.literal_eval(value)
        except (ValueError, SyntaxError):
            return value  # Return original if conversion fails

    # Apply conversion to object (string) columns
    for col in metadata.select_dtypes(include=["object"]).columns:
        metadata[col] = metadata[col].apply(try_convert)

    return metadata

In [None]:
# Load labelled test set
metadata_path = f"../{SPECIES}_results/test_dataset.csv"
output_dir = os.path.dirname(metadata_path) + "/plots"
dataset_name = os.path.basename(metadata_path)
metadata = pd.read_csv(metadata_path)

winnow_metrics_df = pd.read_csv(
    f"../{SPECIES}_results/test_dataset_calibrated_confidence_winnow_fdr.csv"
)
decoy_metrics_df = pd.read_csv(
    f"../{SPECIES}_results/test_dataset_calibrated_confidence_dbg_fdr.csv"
)

# -- Precision-Recall curve
pr_fig = create_pr_curve_plot(
    metadata,
    "Precision-recall curve for labelled data using database search",
    "correct",
)
pr_fig.savefig(
    os.path.join(
        output_dir, f"{dataset_name}_labelled_precision_recall_with_db_search.png"
    ),
    dpi=300,
    bbox_inches="tight",
)
pr_fig.savefig(
    os.path.join(
        output_dir, f"{dataset_name}_labelled_precision_recall_with_db_search.pdf"
    ),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(pr_fig)

# -- Calibration curves
cal_fig = plot_combined_calibration_curves(
    metadata, "Calibration curves for labelled data using database search", "correct"
)
cal_fig.savefig(
    os.path.join(
        output_dir, f"{dataset_name}_labelled_calibration_curves_with_db_search.png"
    ),
    dpi=300,
    bbox_inches="tight",
)
cal_fig.savefig(
    os.path.join(
        output_dir, f"{dataset_name}_labelled_calibration_curves_with_db_search.pdf"
    ),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(cal_fig)

# -- FDR accuracy
fdr_fig = create_fdr_accuracy_plot(
    metadata,
    winnow_metrics_df,
    decoy_metrics_df,
    "FDR accuracy for labelled data using database search",
)
fdr_fig.savefig(
    os.path.join(
        output_dir, f"{dataset_name}_labelled_fdr_accuracy_with_db_search.png"
    ),
    dpi=300,
    bbox_inches="tight",
)
fdr_fig.savefig(
    os.path.join(
        output_dir, f"{dataset_name}_labelled_fdr_accuracy_with_db_search.pdf"
    ),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(fdr_fig)

confidence_fig = plot_combined_confidence_distributions(
    metadata,
    "Confidence distribution for labelled data using database search",
    "correct",
)
fdr_fig.savefig(
    os.path.join(
        output_dir,
        f"{dataset_name}_labelled_confidence_distributions_with_db_search.png",
    ),
    dpi=300,
    bbox_inches="tight",
)
fdr_fig.savefig(
    os.path.join(
        output_dir,
        f"{dataset_name}_labelled_confidence_distributions_with_db_search.pdf",
    ),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(confidence_fig)

In [None]:
# Load labelled test set
metadata_path = f"../{SPECIES}_results/test_dataset.csv"
output_dir = os.path.dirname(metadata_path) + "/plots"
dataset_name = os.path.basename(metadata_path)
metadata = pd.read_csv(metadata_path)

# -- Load Winnow metrics
winnow_metrics_df = pd.read_csv(
    f"../{SPECIES}_results/test_dataset_calibrated_confidence_winnow_fdr.csv"
)

# -- Compute decoy metrics
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(
    dataset=metadata, residue_masses=RESIDUE_MASSES, correct_column="proteome_hit"
)
decoy_metrics_df = database_grounded_fdr_control.add_psm_fdr(
    metadata, "calibrated_confidence"
)
decoy_metrics_df = decoy_metrics_df[["spectrum_id", "psm_fdr"]]

# -- Precision-Recall curve
pr_fig = create_pr_curve_plot(
    metadata,
    "Precision-recall curve for labelled data using proteome mapping",
    "proteome_hit",
)
pr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_precision_recall.png"),
    dpi=300,
    bbox_inches="tight",
)
pr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_precision_recall.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(pr_fig)

# -- Calibration curves
cal_fig = plot_combined_calibration_curves(
    metadata,
    "Calibration curves for labelled data using proteome mapping",
    "proteome_hit",
)
cal_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_calibration_curves.png"),
    dpi=300,
    bbox_inches="tight",
)
cal_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_calibration_curves.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(cal_fig)

# -- FDR accuracy
fdr_fig = create_fdr_accuracy_plot(
    metadata,
    winnow_metrics_df,
    decoy_metrics_df,
    "FDR accuracy for labelled data using proteome mapping",
)
fdr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_fdr_accuracy.png"),
    dpi=300,
    bbox_inches="tight",
)
fdr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_fdr_accuracy.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(fdr_fig)

confidence_fig = plot_combined_confidence_distributions(
    metadata,
    "Confidence distribution for labelled data using proteome mapping",
    "proteome_hit",
)
confidence_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_confidence_distributions.png"),
    dpi=300,
    bbox_inches="tight",
)
confidence_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_labelled_confidence_distributions.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(confidence_fig)

In [None]:
# Load raw dataset
metadata_path = f"../{SPECIES}_results/raw_less_train.csv"
output_dir = os.path.dirname(metadata_path) + "/plots"
dataset_name = os.path.basename(metadata_path)
metadata = pd.read_csv(metadata_path)

# -- Load Winnow metrics
winnow_metrics_df = pd.read_csv(
    f"../{SPECIES}_results/raw_less_train_dataset_calibrated_confidence_winnow_fdr.csv"
)

# -- Compute decoy metrics
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(
    dataset=metadata, residue_masses=RESIDUE_MASSES, correct_column="proteome_hit"
)
decoy_metrics_df = database_grounded_fdr_control.add_psm_fdr(
    metadata, "calibrated_confidence"
)
decoy_metrics_df = decoy_metrics_df[["spectrum_id", "psm_fdr"]]

# -- Precision-Recall curve
pr_fig = create_pr_curve_plot(
    metadata,
    "Precision-recall curve for raw data using proteome mapping",
    "proteome_hit",
)
pr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_precision_recall.png"),
    dpi=300,
    bbox_inches="tight",
)
pr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_precision_recall.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(pr_fig)

# -- Calibration curves
cal_fig = plot_combined_calibration_curves(
    metadata, "Calibration curves for raw data using proteome mapping", "proteome_hit"
)
cal_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_calibration_curves.png"),
    dpi=300,
    bbox_inches="tight",
)
cal_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_calibration_curves.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(cal_fig)

# -- FDR accuracy
fdr_fig = create_fdr_accuracy_plot(
    metadata,
    winnow_metrics_df,
    decoy_metrics_df,
    "FDR accuracy for raw data using proteome mapping",
)
fdr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_fdr_accuracy.png"),
    dpi=300,
    bbox_inches="tight",
)
fdr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_fdr_accuracy.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(fdr_fig)

confidence_fig = plot_combined_confidence_distributions(
    metadata,
    "Confidence distribution for raw data using proteome mapping",
    "proteome_hit",
)
confidence_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_confidence_distributions.png"),
    dpi=300,
    bbox_inches="tight",
)
confidence_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_raw_confidence_distributions.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(confidence_fig)

In [None]:
# Load de novo dataset
metadata_path = f"../{SPECIES}_results/raw_filtered_dataset.csv"
output_dir = os.path.dirname(metadata_path) + "/plots"
dataset_name = os.path.basename(metadata_path)
metadata = pd.read_csv(metadata_path)

# -- Load Winnow metrics
winnow_metrics_df = pd.read_csv(
    f"../{SPECIES}_results/de_novo_dataset_calibrated_confidence_winnow_fdr.csv"
)

# -- Compute decoy metrics
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(
    dataset=metadata, residue_masses=RESIDUE_MASSES, correct_column="proteome_hit"
)
decoy_metrics_df = database_grounded_fdr_control.add_psm_fdr(
    metadata, "calibrated_confidence"
)
decoy_metrics_df = decoy_metrics_df[["spectrum_id", "psm_fdr"]]

# -- Precision-Recall curve
pr_fig = create_pr_curve_plot(
    metadata,
    "Precision-recall curve for "
    + r"$\mathit{de\ novo}$"
    + " data using proteome mapping",
    "proteome_hit",
)
pr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_precision_recall.png"),
    dpi=300,
    bbox_inches="tight",
)
pr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_precision_recall.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(pr_fig)

# -- Calibration curves
cal_fig = plot_combined_calibration_curves(
    metadata,
    "Calibration curves for " + r"$\mathit{de\ novo}$" + " data using proteome mapping",
    "proteome_hit",
)
cal_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_calibration_curves.png"),
    dpi=300,
    bbox_inches="tight",
)
cal_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_calibration_curves.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(cal_fig)

# -- FDR accuracy
fdr_fig = create_fdr_accuracy_plot(
    metadata,
    winnow_metrics_df,
    decoy_metrics_df,
    "FDR accuracy for " + r"$\mathit{de\ novo}$" + " data using proteome mapping",
)
fdr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_fdr_accuracy.png"),
    dpi=300,
    bbox_inches="tight",
)
fdr_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_fdr_accuracy.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(fdr_fig)

confidence_fig = plot_combined_confidence_distributions(
    metadata,
    "Confidence distribution for "
    + r"$\mathit{de\ novo}$"
    + " data using proteome mapping",
    "proteome_hit",
)
confidence_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_confidence_distributions.png"),
    dpi=300,
    bbox_inches="tight",
)
confidence_fig.savefig(
    os.path.join(output_dir, f"{dataset_name}_de_novo_confidence_distributions.pdf"),
    dpi=300,
    bbox_inches="tight",
)
plt.show()
plt.close(confidence_fig)

In [None]:
# for feature in feature_columns:
#     plt.hist(raw_metadata[raw_metadata["proteome_hit"] == True][feature], label="Unlabelled", color="green", alpha=0.5, bins=50, density=True)
#     plt.hist(labelled_metadata[labelled_metadata["proteome_hit"] == True][feature], label="Labelled", color="red", alpha=0.5, bins=50, density=True)
#     plt.xlabel(feature)
#     plt.ylabel("Density")
#     plt.legend()
#     plt.show()

# FDR Metrics for External Datasets

In [None]:
import ast


def _map_l_to_i_in_sequences(metadata: pd.DataFrame) -> pd.DataFrame:
    """Map L to I in sequences and predictions."""
    logger.info("Mapping L to I in sequences and predictions")

    def _replace_l_with_i(value):
        """Replace L with I in a value, handling both strings and lists."""
        if isinstance(value, str):
            return value.replace("L", "I")
        elif isinstance(value, list):
            return [
                token.replace("L", "I") if isinstance(token, str) else token
                for token in value
            ]
        return value

    for col in ["sequence", "prediction"]:
        if col in metadata.columns:
            metadata[col] = metadata[col].apply(_replace_l_with_i)

    return metadata


def _convert_object_columns(metadata: pd.DataFrame) -> pd.DataFrame:
    """Convert object columns that might contain string representations of Python objects."""

    def try_convert(value):
        try:
            return ast.literal_eval(value)
        except (ValueError, SyntaxError):
            return value  # Return original if conversion fails

    # Apply conversion to object (string) columns
    for col in metadata.select_dtypes(include=["object"]).columns:
        metadata[col] = metadata[col].apply(try_convert)

    return metadata

## C. elegans (PXD014877)

### Labelled

In [None]:
labelled_df = pd.read_csv("../new_model/results/labelled_PXD014877_results.csv")

# Convert object columns
labelled_df = _convert_object_columns(labelled_df)

# Map L to I in sequences and predictions
labelled_df = _map_l_to_i_in_sequences(labelled_df)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=labelled_df["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=labelled_df["calibrated_confidence"])

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(dataset=labelled_df, residue_masses=RESIDUE_MASSES)

logger.info(
    f"Database-grounded FDR threshold for raw confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(dataset=labelled_df, residue_masses=RESIDUE_MASSES)

logger.info(
    f"Database-grounded FDR threshold for calibrated confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

### Raw

In [None]:
de_novo_df = pd.read_csv("../new_model/results/raw_PXD014877_results.csv")

# Convert object columns
de_novo_df = _convert_object_columns(de_novo_df)

# Map L to I in sequences and predictions
de_novo_df = _map_l_to_i_in_sequences(de_novo_df)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=de_novo_df["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=de_novo_df["calibrated_confidence"])

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

## Immuno-2 (PXD023064)

### Labelled

In [None]:
labelled_df = pd.read_csv("../new_model/results/labelled_PXD023064_results.csv")

# Convert object columns
labelled_df = _convert_object_columns(labelled_df)

# Map L to I in sequences and predictions
labelled_df = _map_l_to_i_in_sequences(labelled_df)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=labelled_df["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=labelled_df["calibrated_confidence"])

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(dataset=labelled_df, residue_masses=RESIDUE_MASSES)

logger.info(
    f"Database-grounded FDR threshold for raw confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(dataset=labelled_df, residue_masses=RESIDUE_MASSES)

logger.info(
    f"Database-grounded FDR threshold for calibrated confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

### Raw

In [None]:
de_novo_df = pd.read_csv("../new_model/results/raw_PXD023064_results.csv")

# Convert object columns
de_novo_df = _convert_object_columns(de_novo_df)

# Map L to I in sequences and predictions
de_novo_df = _map_l_to_i_in_sequences(de_novo_df)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=de_novo_df["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=de_novo_df["calibrated_confidence"])

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

## General model test set

In [None]:
labelled_df = pd.read_csv("../new_model/results/labelled_general_results.csv")

# Convert object columns
labelled_df = _convert_object_columns(labelled_df)

# Map L to I in sequences and predictions
labelled_df = _map_l_to_i_in_sequences(labelled_df)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=labelled_df["confidence"])

logger.info(
    f"Winnow FDR threshold for raw confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
non_parametric_fdr_control = NonParametricFDRControl()
non_parametric_fdr_control.fit(dataset=labelled_df["calibrated_confidence"])

logger.info(
    f"Winnow FDR threshold for calibrated confidence: {non_parametric_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="confidence"
)
database_grounded_fdr_control.fit(dataset=labelled_df, residue_masses=RESIDUE_MASSES)

logger.info(
    f"Database-grounded FDR threshold for raw confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)

In [None]:
database_grounded_fdr_control = DatabaseGroundedFDRControl(
    confidence_feature="calibrated_confidence"
)
database_grounded_fdr_control.fit(dataset=labelled_df, residue_masses=RESIDUE_MASSES)

logger.info(
    f"Database-grounded FDR threshold for calibrated confidence: {database_grounded_fdr_control.get_confidence_cutoff(threshold=0.05)}"
)