# Test SLAV-calling model on 5 individuals

In [None]:
# basic
import os

print(f"Number of CPUs in this system: {os.cpu_count()}")

from pathlib import Path
from collections import defaultdict
from tqdm import tqdm

# data
import numpy as np

print(f"numpy: {np.__version__}")

import pandas as pd

print(f"pandas: {pd.__version__}")
import pyranges as pr

import pyarrow
import pyarrow.parquet as pq

print(f"pyarrow: {pyarrow.__version__}")

# ML
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
from flaml import AutoML

# import ray

# print(f"ray: {ray.__version__}")
# print("initializing ray...")
# if ray.is_initialized():
#     ray.shutdown()
# ray.init()

# plotting
import matplotlib.pyplot as plt
import seaborn as sns


# custom
from scripts.get_labels import label_windows

## Read data

### RepeatMasker, KNRGL, and blacklists

In [None]:
rmsk = pd.read_csv(
    "/iblm/netapp/data4/mcuoco/sz_slavseq/resources/rmsk_1kb_3end.bed",
    sep="\t",
    usecols=[0, 1, 2],
    names=["Chromosome", "Start", "End"],
)

GRC_blacklist = pd.read_csv(
    "ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/001/405/GCA_000001405.15_GRCh38/seqs_for_alignment_pipelines.ucsc_ids/GCA_000001405.15_GRCh38_GRC_exclusions.bed",
    sep="\t",
    skiprows=3,
)
GRC_blacklist.rename(
    columns={"#sequence": "Chromosome", "sequenceStart": "Start", "sequenceEnd": "End"},
    inplace=True,
)
segdups = pr.read_bed(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/cf.10xgenomics.com/supp/genome/GRCh38/segdups.bedpe"
).df
sv_blacklist = pr.read_bed(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/cf.10xgenomics.com/supp/genome/GRCh38/sv_blacklist.bed"
).df

In [None]:
# get chromosomes and lengths
chr_lengths = {}
with open("/iblm/netapp/data4/mcuoco/sz_slavseq/resources/hs38d1.fa.fai") as f:
    for line in f:
        l = line.strip().split("\t")
        chr_lengths[l[0]] = int(l[1])

# remove chromosomes not in rmsk
chr_rm = []
for c in chr_lengths.keys():
    if c not in rmsk.Chromosome.unique():
        chr_rm.append(c)

for c in chr_rm:
    del chr_lengths[c]

### Read and label bam windows (takes ~3 hrs)

In [None]:
# get KNRGL and window files for each donor
# donors = ["CommonBrain",1,3,4,8,27]
donors = ["CommonBrain"]
donor_files = {}
for d in donors:
    donor_files[d] = defaultdict(list)
    for f in Path("/iblm/netapp/data4/mcuoco/sz_slavseq/resources").rglob(
        f"{d}_insertions_1kb_3nd.bed"
    ):
        donor_files[d]["KNRGL"] = str(f)
    for f in Path(
        f"/iblm/netapp/data4/mcuoco/sz_slavseq/results/model/get_features/{d}"
    ).rglob("*.pqt"):
        donor_files[d]["features"].append(str(f))

In [None]:
# define function to read in the data
# filter out windows with nr1 < 4 and decoy chromosomes
def read_data(filename: str):
    return pq.read_table(
        filename,
        filters=[
            ("nr1", ">=", 1),
            ("Chromosome", "not in", chr_rm),
        ],
    ).to_pandas()


def collate_labels(x):
    if x.RMSK:
        return "RMSK"
    elif x.RMSK_10kb:
        return "RMSK_10kb"
    elif x.KNRGL:
        return "KNRGL"
    elif x.KNRGL_10kb:
        return "KNRGL_10kb"
    elif x.GRC_blacklist or x.SV_blacklist or x.segdups:
        return "blacklist"
    else:
        return "other"


