In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

In [None]:
def read_salmon_counts(ground_truth_path, new_counts_paths) -> pd.DataFrame:
    """
    Read in ground truth and estimated counts, compute difference.
    """
    # read in ground truth counts
    truth = pd.read_csv(ground_truth_path, sep="\t", index_col=0).melt(
        value_name="count", var_name="sample", ignore_index=False
    )
    truth["sample"] = truth["sample"].astype(int) + 1
    truth = truth.reset_index().set_index(["tx_id", "sample"])

    # read in estimated counts
    estimate = {}
    for f in [new_counts_paths]:
        sample = f
        estimate[sample] = pd.read_csv(f, sep="\t", index_col=0).iloc[:, 3]

    estimate = (
        pd.DataFrame(estimate)
        .melt(value_name="count", var_name="sample", ignore_index=False)
        .reset_index()
        .rename(columns={"Name": "tx_id"})
        .set_index(["tx_id", "sample"])
    )

    # combine truth and estimate, compute difference
    benchmark = truth.join(estimate, lsuffix="_true", rsuffix="_estimated")
    benchmark["difference"] = benchmark["count_true"] - benchmark["count_estimated"]
    benchmark["log2_abs_difference"] = np.log2(np.abs(benchmark["difference"]))
    benchmark.reset_index(inplace=True)

    return benchmark

In [None]:
def plot_difference(benchmark) -> None:
    """
    Plot difference of estimated vs true counts in boxplots
    """

    ax = plt.gca()
    sns.boxplot(
        data=benchmark[~benchmark.tx_id.str.contains("L1HS")],
        x="sample",
        y="log2_abs_difference",
        ax=ax,
    )
    sns.stripplot(
        data=benchmark[benchmark.tx_id.str.contains("L1HS")],
        x="sample",
        y="log2_abs_difference",
        ax=ax,
        color="red",
    )
    ax.text(
        x=0.5,
        y=1.1,
        s="Difference of Estimated Counts from True Counts",
        fontsize=12,
        weight="bold",
        ha="center",
        va="bottom",
        transform=ax.transAxes,
    )
    ax.text(
        x=0.5,
        y=1.05,
        s="Red points = L1 transcripts",
        fontsize=8,
        alpha=0.75,
        ha="center",
        va="bottom",
        transform=ax.transAxes,
    )

In [None]:
def plot_l1hs(benchmark) -> None:
    """
    Plot estimated vs true counts for L1 transcripts
    """

    plot_df = benchmark[benchmark.tx_id.str.contains("L1HS")]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

    sns.lineplot(
        data=plot_df,
        x="count_true",
        y="count_estimated",
        hue="tx_id",
        palette="colorblind",
        ax=ax1,
    )

    sns.lineplot(
        data=plot_df,
        x="count_true",
        y="difference",
        hue="tx_id",
        palette="colorblind",
        ax=ax2,
    )

    # remove legends
    ax1.legend_.remove()

    # move ax2 legend outside
    handles, labels = ax2.get_legend_handles_labels()
    ax2.legend(
        handles=handles,
        labels=labels,
        loc="center left",
        bbox_to_anchor=(1, 0.5),
        frameon=False,
    )

In [None]:
# check if variable is list or string
"""if isinstance(snakemake.input.new_counts, str):
	new_counts_paths = [snakemake.input.new_counts]
else:
	new_counts_paths = snakemake.input.new_counts

benchmark = read_salmon_counts(snakemake.input.truth, new_counts_paths)

plot_difference(benchmark)
plot_l1hs(benchmark)"""

In [None]:
print(snakemake.input.new_counts)
print(snakemake.input.truth)