In [None]:
from pathlib import Path
import json

# parallelization
from joblib import Parallel, delayed

# data science / ML
import pyarrow.parquet as pq
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from flaml import AutoML
from scipy.stats import pearsonr, spearmanr


# plotting
import seaborn as sns
import matplotlib.pyplot as plt

# my own modules
from scripts.pyslavseq.model_selection import SampleChrSplitter, Model, parse_log

# Load the data

## Read

In [None]:
# read and save data, takes a long time to run
data = []
for f in Path("../results/model/get_labels/").rglob("*nonrefonly.pqt"):
    print(f"Reading {f}")
    ddata = pq.read_table(f).to_pandas()
    ddata = ddata.loc[ddata.rpm >= 2, :]
    if "bulk_peaks" in ddata.columns:
        ddata = ddata.drop(columns=["bulk_peaks", "bulk_peaks_id"]).drop_duplicates()
    # convert float32 to float16
    for c in ddata.columns:
        if (ddata[c].dtype == "float32") and (c != "rpm"):
            ddata[c] = ddata[c].astype("float16")
            assert not np.isinf(ddata[c]).any(), f"{c} column contains inf values"
            assert not ddata[c].isna().any(), f"{c} column contains nan values"
    data.append(ddata)

data = pd.concat(data)
assert (
    data.shape[0]
    == data[["Chromosome", "Start", "End", "cell_id"]].drop_duplicates().shape[0]
), "some rows have been duplicated during labeling!"

# save
data.to_pickle("data.pkl")

## Load

In [None]:
# read data
data = pd.read_pickle("data.pkl")

In [None]:
assert data.donor_id.nunique() == 38, "Wrong number of donors"

## Get metadata

In [None]:
# read metadata
meta = pd.read_csv(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/config/slavseq_metadata.tsv", sep="\t"
)
meta.columns = [col.lower() for col in meta.columns]
donors = pd.read_csv(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/config/all_donors.tsv", sep="\t"
)
cells = pd.read_csv(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/config/all_samples.tsv", sep="\t"
)
cells = pd.merge(cells, donors, on="donor_id", how="left")
cells = pd.merge(
    cells, meta[["tissue_id", "sequencing", "region"]], on="tissue_id", how="left"
)

# Visualize window reads by labels

In [None]:
def compute_ecdf(data):

    sorted_data = np.sort(data)[::-1]
    y_values = np.arange(1, len(data) + 1)

    return sorted_data, y_values


def label_ecdf(data, label):
    assert label in data.columns, f"{label} not in data"
    assert data[label].dtype == bool, f"{label} must be boolean"

    x_pos, y_pos = compute_ecdf(data[data[label]]["rpm"].values)
    x_neg, y_neg = compute_ecdf(data[~data[label]]["rpm"].values)

    return {"pos": (x_pos, y_pos), "neg": (x_neg, y_neg)}

In [None]:
fig, (axs) = plt.subplots(
    data.donor_id.nunique(), 3, figsize=(17, 6 * data.donor_id.nunique())
)

plt.subplots_adjust(hspace=0.3)

for (d, ddata), ax1, ax2, ax3 in zip(
    data.groupby("donor_id"), axs[:, 0], axs[:, 1], axs[:, 2]
):
    for ax, label in zip([ax1, ax2, ax3], ["xtea", "xtea_1kb_3end", "bulk_peaks"]):
        print(f"Running donor {d}, label {label}")
        ecdf = np.array(
            Parallel(n_jobs=4, verbose=2)(
                delayed(label_ecdf)(df, label) for _, df in ddata.groupby("cell_id")
            )
        )
        for e in ecdf:
            ax.plot(e["pos"][0], e["pos"][1], c=sns.color_palette()[0], alpha=0.5)
            ax.plot(e["neg"][0], e["neg"][1], c=sns.color_palette()[1], alpha=0.5)

        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.set_xlabel("RPM")
        ax.set_ylabel("Count")
        ax.legend(["True", "False"]).set_title(label)
        ax.set_title(f"Donor {d}")