# iterate over donors and read in the data
# label the windows
df = []
for d in donors:
    print(f"Reading {len(donor_files[d]['features'])} files for donor {d} ..")
    ddf = pd.concat([read_data(f) for f in tqdm(donor_files[d]["features"])])
    knrgl = pd.read_csv(
        donor_files[d]["KNRGL"],
        sep="\t",
        names=["Chromosome", "Start", "End"],
        usecols=[0, 1, 2],
    )
    ddf = label_windows(ddf, knrgl, "KNRGL")
    # ddf = label_windows(ddf, knrgl, "KNRGL_10kb")
    ddf = label_windows(ddf, rmsk, "RMSK")
    # ddf = label_windows(ddf, rmsk_10kb, "RMSK_10kb")
    ddf = label_windows(ddf, GRC_blacklist, "GRC_blacklist")
    ddf = label_windows(ddf, segdups, "segdups")
    ddf = label_windows(ddf, sv_blacklist, "SV_blacklist")
    ddf["label"] = ddf.apply(collate_labels, axis=1)

    df.append(ddf)


df = pd.concat(df)

# remove low quality cells
bad_cells = [
    "USD3_F3_S151",
    "USD3_C3_S148",
    "USD3_E3_S150",
    "USD3_C5_S163",
    "USD3_G3_S152",
    "USD4E2_S141",
    "plate2_E7_S101",
    "ush8_D7_S51",
    "ush27_A7_S159",
]
df = df.loc[~df["cell_id"].isin(bad_cells), :]

In [None]:
# # remove windows with any of the above labels
# df = df.loc[~df[["GRC_blacklist", "segdups", "sv_blacklist", "rmsk"]].any(axis=1), :]
# df.drop(["GRC_blacklist", "segdups", "sv_blacklist", "rmsk"], axis=1, inplace=True)

### Read labelled windows

In [None]:
# read in the data
df = pq.read_table("5donors_labelled.pqt").to_pandas()
df.shape

## Check distribution of classes

TODO: 
- plot absolute window numbers as a function of read 1 threshold colored by class
- plot ratio
- include false negatives

In [None]:
# make subplots
fig, axes = plt.subplots(len(donors), 3, figsize=(13, 4 * len(donors)))
fig.tight_layout(w_pad=5.0, h_pad=4.0)
hue_order = df["label"].unique()

for i, d in enumerate(df["donor_id"].unique()):
    ddf = df.loc[(df["donor_id"] == d), :]
    print(f"Number of windows for {d}: {ddf.shape[0]}")
    title = f"donor {d}"
    sns.ecdfplot(
        data=ddf, x="nr1", hue="label", hue_order=hue_order, ax=axes[i, 0], legend=False
    ).set(xscale="log", title=title, xlabel="# Read 1")
    sns.ecdfplot(
        data=ddf,
        x="nr1",
        hue="label",
        hue_order=hue_order,
        ax=axes[i, 1],
        legend=False,
        stat="count",
        complementary=True,
    ).set(xscale="log", yscale="log", title=title, xlabel="# Read 1")
    # sns.histplot(data=ddf, x="nr1", hue="label", hue_order=hue_order, ax=axes[i,1], log_scale=True, fill=False, bins=100, element="step", legend=False).set(yscale="log", title=title, xlabel="# Read 1")

    plot_df = (
        ddf.groupby(["label", "cell_id"])
        .size()
        .reset_index()
        .rename({0: "count"}, axis=1)
    )

    sns.stripplot(
        data=plot_df,
        y="label",
        x="count",
        alpha=0.5,
        hue="label",
        hue_order=hue_order,
        ax=axes[i, 2],
        legend=False,
    ).set(xscale="log", title=title, xlabel="# Windows / cell", ylabel=None)

## Tune model's hyperparameters using [Microsoft's FLAML library](https://microsoft.github.io/FLAML/)

In [None]:
features = []
keys = ["ML", "MA", "MS", "ME", "AS", "frac", "bias"]
for c in df.columns:
    for k in keys:
        if (k in c) and ("r2" not in c) and ("prop" not in c):
            features.append(c)

