In [1]:
import sys, pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

## Read in data

In [None]:
folds_df = pd.read_pickle(snakemake.input.folds)
nlabels = folds_df["label"].nunique()
labels = folds_df["label"].unique()
sns.set_palette(sns.husl_palette(nlabels))
test_df = folds_df[(folds_df["stage"] == "test")]

## Label Distributions in test data

### Per donor

In [None]:
plot_df = test_df.value_counts(["label", "donor_id"]).to_frame("count").reset_index()
fig = sns.barplot(
    plot_df,
    x="count",
    y="donor_id",
    hue="label",
    palette=sns.color_palette(sns.husl_palette(nlabels)),
    hue_order=labels,
)
sns.move_legend(fig, "upper left", bbox_to_anchor=(1, 1))
sns.despine()

### Per cell

In [None]:
plot_df = (
    test_df.value_counts(["label", "donor_id", "cell_id"])
    .to_frame("count")
    .reset_index()
)
fig = sns.boxplot(
    plot_df,
    x="count",
    y="donor_id",
    hue="label",
    hue_order=labels,
    palette=sns.husl_palette(nlabels),
)
fig.set_xscale("log")
sns.move_legend(fig, "upper left", bbox_to_anchor=(1, 1))
sns.despine()

## Per donor per fold

In [None]:
plot_df = (
    folds_df.value_counts(["label", "donor_id", "fold", "stage"])
    .to_frame("count")
    .reset_index()
)

fig = sns.FacetGrid(plot_df, col="fold", row="stage", sharey=False, sharex=False)
fig.map_dataframe(
    sns.barplot,
    x="count",
    y="donor_id",
    hue="label",
    hue_order=labels,
    palette=sns.husl_palette(nlabels),
)
fig.add_legend()
sns.despine()

## Per donor per fold per cell

In [None]:
def logplot(**kwargs):
    data = kwargs.pop("data")
    ax = sns.boxplot(data, **kwargs)
    ax.set_xscale("log")


plot_df = (
    folds_df.value_counts(["label", "donor_id", "cell_id", "fold", "stage"])
    .to_frame("count")
    .reset_index()
)

fig = sns.FacetGrid(plot_df, col="fold", row="stage", sharey=False, sharex=False)
fig.map_dataframe(
    logplot,
    x="count",
    y="donor_id",
    hue="label",
    hue_order=labels,
    palette=sns.husl_palette(nlabels),
)
fig.add_legend()
sns.despine()

## Feature PCA

In [None]:
# downsample to smallest class
from imblearn.under_sampling import RandomUnderSampler

down_to = (
    test_df["label"].value_counts().min()
    if test_df["label"].value_counts().min() < 10000
    else 10000
)
sample_dict = {label: down_to for label in labels}
df, _ = RandomUnderSampler(sampling_strategy=sample_dict, random_state=42).fit_resample(
    test_df, test_df["label"]
)

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

# make input data
pca_df = df.drop(
    ["chrom", "start", "end", "donor_id", "cell_id", "label", "build", "db", "stage"],
    axis=1,
)
pca_df = StandardScaler().fit_transform(pca_df)

# fit PCA
pca = PCA(n_components=50, svd_solver="arpack").fit_transform(pca_df)

# preprare for plotting
plot_df = pd.DataFrame(pca[:, :50])
plot_df.columns = ["PC{}".format(i) for i in range(1, len(plot_df.columns) + 1)]
plot_df["label"] = df["label"].to_numpy()

In [None]:
sns.pairplot(
    plot_df[["PC1", "PC2", "PC3", "PC4", "PC5", "label"]],  # first 5 PCs
    hue="label",
    hue_order=labels,
    plot_kws={"alpha": 0.5, "size": 2},
)

In [None]:
# tSNE
# TODO: color by donor, other covariates
# TODO: try different resolutions
from sklearn.manifold import TSNE

tsne = TSNE(random_state=42, init="random").fit_transform(plot_df.drop("label", axis=1))