plt.show()

# Test the model

In [None]:
# define features
features = []
keys = ["_mean", "frac", "gini", "bias"]
for c in data.columns:
    if ("_score" in c) or ("_length" in c):
        if "_normed" not in c:
            continue
    for k in keys:
        if k in c:
            features.append(c)
features.append("rpm")
print("Features:", features)

# define the classifier
clf = AutoML(
    task="classification",
    estimator_list=["xgboost"],
    early_stop=True,
    eval_method="cv",
    time_budget=120,  # time budget in seconds
    metric="ap",
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
    verbose=4,
)

# setup outdir
Path("model_logs").mkdir(exist_ok=True)

## Naive splits

In [None]:
# helper functions
def cv(
    clf: AutoML, data: pd.DataFrame, features: list, label_col: str, rpm_filter: int
):

    # initialize my custom model class
    mdl = Model(
        data=data, features=features, label_col=label_col, rpm_filter=rpm_filter
    )

    # initialize splitter
    skf = StratifiedKFold(n_splits=5)

    result = []
    for i, (train_idx, test_idx) in enumerate(
        skf.split(mdl.data, mdl.data[mdl.label_col])
    ):
        print(f"Fold {i+1}")
        train_metrics, test_metrics, model_metrics = mdl.fit(
            train_idx,
            test_idx,
            clf,
            sample_col="cell_id",
            n_chr_splits=2,
            n_sample_splits=2,
        )
        train_metrics["fold"], test_metrics["fold"] = i + 1, i + 1
        train_metrics.update(model_metrics)
        test_metrics.update(model_metrics)
        result.append(train_metrics)
        result.append(test_metrics)

    return result


# define the data
ddata = data[data.donor_id == "CommonBrain"]
out = cv(clf, ddata, features, label_col="xtea_1kb_3end", rpm_filter=5)
out = pd.DataFrame(out)