# encode labels
df["label_encoded"] = df["label"].astype(int)

In [None]:
# set flaml settings
# NOTE: Don't try logistic regression, it's too slow, doesn't converge, and doesn't perform well
flaml_settings = dict(
    task="classification",
    n_jobs=16,
    estimator_list=["xgboost", "rf"],
    early_stop=True,
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
    log_training_metric=True,
)

In [None]:
# define data to tune on
eval_chr = "chr1"
eval_data = df.loc[(df["Chromosome"] == eval_chr), :].reset_index()
# eval_data = df.loc[(df["donor_id"] != "CommonBrain") & (df["Chromosome"] == eval_chr) & (df["label"] != "RMSK"), :].reset_index()

# define data for final evaluation
tune_data = df.loc[
    (df["donor_id"] == "CommonBrain") & (df["Chromosome"] != eval_chr), :
].reset_index()
# tune_data = df.loc[(df["donor_id"] == "CommonBrain") & (df["Chromosome"] != eval_chr) & (df["label"] != "RMSK"), :].reset_index()

In [None]:
# define evaluation function
def precision_recall(pred: pd.DataFrame, insertions: pd.DataFrame):
    """
    Calculate precision and recall for a binary classifier
    pred: predicted labels for genomic windows
    insertions: L1 annotations
    """
    assert "pred" in pred.columns, "pred must have column 'pred'"
    assert set(pred.pred.unique()) == set([0, 1]), "pred must be binary"

    for col in ["Chromosome", "Start", "End"]:
        assert col in pred.columns, f"pred must have column {col}"
        assert col in rmsk.columns, f"rmsk must have column {col}"
        assert col in knrgl.columns, f"knrgl must have column {col}"

    # only consider insertions that have windows
    insertions = pr.PyRanges(insertions).overlap(pr.PyRanges(pred)).df

    # how many insertions were detected?
    y_pos = pred.loc[pred["pred"] == 1, :]
    tp = len(pr.PyRanges(insertions).overlap(pr.PyRanges(y_pos)).df)

    # how many insertions were false positives?
    fp = len(pr.PyRanges(insertions).overlap(pr.PyRanges(y_pos), invert=True).df)

    # how many insertions were missed?
    y_neg = pred.loc[pred["pred"] == 0, :]
    fn = len(pr.PyRanges(insertions).overlap(pr.PyRanges(y_neg)).df)

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)

    return precision, recall

### Use cross validation

In [None]:
# define custom splitter for AutoML split_type argument
# "A valid splitter object is an instance of a derived class of scikit-learn KFold and have split and get_n_splits methods with the same signatures. Set eval_method to "cv" to use the splitter object."
class SampleChromosomeSplitter:
    def __init__(self, n_splits, X, y, sample_col):
        """
        Initialize the splitter object
        X: pandas dataframe with columns "donor_id" and "Chromosome"
        y: pandas series with labels
        """
        assert sample_col in X.columns, f"X must have column {sample_col}"
        assert "Chromosome" in X.columns, "X must have column 'Chromosome'"
        self.n_splits = n_splits
        self.sample_array = X[sample_col].values
        self.chr_array = X["Chromosome"].values
        self.y = y

    def split(self, X):
        assert (
            X.shape[0] == self.y.shape[0]
        ), "X and y must have the same number of rows"
        for chr_train_index, chr_test_index in StratifiedGroupKFold(
            n_splits=self.n_splits
        ).split(X, self.y, groups=self.chr_array):
            for sample_train_index, sample_test_index in StratifiedGroupKFold(
                n_splits=self.n_splits
            ).split(X, self.y, groups=self.sample_array):
                train_index = np.intersect1d(sample_train_index, chr_train_index)
                test_index = np.intersect1d(sample_test_index, chr_test_index)

                yield train_index, test_index

    def get_n_splits(self):
        return self.n_splits ^ 2