tsne_df = pd.DataFrame(tsne)
tsne_df.columns = ["tSNE1", "tSNE2"]
tsne_df["label"] = plot_df["label"]

In [None]:
sns.scatterplot(
    tsne_df, x="tSNE1", y="tSNE2", hue="label", hue_order=labels, alpha=0.5, s=3
)

## Precision/Recall

In [None]:
pr_list = []
for i, model_id in enumerate(snakemake.params.model_ids):
    with open(snakemake.input.metrics[i], "rb") as f:
        metrics = pickle.load(f)
    for fold in metrics.keys():
        for stage in ["train", "test", "test_shuffled"]:
            for label in labels:
                pr = pd.DataFrame()
                pr["precision"] = metrics[fold][stage][label]["prcurve"]["precision"]
                pr["recall"] = metrics[fold][stage][label]["prcurve"]["recall"]
                pr = pr.sample(1000) if len(pr) > 1000 else pr
                pr["auprc"] = metrics[fold][stage][label]["auprc"]
                pr["fold"] = fold
                pr["stage"] = stage
                pr["model_id"] = model_id
                pr["label"] = label
                pr_list.append(pr)
df = pd.concat(pr_list)

In [None]:
# Plot PR curves
sns.set_style("ticks")
fig = sns.relplot(
    data=df,
    x="recall",
    y="precision",
    hue="label",
    style="fold",
    col="stage",
    row="model_id",
    kind="line",
    hue_order=labels,
    errorbar=None,
)
sns.despine()

In [None]:
# Plot Area Under PR Curve
fig = sns.FacetGrid(
    df[["label", "stage", "model_id", "fold", "auprc"]].drop_duplicates(),
    col="model_id",
    col_wrap=3,
)
fig.map_dataframe(
    sns.pointplot,
    x="stage",
    y="auprc",
    hue="label",
    hue_order=labels,
    palette=sns.husl_palette(nlabels),
)
fig.add_legend()
sns.despine()

## Confusion Matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

with open(snakemake.input.label_encoder, "rb") as f:
    le = pickle.load(f)

for i, model_id in enumerate(snakemake.params.model_ids):
    with open(snakemake.input.metrics[i], "rb") as f:
        metrics = pickle.load(f)

    fig, ax = plt.subplots(1, 3, sharey="row")
    for j, stage in enumerate(["train", "test", "test_shuffled"]):
        cm_list = [metrics[fold][stage]["cm"] for fold in metrics.keys()]
        cm = np.mean(cm_list, axis=0)  # average confusion matrix across folds
        disp = ConfusionMatrixDisplay(cm, display_labels=le.classes_)
        disp.plot(ax=ax[j], xticks_rotation=45)
        disp.im_.colorbar.remove()
        disp.ax_.set_title(stage)
        disp.ax_.set_xlabel("")
        if j > 0:
            disp.ax_.set_ylabel("")
    fig.text(0.4, 0.75, model_id)
    fig.text(0.4, 0.15, "Predicted label", ha="left")
    plt.show()

# Feature Importances

In [None]:
from functools import reduce

for i, model_id in enumerate(snakemake.params.model_ids):
    with open(snakemake.input.metrics[i], "rb") as f:
        metrics = pickle.load(f)

    if "feature_importances" not in metrics[0].keys():
        print(f"{model_id} does not have feature importances")
        continue
    fi_list = [metrics[fold]["feature_importances"] for fold in metrics.keys()]
    fi_df = reduce(lambda x, y: pd.merge(x, y, on="feature name"), fi_list)
    fi_df["mean_importance"] = fi_df.mean(axis=1)
    fi_df.sort_values("mean_importance", ascending=False, inplace=True)
    fi_df.reset_index(inplace=True)
    plt.figure(figsize=(6, 14))
    sns.pointplot(data=fi_df, x="mean_importance", y="feature name", join=False).set(
        title=model_id
    )