In [None]:
import pickle
import numpy as np
import pandas as pd
import shutil

import seaborn as sns
import matplotlib.pyplot as plt

# ML
import shap
from flaml import AutoML
from flaml import tune
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    precision_score,
    recall_score,
    precision_recall_curve,
    average_precision_score,
)

## Define model and hyperparameter space

In [None]:
# define custom hp
custom_hp = {
    "xgboost": {
        "n_estimators": {
            "domain": tune.randint(lower=100, upper=500),
            "low_cost_init_value": 250,
        },
        "max_leaves": {
            "domain": tune.lograndint(lower=6, upper=1024),
            "low_cost_init_value": 50,
        },
        "max_depth": {
            "domain": tune.randint(lower=3, upper=100),
            "low_cost_init_value": 5,
        },
    }
}

# gpu params if gpu is detected
if shutil.which("nvidia-smi"):
    custom_hp["xgboost"]["tree_method"] = {
        "init_value": "gpu_hist",
        "domain": tune.choice(["gpu_hist"]),
    }

    custom_hp["xgboost"]["device"] = {
        "init_value": "cuda",
        "domain": tune.choice(["cuda"]),
    }

## Define features

In [None]:
# read data
data = pd.read_parquet(snakemake.input.data)  # type: ignore

# define features
features = snakemake.params.features

for f in features:
    assert f in data.columns, f"Feature {f} not in dataframe columns"

assert len(features) == len(set(features)), "some features are duplicated"

print(f"Training model with {len(features)} features: {features}")

## Tune the model's hyperparameters

In [None]:
# define model
automl = AutoML(
    task="binary",
    estimator_list=["xgboost"],
    metric="ap",
    eval_method="cv",
    max_iter=snakemake.params.max_iter,  # type: ignore
    n_jobs=snakemake.threads,  # type: ignore
    skip_transform=False,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    early_stop=True,
    retrain_full=False,
    verbose=4,
    seed=123,
    log_training_metric=True,
    custom_hp=custom_hp,
    log_file_name=snakemake.output.history,  # type: ignore
)

# tune model
sgkf = StratifiedGroupKFold(
    n_splits=5,
    shuffle=True,
    random_state=snakemake.params.random_state,  # type: ignore
)

data["KNRGL"] = data["label"] == "KNRGL"

automl.fit(
    X_train=data[features],
    y_train=data["KNRGL"],
    groups=data["Chromosome"],
    split_type=sgkf,
)

print("Final model:")
automl.model.estimator

In [None]:
from flaml.automl.data import get_output_from_log

# plot tuning curve
(
    time_history,
    best_valid_loss_history,
    valid_loss_history,
    config_history,
    metric_history,
) = get_output_from_log(
    filename=snakemake.output.history, time_budget=1e6  # type: ignore
)
train_loss_history = [x["train_loss"] for x in metric_history]

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(valid_loss_history)
ax.plot(train_loss_history)
ax.set_xlabel("hyperparameter iteration")
ax.set_ylabel("loss")
ax.set_title("Hyperparameter tuning")

# add legend
ax.legend(["validation", "train"])

## Save model and best hp

In [None]:
# save model and hyperparameters
clf = automl.model.estimator
with open(snakemake.output.model, "wb") as f:  # type: ignore
    pickle.dump(clf, f, pickle.HIGHEST_PROTOCOL)
automl.save_best_config(snakemake.output.best_hp)  # type: ignore

## Make predictions in CV

TODO: add shap feature importances

In [None]:
# get CV chromosomes
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=snakemake.params.random_state)  # type: ignore
splitter = sgkf.split(data[features], data["KNRGL"], groups=data["Chromosome"])
folds = {}
shap_values = pd.DataFrame(np.nan, index=data.index, columns=features)

