# Testing the SLAV-calling model

In [None]:
from pathlib import Path, PosixPath
from tempfile import NamedTemporaryFile
import warnings

# data
import numpy as np

print(f"numpy: {np.__version__}")
import pandas as pd

print(f"pandas: {pd.__version__}")
import pyarrow
import pyarrow.parquet as pq

print(f"pyarrow: {pyarrow.__version__}")

# ML
from sklearn import metrics, model_selection, ensemble
from flaml import AutoML
from flaml.automl.data import get_output_from_log

# plotting
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import seaborn as sns

# custom
from scripts.get_labels import label_windows
from scripts.fit import SampleChrSplitter

## Define functions for model testing

Evaluation Strategy: Train on 1/2 donors and 1/2 chromosomes, test on the other half of donors and chromosomes.

Tuning Spliting Strategy:
	
1. Tune within each training set? Splitting again on chromosomes and donors
2. Tune on CommonBrain donor, splitting on chromosomes and cells

In [None]:
def auprc(y_true, y_pred, pos_label: int, sample_weight=None):
    "Area under the precision-recall curve."

    prc = metrics.precision_recall_curve(
        y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
    )
    auprc = metrics.auc(prc[1], prc[0])

    return auprc

In [None]:
def auprc_flaml(
    X_val,
    y_val,
    estimator,
    labels,
    X_train,
    y_train,
    weight_val=None,
    weight_train=None,
    config=None,
    groups_val=None,
    groups_train=None,
):
    "Custom auprc metric for FLAML"
    # TODO: not working, y should be 1d array but is 2d

    y_pred = estimator.predict_proba(X_val)
    val_auprc = auprc(y_val, y_pred, pos_label=1, sample_weight=weight_val)
    y_pred = estimator.predict_proba(X_train)
    train_auprc = auprc(y_train, y_pred, pos_label=1, sample_weight=weight_train)
    return 1 - val_auprc, {"train_auprc": 1 - train_auprc, "val_auprc": 1 - val_auprc}

## Load and Clean Data

In [None]:
def read_data(filename: PosixPath):
    """
    Reads a parquet file, processes labels, shrinks data types, and checks for inf/nan values, returning a pandas dataframe.
    :param filename: path to parquet file
    """

    # read with pyarrow
    assert filename.suffix == ".parquet" or ".pqt", "filename must be a parquet file"
    df = pq.read_table(filename).to_pandas()

    # fix labels
    df["label"] = df["label"].replace({"xtea_1kb_3end": "knrgl"})
    df["label"] = df["label"].apply(lambda x: "unknown" if x != "knrgl" else x)
    df["label_encoded"] = df["label"].map({"knrgl": 1, "unknown": 0})

    # cleanup columns
    df = df.drop(columns=[c for c in df.columns if df[c].dtype == bool])
    df = df.loc[df["rpm"] >= 2]
    for c in df.columns:
        if c == "rpm":
            min_val = np.finfo(np.float32).min
            max_val = np.finfo(np.float32).max
            df[c] = np.clip(df[c], min_val, max_val).astype(np.float32)
        elif df[c].dtype == np.float64:
            df[c] = df[c].astype(np.float16)
        elif df[c].dtype == np.int64:
            df[c] = df[c].astype(np.int32)  # must use int32 for chromosomal positions
        if (df[c].dtype == np.float16) or (df[c].dtype == np.int32):
            assert not np.isinf(df[c]).any(), f"{c} column contains inf values"
            assert not df[c].isna().any(), f"{c} column contains nan values"

    return df

In [None]:
# takes ~40 min, 120 GB RAM
DATADIR = Path("../results/model/get_labels/")
data = pd.concat([read_data(f) for f in DATADIR.rglob("*pqt")])

# remove low quality cells
with open("../resources/bad_cells.txt", "r") as f:
    bad_cells = [line.strip() for line in f.readlines()]
data = data[~data["cell_id"].isin(bad_cells)]

# keep autosomes
data = data.loc[data["Chromosome"].isin([f"chr{i}" for i in range(1, 23)])]

# keep windows with at least 2 reads-per-millions
data = data.loc[data["rpm"] >= 2]