In [None]:
sns.relplot(
    out.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

## Split by cells

In [None]:
# helper functions
def cv(
    clf: AutoML, data: pd.DataFrame, features: list, label_col: str, rpm_filter: int
):

    # initialize my custom model class
    mdl = Model(
        data=data, features=features, label_col=label_col, rpm_filter=rpm_filter
    )

    # initialize splitter
    sgkf = StratifiedGroupKFold(n_splits=5)

    result = []
    for i, (train_idx, test_idx) in enumerate(
        sgkf.split(mdl.data, mdl.data[mdl.label_col], groups=mdl.data["cell_id"])
    ):
        print(f"Fold {i+1}")
        train_metrics, test_metrics, model_metrics = mdl.fit(
            train_idx,
            test_idx,
            clf,
            sample_col="cell_id",
            n_chr_splits=2,
            n_sample_splits=2,
        )
        train_metrics["fold"], test_metrics["fold"] = i + 1, i + 1
        train_metrics.update(model_metrics)
        test_metrics.update(model_metrics)
        result.append(train_metrics)
        result.append(test_metrics)

    return result


# define the data
ddata = data[data.donor_id == "CommonBrain"]
out = cv(clf, ddata, features, label_col="xtea_1kb_3end", rpm_filter=5)
out = pd.DataFrame(out)

In [None]:
sns.relplot(
    out.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

## Split by chromosome

### Optimize sample_pos_weight

In [None]:
def run_cv(scale_pos_weight=int):
    # define the data
    ddata = data[data.donor_id == "CommonBrain"]

    # add scale_pos_weight to settings
    clf._settings["scale_pos_weight"] = scale_pos_weight

    mdl = Model(
        clf=clf,
        data=ddata,
        features=features,
        label_col="xtea_1kb_3end",
        rpm_filter=5,
        outfile=f"model_logs/CommonBrain_weight{scale_pos_weight}.log",
    )
    mdl.cv(n_splits=5)
    out = mdl.get_results()
    out["scale_pos_weight"] = scale_pos_weight
    return out


results = Parallel(n_jobs=6, verbose=2)(
    delayed(run_cv)(w) for w in [1, 10, 20, 50, 100, 200]
)
results = pd.concat(results)

In [None]:
sns.relplot(
    results.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    row="scale_pos_weight",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

### On all data together

In [None]:
# takes 30 min
mdl = Model(
    clf=clf,
    data=data,
    features=features,
    label_col="xtea_1kb_3end",
    rpm_filter=5,
    outfile="model_logs/all_120_budget_41_posweight_no_static_concurrent.log",
)
mdl.cv(n_splits=5)
results = mdl.get_results()
sns.relplot(
    results.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

### On seq platforms separately

In [None]:
def run_cv(platform: str, cells: list):
    # define the data
    ddata = data[data.cell_id.isin(cells)]
    mdl = Model(
        clf=clf,
        data=ddata,
        features=features,
        label_col="xtea_1kb_3end",
        rpm_filter=5,
        outfile=f"model_logs/{platform}.log",
    )
    mdl.cv(n_splits=5)
    out = mdl.get_results()
    out["platform"] = platform
    return out


results = Parallel(n_jobs=2, verbose=2)(
    delayed(run_cv)(platform, cells[cells["sequencing"] == platform].sample_id.unique())
    for platform in ["NOVASEQ", "HISEQ"]
)
results = pd.concat(results)

In [None]:
g = sns.relplot(
    results.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    row="platform",
    kind="line",
)
g.set(xlim=(0, 1), ylim=(0, 1), xlabel="Adjusted locus recall", ylabel="Precision")

### On each tissue separately

In [None]:
def run_cv(tissue_id: str, cells: list):
    # define the data
    ddata = data[data.cell_id.isin(cells)]
    mdl = Model(
        clf=clf,
        data=ddata,
        features=features,
        label_col="xtea_1kb_3end",
        rpm_filter=5,
        outfile=f"model_logs/{tissue_id}_scale_pos.log",
    )
    mdl.cv(n_splits=5)
    out = mdl.get_results()
    out["tissue_id"] = tissue_id
    return out


results = Parallel(n_jobs=32, verbose=2)(
    delayed(run_cv)(tissue_id, df.sample_id.values)
    for tissue_id, df in cells.groupby("tissue_id")
)
results = pd.concat(results)

In [None]:
results = []
for file in Path("model_logs").rglob("*.log"):
    if file.stem not in cells.tissue_id.unique():
        continue
    with open(file) as f:
        log = json.load(f)
    log = parse_log(log)
    log["tissue_id"] = file.stem.rstrip("_scale_pos")
    results.append(log)
results = pd.concat(results)

tissues = (
    cells.drop(columns=["sample_id", "R1", "R2", "xtea"])
    .drop_duplicates()
    .set_index("tissue_id")
)
results = results.set_index("tissue_id").join(tissues, how="left").reset_index()

In [None]:
df = (
    results[results["stage"] == "test"]
    .explode(["precision", "adjusted_locus_recall"])
    .groupby(
        ["tissue_id", "stage", "sequencing", "race", "diagnosis", "age", "region"]
    )[["precision", "adjusted_locus_recall", "total_loci_train"]]
    .mean()
    .reset_index()
)

In [None]:
# evaluate knrgl in training set
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 5))

g = sns.scatterplot(
    df, x="adjusted_locus_recall", y="precision", hue="total_loci_train", ax=ax1
)
g.set(xlabel="Mean Adjusted Locus Recall", ylabel="Mean Precision")
g.set_title("Mean XGBoost performance across 5-fold Chromosome CV for each tissue")
g.legend(title="Mean KNRGL in training set")

g = sns.scatterplot(df, y="precision", x="total_loci_train", ax=ax2)
g.set(xlabel="Mean KNRGL in training set", ylabel="Mean Precision")

# add correlation coefficient from scipy
r, p = pearsonr(df["total_loci_train"], df["precision"])
g.text(0.05, 0.15, f"Pearson: r = {r:.2f}, p = {p:.2e}", transform=g.transAxes)
r, p = spearmanr(df["total_loci_train"], df["precision"])
g.text(0.05, 0.1, f"Spearman: r = {r:.2f}, p = {p:.2e}", transform=g.transAxes)


g = sns.scatterplot(df, y="adjusted_locus_recall", x="total_loci_train", ax=ax3)
g.set(xlabel="Mean KNRGL in training set", ylabel="Mean Adjusted Locus Recall")

# add correlation coefficient from scipy
r, p = pearsonr(df["total_loci_train"], df["adjusted_locus_recall"])
g.text(0.05, 0.15, f"Pearson: r = {r:.2f}, p = {p:.2e}", transform=g.transAxes)
r, p = spearmanr(df["total_loci_train"], df["adjusted_locus_recall"])
g.text(0.05, 0.1, f"Spearman: r = {r:.2f}, p = {p:.2e}", transform=g.transAxes)

In [None]:
df.sort_values("precision").head()

CommonBrain results

In [None]:
sns.relplot(
    results[results["tissue_id"] == "CommonBrain"].explode(
        ["precision", "adjusted_locus_recall"]
    ),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

Calculate average KNRGL coverage per tissue

In [None]:
cov = []
for f in Path("../results/qc/l1_coverage").rglob("*/*xtea_1kb_3end.r1.txt"):
    res = {}
    df = pd.read_csv(
        f,
        sep="\t",
        header=None,
        names=[
            "Chromosome",
            "Start",
            "End",
            "Name",
            "Score",
            "Strand",
            "n_reads",
            "n_bases_overlapped",
            "l1_length",
            "frac_overlap",
        ],
    )
    res["knrgl_zeros"] = df[df["n_reads"] == 0].shape[0]
    res["cell_id"] = f.stem.replace(".xtea_1kb_3end.r1", "")
    res["donor_id"] = f.parent.name
    res["region"] = "HIPPO" if "USH" in res["cell_id"].upper() else "DLPFC"
    res["total_knrgl"] = df.shape[0]
    res["knrgl_covered"] = res["total_knrgl"] - res["knrgl_zeros"]
    res["frac_missing"] = res["knrgl_zeros"] / res["total_knrgl"]
    cov.append(res)

cov = pd.DataFrame(cov)
cov = (
    cov.groupby(["donor_id", "total_knrgl", "region"])[
        ["knrgl_zeros", "frac_missing", "knrgl_covered"]
    ]
    .mean()
    .sort_values(by=["knrgl_covered"], ascending=False)
    .reset_index()
)

In [None]:
# plot total_knrgl per donor
g = sns.barplot(
    y="donor_id", x="total_knrgl", data=cov.sort_values("total_knrgl"), color="gray"
)
g.set(xlabel="Total KNRGL detected from WGS", ylabel="Donor ID")
# make figure size 10 x 10
g.figure.set_size_inches(5, 7)

In [None]:
# drop columns from cells
results = (
    results.set_index(["donor_id", "region"])
    .join(cov.set_index(["donor_id", "region"]), how="left")
    .reset_index()
)

In [None]:
df = (
    results[results["stage"] == "test"]
    .explode(["precision", "adjusted_locus_recall"])
    .groupby(
        [
            "tissue_id",
            "stage",
            "region",
            "diagnosis",
            "race",
            "sex",
            "sequencing",
            "donor_id",
            "knrgl_covered",
            "total_knrgl",
            "libd_id",
        ]
    )[["precision", "adjusted_locus_recall"]]
    .mean()
    .reset_index()
)

In [None]:
g = sns.scatterplot(df, x="adjusted_locus_recall", y="precision")
g.set(xlabel="Mean Adjusted Locus Recall", ylabel="Mean Precision")
g.set_title("Mean XGBoost performance across 5-fold Chromosome CV for each tissue")

In [None]:
g = sns.scatterplot(df, x="adjusted_locus_recall", y="precision", hue="knrgl_covered")
g.set(xlabel="Mean Adjusted Locus Recall", ylabel="Mean Precision")
g.set_title("Mean XGBoost performance across 5-fold Chromosome CV for each tissue")
g.legend(
    title="Mean KNRGL Covered / cell",
    bbox_to_anchor=(1.05, 1),
    loc=2,
    borderaxespad=0.0,
)

In [None]:
g = sns.scatterplot(df, x="adjusted_locus_recall", y="precision", hue="total_knrgl")
g.set(xlabel="Mean Adjusted Locus Recall", ylabel="Mean Precision")
g.set_title("Mean XGBoost performance across 5-fold Chromosome CV for each tissue")
g.legend(
    title="Total KNRGL detected from WGS",
    bbox_to_anchor=(1.05, 1),
    loc=2,
    borderaxespad=0.0,
)

In [None]:
# plotly express
import plotly.express as px

px.scatter(
    df,
    x="adjusted_locus_recall",
    y="precision",
    color="knrgl_covered",
    hover_name="tissue_id",
    hover_data=[
        "tissue_id",
        "knrgl_covered",
        "adjusted_locus_recall",
        "precision",
        "libd_id",
    ],
)

### How does performance change with different cutoffs?

In [None]:
# helper functions
def cv(
    clf: AutoML, data: pd.DataFrame, features: list, label_col: str, rpm_filter: int
):

    # initialize my custom model class
    mdl = Model(
        data=data, features=features, label_col=label_col, rpm_filter=rpm_filter
    )

    # initialize splitter
    splitter = SampleChrSplitter(
        X=mdl.data,
        y=mdl.data[label_col],
        sample_col="cell_id",
        n_chr_splits=4,
        n_sample_splits=2,
    )

    result = []
    for i, (train_idx, test_idx) in enumerate(splitter.split(mdl.data)):
        print(f"Fold {i+1}")
        train_metrics, test_metrics, model_metrics = mdl.fit(
            train_idx,
            test_idx,
            clf,
            sample_col="cell_id",
            n_chr_splits=4,
            n_sample_splits=2,
        )
        train_metrics["fold"], test_metrics["fold"] = i + 1, i + 1
        train_metrics.update(model_metrics)
        test_metrics.update(model_metrics)
        result.append(train_metrics)
        result.append(test_metrics)

    return result

In [None]:
out = cv(
    clf=clf,
    data=data,
    features=features,
    label_col="xtea_1kb_3end",
    rpm_filter=5,
)

In [None]:
sns.relplot(
    pd.DataFrame(out).explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="fold",
    col="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

In [None]:
result = []
for rpm in [2, 3, 5, 7, 10]:
    print(f"Running {rpm} RPM")

    out = cv(
        clf=clf,
        data=data,
        features=features,
        label_col="xtea_1kb_3end",
        rpm_filter=rpm,
    )

    out = pd.DataFrame(out)
    out["rpm"] = rpm
    result.append(out)

result = pd.concat(result)

In [None]:
sns.relplot(
    result.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="rpm",
    col="fold",
    row="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

In [None]:
# heatmap of feature importances
fresult = (
    result.explode(["features", "feature_importances"])
    .groupby(["rpm", "features"])
    .agg({"feature_importances": "mean"})
    .reset_index()
    .pivot(index="features", columns="rpm", values="feature_importances")
)
sns.heatmap(fresult.astype(float)).set(xlabel="rpm filter", ylabel=None)

In [None]:
result = []
for rpm in [2, 5, 10, 20, 50, 100]:
    print(f"Running {rpm} RPM")

    out = cv(
        clf=clf,
        data=data,
        features=features,
        label_col="xtea_1kb_3end",
        rpm_filter=rpm,
    )

    out = pd.DataFrame(out)
    out["rpm"] = rpm
    result.append(out)

result = pd.concat(result)

In [None]:
sns.relplot(
    result.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="rpm",
    col="fold",
    row="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

In [None]:
# heatmap of feature importances
fresult = (
    result.explode(["features", "feature_importances"])
    .groupby(["rpm", "features"])
    .agg({"feature_importances": "mean"})
    .reset_index()
    .pivot(index="features", columns="rpm", values="feature_importances")
)
sns.heatmap(fresult.astype(float)).set(xlabel="rpm filter", ylabel=None)

### How does performance change with different metrics for optimization?

In [None]:
# define the features
features = []
keys = ["_mean", "frac", "gini", "bias"]
for c in data.columns:
    if ("_score" in c) and ("_normed" not in c):
        continue
    for k in keys:
        if k in c:
            features.append(c)
features.append("rpm")
print("Features:", features)

# define the classifier
clf = AutoML(
    task="classification",
    estimator_list=["xgboost"],
    early_stop=True,
    eval_method="cv",
    time_budget=120,  # time budget in seconds
    verbose=0,
    metric="ap",
    skip_transform=True,  # don't preprocess data
    auto_augment=False,  # don't augment rare classes
    starting_points="static",  # use data-independent hyperparameterstarting points
)

In [None]:
# helper functions
def cv(
    clf: AutoML, data: pd.DataFrame, features: list, label_col: str, rpm_filter: int
):

    # initialize my custom model class
    mdl = Model(
        data=data, features=features, label_col=label_col, rpm_filter=rpm_filter
    )

    # initialize splitter
    splitter = SampleChrSplitter(
        X=mdl.data,
        y=mdl.data[label_col],
        sample_col="cell_id",
        n_chr_splits=4,
        n_sample_splits=4,
    )

    result = []
    for i, (train_idx, test_idx) in enumerate(splitter.split(mdl.data)):
        print(f"Fold {i+1}")
        train_metrics, test_metrics, model_metrics = mdl.fit(
            train_idx,
            test_idx,
            clf,
            sample_col="cell_id",
            n_chr_splits=2,
            n_sample_splits=2,
        )
        train_metrics["fold"], test_metrics["fold"] = i + 1, i + 1
        train_metrics.update(model_metrics)
        test_metrics.update(model_metrics)
        result.append(train_metrics)
        result.append(test_metrics)

    return result

In [None]:
result = []
for m in ["ap", "f1"]:

    clf = AutoML(
        task="classification",
        estimator_list=["xgboost"],
        early_stop=True,
        eval_method="cv",
        metric=m,
        time_budget=120,  # time budget in seconds
        verbose=0,
        skip_transform=True,  # don't preprocess data
        auto_augment=False,  # don't augment rare classes
        starting_points="static",  # use data-independent hyperparameterstarting points
    )

    out = cv(
        clf=clf,
        data=data,
        features=features,
        label_col="xtea_1kb_3end",
        rpm_filter=5,
    )

    out = pd.DataFrame(out)
    out["metric"] = m
    result.append(out)

result = pd.concat(result)

In [None]:
sns.relplot(
    result.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="metric",
    col="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

### How does performance change with increasing numbers of donors?

In [None]:
from sklearn.model_selection import StratifiedGroupKFold

# initialize my custom model class
print("Initializing Model object")
mdl = Model(data=data, features=features, label_col="xtea", rpm_filter=5)

# initialize splitter
sgkf = StratifiedGroupKFold(n_splits=2)

test_donor = mdl.data.donor_id.unique()[-1]
train_donors = mdl.data.donor_id.unique()[:-1]

result = []
for n_donors in range(1, mdl.data.donor_id.nunique()):
    print(f"Running n_donors: {n_donors}")
    for i, (train_chrs, test_chrs) in enumerate(
        sgkf.split(mdl.data, mdl.data["xtea"], mdl.data["Chromosome"])
    ):
        print(f"Fold {i+1}")
        train_donor_idx = mdl.data.loc[
            mdl.data["donor_id"].isin(train_donors[0:n_donors]), :
        ].index
        test_donor_idx = mdl.data.loc[mdl.data["donor_id"] == test_donor, :].index
        assert (
            len(np.intersect1d(train_donor_idx, test_donor_idx)) == 0
        ), "Donors in train and test set overlap"
        train_idx = np.intersect1d(train_donor_idx, train_chrs)
        test_idx = np.intersect1d(test_donor_idx, test_chrs)
        train_metrics, test_metrics, model_metrics = mdl.fit(
            train_idx,
            test_idx,
            clf,
            sample_col="cell_id",
            n_chr_splits=2,
            n_sample_splits=2,
        )
        for metrics in [train_metrics, test_metrics]:
            metrics["fold"] = i + 1
            metrics["n_donors"] = n_donors
            metrics.update(model_metrics)
            result.append(metrics)
result = pd.DataFrame(result)

In [None]:
sns.relplot(
    result.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="n_donors",
    col="fold",
    row="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

In [None]:
# heatmap of feature importances
fresult = (
    result.explode(["features", "feature_importances"])
    .groupby(["n_donors", "features"])
    .agg({"feature_importances": "mean"})
    .reset_index()
    .pivot(index="features", columns="n_donors", values="feature_importances")
)
sns.heatmap(fresult.astype(float)).set(xlabel="n_donors", ylabel=None)

### How does performance change with increasing numbers of chromosomes?

In [None]:
# initialize my custom model class
print("Initializing Model object")
mdl = Model(data=data, features=features, label_col="xtea", rpm_filter=5)

chr_sgkf = StratifiedGroupKFold(n_splits=5).split(
    mdl.data, mdl.data["xtea"], mdl.data["Chromosome"]
)
cell_sgkf = StratifiedGroupKFold(n_splits=4)

result = []
test_chrs_idx = next(chr_sgkf)[1]
train_chrs_idx = np.array([])
for i, (_, chrs_idx) in enumerate(chr_sgkf):
    print(f"Running chr group {i+1}")
    train_chrs_idx = np.append(train_chrs_idx, chrs_idx)
    print(f"Training on {mdl.data.iloc[train_chrs_idx,:].Chromosome.unique()}")
    print(f"Testing on {mdl.data.iloc[test_chrs_idx,:].Chromosome.unique()}")
    for j, (train_cell_idx, test_cell_idx) in enumerate(
        cell_sgkf.split(mdl.data, mdl.data["xtea"], mdl.data["cell_id"])
    ):
        print(f"Fold {j+1}")
        train_idx = np.intersect1d(train_cell_idx, train_chrs_idx)
        test_idx = np.intersect1d(test_cell_idx, test_chrs_idx)
        train_metrics, test_metrics, model_metrics = mdl.fit(
            train_idx,
            test_idx,
            clf,
            sample_col="cell_id",
            n_chr_splits=2,
            n_sample_splits=2,
        )
        for metrics in [train_metrics, test_metrics]:
            metrics["fold"] = j + 1
            metrics["chr_groups"] = i + 1
            metrics.update(model_metrics)
            result.append(metrics)

result = pd.DataFrame(result)

In [None]:
sns.relplot(
    result.explode(["precision", "adjusted_locus_recall"]),
    x="adjusted_locus_recall",
    y="precision",
    hue="chr_groups",
    col="fold",
    row="stage",
    kind="line",
).set(xlim=(0, 1), ylim=(0, 1))

In [None]:
# heatmap of feature importances
fresult = (
    result.explode(["features", "feature_importances"])
    .groupby(["chr_groups", "features"])
    .agg({"feature_importances": "mean"})
    .reset_index()
    .pivot(index="features", columns="chr_groups", values="feature_importances")
)
sns.heatmap(fresult.astype(float)).set(xlabel="chr_groups", ylabel=None)