In [1]:
import sys, pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

## Read in data

In [None]:
df = pd.read_pickle(snakemake.input.folds)
nlabels = df["label"].nunique()
sns.set_palette(sns.husl_palette(nlabels))

## Label Distributions

### Per donor

In [None]:
plot_df = df.value_counts(["label", "donor_id"]).to_frame("count").reset_index()
fig = sns.barplot(plot_df, x="count", y="donor_id", hue="label")
fig.set_xscale("log")
sns.move_legend(fig, "upper left", bbox_to_anchor=(1, 1))
sns.despine()

### Per cell

In [None]:
plot_df = (
    df.value_counts(["label", "donor_id", "cell_id"]).to_frame("count").reset_index()
)
fig = sns.boxplot(
    plot_df,
    x="count",
    y="donor_id",
    hue="label",
    hue_order=["KNRGL", "RL1", "OTHER"],
    palette=sns.husl_palette(nlabels),
)
fig.set_xscale("log")
sns.move_legend(fig, "upper left", bbox_to_anchor=(1, 1))
sns.despine()

### Per fold

In [None]:
def logplot(**kwargs):
    data = kwargs.pop("data")
    ax = sns.barplot(data, **kwargs)
    ax.set_xscale("log")


plot_df = (
    df.value_counts(["label", "donor_id", "fold", "stage"])
    .to_frame("count")
    .reset_index()
)

fig = sns.FacetGrid(plot_df, col="fold", row="stage", sharey=False)
fig.map_dataframe(
    logplot,
    x="count",
    y="donor_id",
    hue="label",
    hue_order=["KNRGL", "RL1", "OTHER"],
    palette=sns.husl_palette(nlabels),
)
fig.add_legend()
sns.despine()

## Feature Distributions

In [6]:
plot_df = df.reset_index().melt(
    id_vars=[
        "chrom",
        "start",
        "end",
        "donor_id",
        "cell_id",
        "label",
        "build",
        "db",
        "fold",
        "stage",
    ],
    var_name="feature",
)

fig = sns.FacetGrid(plot_df, col="feature", col_wrap=5, sharex=False)
fig.map_dataframe(
    sns.boxplot,
    x="value",
    y="label",
    hue_order=["KNRGL", "RL1", "OTHER"],
    fliersize=0,
    palette=sns.husl_palette(nlabels),
)
sns.despine()

## Precision/Recall

In [None]:
df = pd.concat([pd.read_pickle(f) for f in snakemake.input.prcurve])
df["label_fold"] = df["label"] + "; fold " + df["fold"].astype(str)
nfolds = df["fold"].nunique()

In [None]:
# Plot PR curves
colors = []
for fold in range(nfolds):
    colors.extend(sns.husl_palette(nlabels, l=0.3 + (fold / nfolds) * 0.4))
sns.set_palette(colors)

sns.set_style("ticks")
fig = sns.relplot(
    data=df,
    x="recall",
    y="precision",
    hue="label_fold",
    col="stage",
    row="model_id",
    kind="line",
    errorbar=None,
)
sns.despine()

In [None]:
# Plot Area Under PR Curve

fig = sns.FacetGrid(
    df[["label", "stage", "model_id", "fold", "auprc"]].drop_duplicates(),
    col="model_id",
    col_wrap=3,
)
fig.map_dataframe(
    sns.boxplot,
    x="stage",
    y="auprc",
    hue="label",
    hue_order=["KNRGL", "RL1", "OTHER"],
    palette=sns.husl_palette(nlabels),
)
fig.add_legend()
sns.despine()

## Confusion Matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

with open(snakemake.input.label_encoder, "rb") as f:
    le = pickle.load(f)

for i, model_id in enumerate(snakemake.params.model_ids):
    with open(snakemake.input.confusion[i], "rb") as f:
        cm_dict = pickle.load(f)

    # TODO: make 3-panel plot for each model
    for stage in cm_dict.keys():
        cm_list = [cm_dict[stage][fold] for fold in cm_dict[stage].keys()]
        cm = np.mean(cm_list, axis=0).astype(
            int
        )  # average confusion matrix across folds
        plt.title(f"Confusion Matrix for {model_id} model in {stage} stage")
        ConfusionMatrixDisplay(cm, display_labels=le.classes_).plot()
        plt.clf()

# TODO: add feature importance