# remove blacklist region
mhc = pd.read_csv(
    "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/genome-stratifications/v3.0/GRCh38/OtherDifficult/GRCh38_MHC.bed.gz",
    sep="\t",
    header=None,
    skiprows=1,
    names=["Chromosome", "Start", "End"],
)
kir = pd.read_csv(
    "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/genome-stratifications/v3.0/GRCh38/OtherDifficult/GRCh38_KIR.bed.gz",
    sep="\t",
    header=None,
    skiprows=1,
    names=["Chromosome", "Start", "End"],
)
trs = pd.read_csv(
    "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/genome-stratifications/v3.0/GRCh38/LowComplexity/GRCh38_AllTandemRepeats_201to10000bp_slop5.bed.gz",
    sep="\t",
    header=None,
    skiprows=1,
    names=["Chromosome", "Start", "End"],
)
segdups = pd.read_csv(
    "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/genome-stratifications/v3.0/GRCh38/SegmentalDuplications/GRCh38_segdups.bed.gz",
    sep="\t",
    header=None,
    skiprows=1,
    names=["Chromosome", "Start", "End"],
)
gaps = pd.read_csv(
    "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/genome-stratifications/v3.0/GRCh38/OtherDifficult/GRCh38_gaps_slop15kb.bed.gz",
    sep="\t",
    header=None,
    skiprows=1,
    names=["Chromosome", "Start", "End"],
)
false_dup = pd.read_csv(
    "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/genome-stratifications/v3.0/GRCh38/OtherDifficult/GRCh38_false_duplications_correct_copy.bed.gz",
    sep="\t",
    header=None,
    skiprows=1,
    names=["Chromosome", "Start", "End"],
)
blacklist = pd.concat([mhc, trs, segdups, gaps, false_dup, kir])

data = label_windows(data, blacklist, "blacklist")
data = data.loc[data["blacklist"] == False]

In [None]:
# plot the distribution of reads per million and number of reads for each donor by label

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

for donor_id, df in data.groupby("donor_id"):
    sns.ecdfplot(
        df,
        x="rpm",
        hue="label",
        ax=ax1,
        label=donor_id,
        alpha=0.5,
        complementary=True,
        stat="count",
        log_scale=(True, True),
    )
    sns.ecdfplot(
        df,
        x="n_reads",
        hue="label",
        ax=ax2,
        label=donor_id,
        alpha=0.5,
        complementary=True,
        stat="count",
        log_scale=(True, True),
    )

In [None]:
# plot the distribution of reads per million and number of reads for each donor by label

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

for cell_id, df in data[data["donor_id"] == "CommonBrain"].groupby("cell_id"):
    sns.ecdfplot(
        df,
        x="rpm",
        hue="label",
        ax=ax1,
        alpha=0.5,
        complementary=True,
        stat="count",
        log_scale=(True, True),
    )
    sns.ecdfplot(
        df,
        x="n_reads",
        hue="label",
        ax=ax2,
        alpha=0.5,
        complementary=True,
        stat="count",
        log_scale=(True, True),
    )

## Filtering and Feature selection

In [None]:
# define the features
features = []
keys = ["_mean", "frac", "gini", "bias"]
for c in data.columns:
    for k in keys:
        if k in c:
            features.append(c)
print("Features:", features)