for i, (train_idx, test_idx) in enumerate(splitter):

    # get chromosomes, add to folds dictionary
    train_chroms = data.iloc[train_idx]["Chromosome"].unique().tolist()
    test_chroms = data.iloc[test_idx]["Chromosome"].unique().tolist()
    folds[i + 1] = (train_chroms, test_chroms)

    # fit classifier, plot feature importances
    print(f"Fold {i+1}/{sgkf.n_splits}:")
    print(f"train set: {train_chroms}")
    print(f"test set: {test_chroms}")
    print("training...")
    clf.fit(data.loc[train_idx, features], data.loc[train_idx, "KNRGL"])

    # make predictions
    print("getting SHAP feature importance values on training set...")
    explainer = shap.TreeExplainer(clf)
    shap_values.loc[train_idx] = explainer.shap_values(data[features].loc[train_idx])
    for s, idx in [("train", train_idx), ("test", test_idx)]:
        print(f"making predictions on {s} set, {len(idx)} peaks...")
        data.loc[idx, f"{s}_proba"] = clf.predict_proba(data.loc[idx, features])[:, 1]

In [None]:
# Set the figure size
shap.summary_plot(
    shap_values.values,
    data[features].values,
    feature_names=features,
    plot_type="violin",
    plot_size=(10, 10),
)

## Plot performance

In [None]:
# define functions to evaluate performance
def prc_cv_plot(data, features, clf, retrain=False):
    """
    Plot precision-recall curve and precision-recall vs. threshold for each fold
    :param data: pd.DataFrame
    :param features: list
    :param clf: sklearn classifier
    :param retrain: bool, whether to retrain the classifier
    """
    # setup figure
    g, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    # iterate over folds
    for i, (train_chroms, test_chroms) in folds.items():
        test_idx = data[data["Chromosome"].isin(test_chroms)].index
        train_idx = data[data["Chromosome"].isin(train_chroms)].index

        if retrain:  # retrain classifier
            clf.fit(data.loc[train_idx, features], data.loc[train_idx, "KNRGL"])

        # make predictions on train and test set
        for s, idx in [("train", train_idx), ("test", test_idx)]:
            if retrain:
                y_score = clf.predict_proba(data.loc[idx, features])[:, 1]
            else:
                y_score = data.loc[idx, f"{s}_proba"]

            prec, rec, thresh = precision_recall_curve(data.loc[idx, "KNRGL"], y_score)
            if i == 1:
                label, color = ("train", "blue") if s == "train" else ("test", "orange")
            else:
                label, color = (None, "blue") if s == "train" else (None, "orange")

            ax1.plot(rec, prec, color=color, label=label)
            ax2.plot(thresh, prec[:-1], color=color, label=label)
            ax2.plot(thresh, rec[:-1], color=color, label=label, linestyle="--")

    ap = average_precision_score(data["KNRGL"], data["test_proba"])
    ax1.set(
        xlabel="Recall",
        ylabel="Precision",
        title=f"Precision-Recall curve: AP={ap:.3f}",
        xlim=(0, 1),
        ylim=(0, 1),
    )

    # add chance line
    ax1.plot(
        [0, 1], [data["KNRGL"].mean()] * 2, linestyle="--", color="gray", label="chance"
    )
    ax1.legend()
    ax2.set(
        xlabel="Threshold",
        ylabel="Score",
        title="Precision-Recall vs. threshold",
        xlim=(0, 1),
        ylim=(0, 1),
    )

    return g


def cm_plot(data: pd.DataFrame):
    """
    Plot confusion matrices for different thresholds
    :param data: pd.DataFrame
    """

    # setup figure
    g, axes = plt.subplots(1, 5, figsize=(25, 5), sharey=True, sharex=True)

    # iterate over thresholds
    for i, p in enumerate([0.5, 0.6, 0.7, 0.8, 0.9]):
        y_true = data["KNRGL"]
        y_pred = data["test_proba"] > p
        cm = confusion_matrix(y_true, y_pred)

        # plot
        ConfusionMatrixDisplay(cm, display_labels=["OTHER", "KNRGL"]).plot(ax=axes[i])
        axes[i].images[-1].colorbar.remove()  # remove colorbar
        precision = precision_score(data["KNRGL"], data["test_proba"] > p)
        recall = recall_score(data["KNRGL"], data["test_proba"] > p)
        axes[i].set_title(f"{p}: precision = {precision:.3f}, recall = {recall:.3f}")
    return g

In [None]:
# setup plot
g = prc_cv_plot(data, features, clf, retrain=False)
g = cm_plot(data)

TODO: plot germline dist and clonality vs score

In [None]:
data.to_parquet(snakemake.output.predictions)  # type: ignore