# SLAVseq model report

In [None]:
import pickle, json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    precision_score,
    recall_score,
    precision_recall_curve,
    average_precision_score,
)
from flaml.automl.data import get_output_from_log
from sklearn.model_selection import StratifiedGroupKFold

## Read Data

In [None]:
data = pd.read_parquet(snakemake.input.data)  # type: ignore
data["KNRGL"] = data["KNRGL"].astype(bool)

# load model from pickle
with open(snakemake.input.model, "rb") as f:  # type: ignore
    clf = pickle.load(f)
features = clf.feature_names_in_.tolist()

# load hyperparameters from json
with open(snakemake.input.best_hp, "r") as f:  # type: ignore
    config = json.load(f)

## Hyperparameter tuning

In [None]:
# plot tuning curve
(
    time_history,
    best_valid_loss_history,
    valid_loss_history,
    config_history,
    metric_history,
) = get_output_from_log(
    filename=snakemake.input.history, time_budget=1e6  # type: ignore
)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(time_history, 1 - np.array(best_valid_loss_history))
ax.set_xlabel("hyperparameter iteration")
ax.set_ylabel("average precision")
ax.set_title("Hyperparameter tuning")

print("Best model: ", config["class"])
print("Best hyperparameters: ", config["hyperparameters"])

## Setup CV for model evaluation

In [None]:
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=snakemake.params.random_state)  # type: ignore
splitter = sgkf.split(data[features], data["KNRGL"], groups=data["Chromosome"])

In [None]:
# get CV chromosomes
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=snakemake.params.random_state)  # type: ignore
splitter = sgkf.split(data[features], data["KNRGL"], groups=data["Chromosome"])
folds = {}

fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
for i, (train_idx, test_idx) in enumerate(splitter):

    # get chromosomes, add to folds dictionary
    train_chroms = data.iloc[train_idx]["Chromosome"].unique().tolist()
    test_chroms = data.iloc[test_idx]["Chromosome"].unique().tolist()
    folds[i + 1] = (train_chroms, test_chroms)

    # fit classifier, plot feature importances
    print(f"Training in fold {i+1}/{sgkf.n_splits} on chromosomes: {train_chroms}")
    clf.fit(data.loc[train_idx, features], data.loc[train_idx, "KNRGL"])
    sns.stripplot(x=clf.feature_importances_, y=features, ax=ax1, alpha=0.5, c="blue")

    # make predictions
    for s, idx in [("train", train_idx), ("test", test_idx)]:
        print(
            f"Making predictions on {s} chromosomes: {data.loc[idx, 'Chromosome'].unique().tolist()}"
        )
        data.loc[idx, f"{s}_proba"] = clf.predict_proba(data.loc[idx, features])[:, 1]

ax1.set_xlabel("Feature importance")
ax1.set_xlim(0, None)

In [None]:
# define functions to evaluate performance
def prc_cv_plot(data, features, clf, retrain=False):
    """
    Plot precision-recall curve and precision-recall vs. threshold for each fold
    :param data: pd.DataFrame
    :param features: list
    :param clf: sklearn classifier
    :param retrain: bool, whether to retrain the classifier
    """
    # setup figure
    g, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    # iterate over folds
    for i, (train_chroms, test_chroms) in folds.items():
        test_idx = data[data["Chromosome"].isin(test_chroms)].index
        train_idx = data[data["Chromosome"].isin(train_chroms)].index

        if retrain:  # retrain classifier
            clf.fit(data.loc[train_idx, features], data.loc[train_idx, "KNRGL"])

        # make predictions on train and test set
        for s, idx in [("train", train_idx), ("test", test_idx)]:
            if retrain:
                y_score = clf.predict_proba(data.loc[idx, features])[:, 1]
            else:
                y_score = data.loc[idx, f"{s}_proba"]

            prec, rec, thresh = precision_recall_curve(data.loc[idx, "KNRGL"], y_score)
            if i == 1:
                label, color = ("train", "blue") if s == "train" else ("test", "orange")
            else:
                label, color = (None, "blue") if s == "train" else (None, "orange")

            ax1.plot(rec, prec, color=color, label=label)
            ax2.plot(thresh, prec[:-1], color=color, label=label)
            ax2.plot(thresh, rec[:-1], color=color, label=label, linestyle="--")

    ap = average_precision_score(data["KNRGL"], data["test_proba"])
    ax1.set(
        xlabel="Recall",
        ylabel="Precision",
        title=f"Precision-Recall curve: AP={ap:.3f}",
        xlim=(0, 1),
        ylim=(0, 1),
    )

    # add chance line
    ax1.plot(
        [0, 1], [data["KNRGL"].mean()] * 2, linestyle="--", color="gray", label="chance"
    )
    ax1.legend()
    ax2.set(
        xlabel="Threshold",
        ylabel="Score",
        title="Precision-Recall vs. threshold",
        xlim=(0, 1),
        ylim=(0, 1),
    )

    return g


def cm_plot(data: pd.DataFrame):
    """
    Plot confusion matrices for different thresholds
    :param data: pd.DataFrame
    """

    # setup figure
    g, axes = plt.subplots(1, 5, figsize=(25, 5), sharey=True, sharex=True)

    # iterate over thresholds
    for i, p in enumerate([0.5, 0.6, 0.7, 0.8, 0.9]):
        y_true = data["KNRGL"]
        y_pred = data["test_proba"] > p
        cm = confusion_matrix(y_true, y_pred)

        # plot
        ConfusionMatrixDisplay(cm, display_labels=["OTHER", "KNRGL"]).plot(ax=axes[i])
        axes[i].images[-1].colorbar.remove()  # remove colorbar
        precision = precision_score(data["KNRGL"], data["test_proba"] > p)
        recall = recall_score(data["KNRGL"], data["test_proba"] > p)
        axes[i].set_title(f"{p}: precision = {precision:.3f}, recall = {recall:.3f}")
    return g

## Precision/Recall

In [None]:
# setup plot
g = prc_cv_plot(data, features, clf, retrain=False)
g = cm_plot(data)

## Precision/Recall, removing within germline dist

In [None]:
# TODO: add recall penalty for removing these
# TODO: test with and without retraining
for gd in [5000, 10000, 20000]:
    nrm = data.query("germline_dist <= @gd and label == 'KNRGL'").shape[0]
    total = data.query("label == 'KNRGL'").shape[0]
    print(
        f"Removing {nrm}/{total} ({((nrm/total)*100):.2f}%) germline variants with distance <= {gd}bp"
    )
    df = data.query("germline_dist > @gd").reset_index(drop=True)
    g = prc_cv_plot(df, features, clf, retrain=False)
    g.suptitle(
        f"Performance after removing {nrm}/{total} ({((nrm/total)*100):.2f}%) germline variants with distance <= {gd}bp"
    )
    g = cm_plot(df)

In [None]:
data.to_parquet(snakemake.output[0])  # type: ignore