In [None]:
def fit(
    train: pd.DataFrame,
    test: pd.DataFrame,
    features: list,
    time_budget: int,
    axes=None,
    label: str = None,
    **kwargs,
):
    warnings.filterwarnings("ignore", category=UserWarning)

    # check that features are in data
    for f in features + ["label_encoded", "donor_id"]:
        assert f in data.columns, f"{f} not in data.columns"

    # split data into train and tune for hyperparameter tuning
    sgkf = SampleChrSplitter(X=train, y=train["label_encoded"], sample_col="donor_id")

    with NamedTemporaryFile() as logfile:
        clf = AutoML()
        print("fitting model...")
        clf.fit(
            task="classification",
            X_train=train[features],
            y_train=train["label_encoded"],
            n_jobs=16,
            estimator_list=["xgboost", "rf", "xgb_limitdepth"],
            early_stop=True,
            eval_method="cv",
            split_type=sgkf,
            log_file_name=logfile.name,
            time_budget=time_budget,  # time budget in seconds
            verbose=0,
            **kwargs,
        )
        print("done")
        # get learning curve
        (
            time_history,
            best_valid_loss_history,
            valid_loss_history,
            config_history,
            metric_history,
        ) = get_output_from_log(filename=logfile.name, time_budget=time_budget)

    print("best estimator:", clf.best_estimator)
    print("best config:", clf.best_config)
    print("best loss:", clf.best_loss)

    # make subplots
    print("plotting...")
    if not axes:
        _, axes = plt.subplots(2, 2, figsize=(10, 10))
        plt.tight_layout()
    else:
        assert axes.shape == (2, 2), "axes must be 2x2"

    # class balance
    train_counts = train.value_counts("label").reset_index(name="count")
    test_counts = test.value_counts("label").reset_index(name="count")
    plot_df = pd.concat(
        [train_counts.assign(stage="train"), test_counts.assign(stage="test")]
    )
    plot_df["percent"] = (
        plot_df["count"] / plot_df.groupby("stage")["count"].transform("sum")
    ) * 100
    sns.barplot(
        x="stage",
        y="percent",
        data=plot_df,
        order=["train", "test"],
        hue="label",
        ax=axes[0, 0],
    )
    axes[0, 0].title.set_text("Class Balance")
    axes[0, 0].set_xlabel("Label")
    axes[0, 0].set_ylabel("Percent")
    axes[0, 0].set_yscale("log")
    # add counts to bars
    for p in axes[0, 0].patches:
        axes[0, 0].annotate(
            f"{p.get_height():.2f}%",
            (p.get_x() + p.get_width() / 2.0, p.get_height()),
            ha="center",
            va="center",
            xytext=(0, 10),
            textcoords="offset points",
        )
    axes[0, 0].yaxis.set_major_formatter(mtick.PercentFormatter())

    # learning curve
    axes[0, 1].step(time_history, 1 - np.array(best_valid_loss_history), where="post")
    axes[0, 1].title.set_text("Learning Curve")
    axes[0, 1].set_xlabel("Wall Clock Time (s)")
    axes[0, 1].set_ylabel("Test Average Precision")

    # metrics
    metrics.PrecisionRecallDisplay.from_estimator(
        clf, test[features], test["label_encoded"], ax=axes[1, 0]
    )
    metrics.RocCurveDisplay.from_estimator(
        clf, test[features], test["label_encoded"], ax=axes[1, 1]
    )

    return clf

In [None]:
res = []

Path("../results/20210731_model_testing").mkdir(exist_ok=True)
for rpm in [10, 50, 100, 500, 1000]:
    my_data = data.loc[data["rpm"] >= rpm, :]
    sgkf = SampleChrSplitter(
        X=my_data, y=my_data["label_encoded"], sample_col="donor_id"
    ).split(my_data)
    train_idx, test_idx = next(sgkf)
    clf = fit(
        train=my_data.iloc[train_idx, :].reset_index(drop=True),
        test=my_data.iloc[test_idx, :].reset_index(drop=True),
        features=features,
        time_budget=600,
        metric="ap",
        skip_transform=True,  # don't preprocess data
        auto_augment=False,  # don't augment rare classes
        starting_points="static",  # use data-independent hyperparameterstarting points
    )

    # append results to list
    y_test = my_data.iloc[test_idx, :]["label_encoded"]
    y_pred = clf.predict_proba(my_data.iloc[test_idx, :][features])
    res.append(
        {"rpm": rpm, "ap": metrics.average_precision_score(y_test, y_pred[:, 1])}
    )
    plt.savefig(f"../results/20210731_model_testing/rpm_{rpm}.png")

sns.lineplot(data=pd.DataFrame(res), x="rpm", y="ap")

In [None]:
data = data.loc[data["rpm"] > 20]
sgkf = SampleChrSplitter(X=data, y=data["label_encoded"], sample_col="donor_id").split(
    data
)
train_idx, test_idx = next(sgkf)
clf = fit(
    train=data.iloc[train_idx, :].reset_index(drop=True),
    test=data.iloc[test_idx, :].reset_index(drop=True),
    features=features,
    time_budget=100,
    metric="ap",
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
)

## Tune and evaluate XGBoost model