# test splitter
sample_col = "cell_id"
splitter = SampleChromosomeSplitter(
    n_splits=5, X=tune_data, y=tune_data["label_encoded"], sample_col=sample_col
)
for train_index, test_index in splitter.split(tune_data[features]):
    train = tune_data.iloc[train_index]
    test = tune_data.iloc[test_index]

    print(f"Train: {len(train_index)}, Test: {len(test_index)}")
    print(
        f"Training on samples {train[sample_col].nunique()} and Chromosomes {','.join(train['Chromosome'].unique())}"
    )
    print(
        f"Testing on samples {test[sample_col].nunique()} and Chromosomes {','.join(test['Chromosome'].unique())}"
    )
    break

In [None]:
# fit using CV
# TODO: use holdout data for testing
clf = AutoML()

clf.fit(
    X_train=tune_data[features],
    y_train=tune_data["label_encoded"],
    metric="f1",
    eval_method="cv",
    split_type=SampleChromosomeSplitter(
        n_splits=5, X=tune_data, y=tune_data["label_encoded"], sample_col="cell_id"
    ),
    log_file_name="flaml_cv.log",
    time_budget=600,
    **flaml_settings,
)

### Use single train and test sets

In [None]:
# get train and test chromosomes
chrom_sgkf = StratifiedGroupKFold(n_splits=5).split(
    tune_data, tune_data["label_encoded"], groups=tune_data["Chromosome"]
)
chr_train_index, chr_test_index = next(chrom_sgkf)
train_chrs = tune_data.iloc[chr_train_index, :]["Chromosome"].unique()
test_chrs = tune_data.iloc[chr_test_index, :]["Chromosome"].unique()
assert (
    np.intersect1d(train_chrs, test_chrs, eval_chr).size == 0
), "Train and test chromosomes must be mutually exclusive"

# get train and test cells
cell_sgkf = StratifiedGroupKFold(n_splits=5).split(
    tune_data, tune_data["label_encoded"], groups=tune_data["cell_id"]
)
cell_train_index, cell_test_index = next(cell_sgkf)
train_cells = tune_data.iloc[cell_train_index, :]["cell_id"].unique()
test_cells = tune_data.iloc[cell_test_index, :]["cell_id"].unique()
assert (
    np.intersect1d(train_cells, test_cells).size == 0
), "Train and test cells must be mutually exclusive"


train_df = tune_data.loc[
    (tune_data["Chromosome"].isin(train_chrs))
    & (tune_data["cell_id"].isin(train_cells)),
    :,
]
test_df = tune_data.loc[
    (tune_data["Chromosome"].isin(test_chrs)) & (tune_data["cell_id"].isin(test_cells)),
    :,
]

train_df = train_df.loc[train_df["nr1"] >= 10, :]
test_df = test_df.loc[test_df["nr1"] >= 10, :]

In [None]:
train_chrs

In [None]:
test_chrs

In [None]:
# fit using holdout data
clf = AutoML()

clf.fit(
    X_train=train_df[features],
    y_train=train_df["label_encoded"],
    X_val=test_df[features],
    y_val=test_df["label_encoded"],
    metric="f1",
    time_budget=600,
    eval_method="holdout",
    log_file_name="flaml_holdout.log",
    **flaml_settings
)

In [None]:
clf.best_config

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
for d in df["donor_id"].unique():
    val_data = eval_data.loc[(eval_data["donor_id"] == d) & (eval_data["nr1"] >= 10), :]
    class_counts = val_data.label.value_counts().to_dict()
    print(f"donor {d} {eval_chr} windows: {val_data.shape[0]}")
    PrecisionRecallDisplay.from_estimator(
        clf, val_data[features], val_data["label_encoded"], name=f"donor {d}", ax=ax[0]
    )
    RocCurveDisplay.from_estimator(
        clf, val_data[features], val_data["label_encoded"], name=f"donor {d}", ax=ax[1]
    )