In [None]:
def nested_CV(data, features, time_budget=600, **kwargs):
	"""
	Perform 4-fold nested cross-validation to tune hyperparameters and evaluate model performance.
	:param data: pd.DataFrame with features and labels
	:param features: list of feature names, corresponding to columns in data
	:param time_budget: time budget in seconds for tuning per fold
	"""

	# initialize subplots
	_, axs = plt.subplot_mosaic([["TrClassBal","TrPRC","TrROC","LC"],["VaClassBal","VaPRC","VaROC","LC"]], figsize=(17, 8))
	plt.subplots_adjust(wspace=0.3, hspace=0.3)
	

	# split data into train and test for evaluation
	eval_sgkf = SampleChrSplitter(X=data, y=data["label_encoded"], sample_col="donor_id")

	for i, (train_idx, test_idx) in enumerate(eval_sgkf.split(data)):
		print(f"Fold {i+1}")
		train = data.iloc[train_idx, :].reset_index(drop=True)
		valid = data.iloc[test_idx, :].reset_index(drop=True)
			
		print(f"Training on {train.shape[0]} windows")
		print("Training on donors:", train["donor_id"].unique())
		print("Training on chromosomes:", train["Chromosome"].unique())

		clf = fit(train, features, time_budget=time_budget, **kwargs)

		# PLOTS
		label = f"Fold {i+1}: {clf.best_estimator}"

		metrics.PrecisionRecallDisplay.from_estimator(clf, train[features], train["label_encoded"], ax=axs["TrPRC"], name=label)
		metrics.RocCurveDisplay.from_estimator(clf, train[features], train["label_encoded"], ax=axs["TrROC"], name=label)
		metrics.PrecisionRecallDisplay.from_estimator(clf, valid[features], valid["label_encoded"], ax=axs["VaPRC"], name=label)
		metrics.RocCurveDisplay.from_estimator(clf, valid[features], valid["label_encoded"], ax=axs["VaROC"], name=label)

		# learning curve
		axs["LC"].step(clf.time_history, 1- np.array(clf/.best_valid_loss_history), label=label)

		# class balance
		train_bal = train.value_counts("label").reset_index().rename(columns={0:"count"})
		axs["TrClassBal"].bar(i, train_bal["count"][0], label=train_bal["label"][0])
		axs["TrClassBal"].bar(i+1, train_bal["count"][1], label=train_bal["label"][1])
		valid_bal = valid.value_counts("label").reset_index().rename(columns={0:"count"})
		axs["VaClassBal"].bar(i, valid_bal["count"], label=valid_bal["label"][0])
		axs["VaClassBal"].bar(i+1, valid_bal["count"], label=valid_bal["label"][1])
	
	
	axs["TrClassBal"].title.set_text("Training Class Balance")
	axs["TrClassBal"].legend()
	axs["TrClassBal"].set_yscale("log")

	axs["VaClassBal"].title.set_text("Validation Class Balance")
	axs["VaClassBal"].legend()
	axs["VaClassBal"].set_yscale("log")

	axs["TrPRC"].title.set_text("Training Precision-Recall")
	axs["TrPRC"].set_xlabel("Recall")
	axs["TrPRC"].set_ylabel("Precision")

	axs["TrROC"].title.set_text("Training ROC")
	axs["TrROC"].set_xlabel("False Positive Rate")
	axs["TrROC"].set_ylabel("True Positive Rate")

	axs["VaPRC"].title.set_text("Validation Precision-Recall")
	axs["VaPRC"].set_xlabel("Recall")
	axs["VaPRC"].set_ylabel("Precision")

	axs["VaROC"].title.set_text("Validation ROC")
	axs["VaROC"].set_xlabel("False Positive Rate")
	axs["VaROC"].set_ylabel("True Positive Rate")

	axs["LC"].title.set_text("Learning curve")
	axs["LC"].set_xlabel("Wall Clock Time (s)")
	axs["LC"].set_ylabel("Test Average Precision")
	

In [None]:
nested_CV(
    data[data["rpm"] > 5],
    features + ["rpm"],
    time_budget=600,
    metric="ap",
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
)

In [None]:
nested_CV(
    data[data["rpm"] > 10],
    features,
    time_budget=600,
    metric="ap",
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
)

In [None]:
nested_CV(
    data[data["rpm"] > 10],
    features + ["rpm"],
    time_budget=600,
    metric="ap",
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
)

TODOs:

Evaluate performance as a function of:
1. donors included
2. windows/donor
3. Chromosomes
4. n_reads (or rpm) 


With/without
1. rpm as a feature
2. skip_transform=False
3. auto_augment=True
4. starting_points="dynamic"
5. EN score as a feature
6. quantiles vs mean

Compute FP/FN rate!
