In [31]:
import pandas as pd
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from upsetplot import UpSet, from_memberships
from itertools import combinations
from typing import List, Dict
import warnings
import marsilea as ma
from scipy.stats import gaussian_kde

# Ignore future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

plt.rcParams['figure.dpi'] = 300

In [32]:
DETECTION_DIR = "detection"
OUT_DIR = "../chapters/4_results_and_discussion/figures/detection"
tools = [tool_csv[:-4] for tool_csv in os.listdir(DETECTION_DIR)]
tools

['ciri2', 'circexplorer2', 'find_circ', 'dcc', 'segemehl']

In [33]:
def parse_locstring(locstring: str):
    chrom, coords, strand = locstring.split(":")
    start, end = coords.split("-")
    return chrom, int(start), int(end), strand

In [34]:
def get_tool_data(tool: str, min_samples=1):
    df = pd.read_csv(os.path.join(DETECTION_DIR, f"{tool}.csv"), sep='\t', index_col=0)
    df.fillna(0, inplace=True)
    
    df_stats = pd.DataFrame(index=df.index)
    df_stats["n_samples"] = (df>0).sum(axis=1)
    df_stats["total_counts"] = df.sum(axis=1)
    df_stats["mean_counts"] = df_stats["total_counts"] / df_stats["n_samples"]

    mask = df_stats['n_samples'] > min_samples
    df = df[mask]
    df_stats = df_stats[mask]

    df_loc = pd.DataFrame([parse_locstring(loc) for loc in df.index], columns=["chrom", "start", "end", "strand"], index=df.index)
    df_loc["tool"] = tool

    # Merge df_stats and df_loc
    df_stats = pd.concat([df_stats, df_loc], axis=1)

    return df, df_stats

In [35]:
def bar_plot(tool_stats: Dict[str, pd.DataFrame], outfile: str):
    tool_n_bsjs = {tool: len(df) for tool, df in tool_stats.items()}

    # Bar plot
    ax = sns.barplot(x=list(tool_n_bsjs.keys()), y=list(tool_n_bsjs.values()))
    ax.bar_label(ax.containers[0])
    plt.ylabel("Number of BSJs detected")
    plt.xlabel("Tool")
    plt.title("Number of BSJs detected by each tool")
    plt.savefig(outfile)
    plt.close()

In [36]:
def violin_tools(df_stats: pd.DataFrame, outfile: str):
    df = df_stats.copy()
    df["length"] = df["end"] - df["start"]

    # Violin plot
    sns.violinplot(data=df, x="tool", y="length")
    plt.ylabel("BSJ length")
    plt.xlabel("Tool")
    plt.yscale("log")
    plt.title("BSJ length distribution by tool")
    plt.savefig(outfile)
    plt.close()

In [37]:
def get_diff_groups(df_locs: pd.DataFrame, max_diff: int = 0):
    df_diff = df_locs.sort_values(["chrom", "end"])
    df_diff["end_group"] = df_diff.groupby("chrom")["end"].diff().gt(max_diff).cumsum()
    df_diff = df_diff.sort_values(["chrom", "start"])
    df_diff["start_group"] = df_diff.groupby("chrom")["start"].diff().gt(max_diff).cumsum()

    return df_diff

In [38]:
def plot_diff_upset(diff_df: Dict[int, pd.DataFrame], outdir: str):
    os.makedirs(outdir, exist_ok=True)
    diff_plotdata = {}
    for diff, df in diff_df.items():
        for include_strand in [True, False]:
            df_grouped = df.groupby(["chrom", "start_group", "end_group"] + (["strand"] if include_strand else [])).aggregate({"tool": list})
            plotdata = from_memberships(df_grouped["tool"])
            if not include_strand:
                diff_plotdata[diff] = plotdata

            upset = UpSet(plotdata, subset_size="count", min_degree=2, min_subset_size=10)
            upset.plot()
            # plt.title(f"Max shift: {diff}, {"considering" if include_strand else "ignoring"} strand", fontsize=16)

            plt.savefig(os.path.join(outdir, f"diff_{diff}_{"strand" if include_strand else "nostrand"}.png"))
            plt.close()

    return diff_plotdata

In [39]:
def plot_pies(diff_plotdata: Dict[int, pd.Series], outfile: str):
    font_size = 25

    n_cols = 3
    n_rows = len(tools) // n_cols + 1

    _, axs = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))

    def get_count(investigation, diff=0):
        index = [[investigation[tool] for tool in diff_plotdata[diff].index.names]]
        return diff_plotdata[diff].loc[index].sum() if diff_plotdata[diff].index.isin(index).any() else 0

    cmap = plt.colormaps["tab20c"]
    colors = cmap(np.arange(len(tools))*4)

    for i, tool in enumerate(tools):
        others = [t for t in tools if t != tool]
        diff_overlaps = {}
        for diff in [0,1]:
            n_counts = {}
            for n_others in range(len(tools)):
                sum_count = 0
                for other_tools in combinations(others, n_others):
                    allowed_tools = other_tools + (tool,)
                    investigation = {t: t in allowed_tools for t in tools}
                    sum_count += get_count(investigation, diff)
                n_counts[n_others] = sum_count
            diff_overlaps[diff] = n_counts
        
        df_overlaps = pd.DataFrame(diff_overlaps).T
        
        size = 0.4

        ax = axs.flatten()[i]
        ax.pie(df_overlaps.loc[1], radius=1.2, wedgeprops=dict(width=size, edgecolor='w'), colors=colors)
        wedges, _ = ax.pie(df_overlaps.loc[0], radius=1.2-size, wedgeprops=dict(width=size, edgecolor='w'), colors=colors)
        ax.set_title(tool, fontsize=font_size)

    ax = axs.flatten()[-1]
    legend = ax.legend(wedges, range(len(tools)), title="Number of agreeing tools", loc='center', fontsize=font_size)

    plt.setp(legend.get_title(), fontsize=font_size)

    for direction in ['top', 'right', 'left', 'bottom']:
        ax.spines[direction].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(outfile)
    plt.close()


In [40]:
def plot_distances(diff_df: Dict[int, pd.DataFrame], outdir: str):
    os.makedirs(outdir, exist_ok=True)
    for diff, df in diff_df.items():
        if diff == 0:
            continue
        df = df.groupby(["chrom", "start_group", "end_group"]).aggregate({"start": list, "end": list, "tool": "nunique"})
        df = df[df["tool"] >= 4]
        df["Start dist"] = df["start"].apply(lambda x: max(x) - min(x))
        df["End dist"] = df["end"].apply(lambda x: max(x) - min(x))
        df["start"] = df["start"].apply(max)
        df["end"] = df["end"].apply(min)
        df["length"] = df["end"] - df["start"]
        df_melted = df.melt(id_vars=["length"], value_vars=["Start dist", "End dist"], var_name="Distance type", value_name="distance")
        df_melted = df_melted[df_melted["distance"] > 0]
        df_melted["goodness"] = df_melted["length"] / df_melted["distance"]
        categories = [2, 5, 10, 20, 50, 100, 200, 500, 100000000]
        df_melted["Length/Intra-cluster distance"] = pd.cut(df_melted["goodness"], bins=categories, labels=[f"{categories[i]}-{categories[i+1]}" for i in range(len(categories)-2)] + [f"{categories[-2]}+"])
        
        # Scatter plot
        plt.figure(figsize=(7,5))
        ax = sns.scatterplot(df_melted, x="length", y="distance", hue="Length/Intra-cluster distance", style="Distance type")
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
        plt.xscale("log")
        plt.title(f"Max shift: {diff}", fontsize=16)
        plt.xlabel("Minimum BSJ length", fontsize=14)
        plt.ylabel("Maximum intra-cluster distance", fontsize=14)
        # Set legend font size to 14
        for t in ax.legend_.texts:
            t.set_fontsize(14)
        # plt.axhline(y=5, color = 'r', linestyle = '-')
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"diff_{diff}_scatter.png"))
        plt.close()

        # Bar plot
        value_counts = df_melted["distance"].value_counts().to_dict()
        value_counts = {str(k): value_counts.get(k, 0) for k in range(1, 11)}
        value_counts["11+"] = len(df_melted) - sum(value_counts.values())
        ax = sns.barplot(x=list(value_counts.keys()), y=list(value_counts.values()))
        ax.bar_label(ax.containers[0])
        plt.yscale("log")
        plt.title(f"Max shift: {diff}", fontsize=16)
        plt.xlabel("Intra-cluster distance", fontsize=14)
        plt.ylabel("Number of start/end clusters", fontsize=14)
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"diff_{diff}_bar.png"), )
        plt.close()

In [41]:
def plot_densities(df: pd.DataFrame, outdir: str):
    os.makedirs(outdir, exist_ok=True)

    df['d0'] = df.groupby(["chrom", "start", "end"])['tool'].transform('nunique')
    df['d1'] = df.groupby(["chrom", "start_group", "end_group"])['tool'].transform('nunique')

    def get_binned(df: pd.DataFrame, bin_col: str, n_bins: int, match_col: str = 'd0', log: bool = False):
        c_bin = df[bin_col].copy()
        if log:
            c_bin = np.log1p(c_bin)
        bins = np.linspace(c_bin.min(), c_bin.max(), n_bins)
        df['bin'] = pd.cut(c_bin, bins, include_lowest=True)
        df_grouped = df.groupby('bin').aggregate(match_col).value_counts(normalize=True).unstack().fillna(0)

        df_grouped.index = df_grouped.index.map(lambda x: x.mid)

        return df_grouped

    for tool in tools:
        if tool == "segemehl" and False:
            continue

        df_tool = df[df['tool'] == tool].copy()

        wb = ma.WhiteBoard(width=3, height=3)
        # Reserve empty canvas for drawing latter
        barplot_size = 0.8
        barplot_pad = 0.1
        wb.add_canvas("top", size=barplot_size, pad=barplot_pad, name="x1")
        wb.add_canvas("left", size=barplot_size, pad=barplot_pad, name="y1")
        wb.add_canvas("top", size=barplot_size, pad=barplot_pad, name="x2")
        wb.add_canvas("left", size=barplot_size, pad=barplot_pad, name="y2")
        wb.render()

        values = np.vstack([df_tool["n_samples"], df_tool["mean_counts"]])
        kernel = np.log(gaussian_kde(values)(values))
        main_ax = wb.get_main_ax()
        main_ax.yaxis.tick_right()
        main_ax.scatter(df_tool["n_samples"], df_tool["mean_counts"], c=kernel, s=10)
        main_ax.set_yscale("log")
        main_ax.set_xlabel("Number of samples")
        main_ax.set_ylabel("Mean count in detecting samples")
        main_ax.yaxis.set_label_position("right")

        cmap = plt.colormaps["tab20c"]
        colors = cmap(np.arange(len(tools))*4)

        kwargs = {
            'stacked': True,
            'legend': False,
            'width': 1,
            'color': colors
        }

        n_bins = 25

        x1_ax = wb.get_ax("x1")
        get_binned(df_tool, "n_samples", n_bins, 'd0').plot(kind='bar', ax=x1_ax, **kwargs)
        sns.despine(ax=x1_ax, bottom=False, top=True)
        x1_ax.tick_params(bottom=False, labelbottom=False)
        x1_ax.set_xlabel("")
        x1_ax.set_ylabel("Max shift: 0", rotation=0)
        x1_ax.yaxis.set_label_coords(-0.15, 0.4)

        y1_ax = wb.get_ax("y1")
        get_binned(df_tool, "mean_counts", n_bins, 'd0', log=True).plot(kind='barh', ax=y1_ax, **kwargs)
        sns.despine(ax=y1_ax, left=True, right=False)
        y1_ax.tick_params(right=False, labelright=False)
        y1_ax.set_ylabel("")
        y1_ax.set_xlabel("Max shift: 0")

        for tick in y1_ax.get_xticklabels():
            tick.set_rotation(90)
        y1_ax.invert_xaxis()

        x2_ax = wb.get_ax("x2")
        get_binned(df_tool, "n_samples", n_bins, 'd1').plot(kind='bar', ax=x2_ax, **kwargs)
        sns.despine(ax=x2_ax, bottom=False, top=True)
        x2_ax.tick_params(bottom=False, labelbottom=False)
        x2_ax.set_xlabel("")
        x2_ax.set_ylabel("Max shift: 1", rotation=0)
        x2_ax.yaxis.set_label_coords(-0.15, 0.4)

        y2_ax = wb.get_ax("y2")
        get_binned(df_tool, "mean_counts", n_bins, 'd1', log=True).plot(kind='barh', ax=y2_ax, **kwargs)
        sns.despine(ax=y2_ax, left=True, right=False)
        y2_ax.tick_params(right=False, labelright=False)
        y2_ax.set_ylabel("")
        y2_ax.set_xlabel("Max shift: 1")

        for tick in y2_ax.get_xticklabels():
            tick.set_rotation(90)
        y2_ax.invert_xaxis()

        wb.figure.savefig(os.path.join(outdir, f"{tool}.png"), bbox_inches='tight')
        plt.close()

In [42]:
def load_data(min_samples = 0):
    tool_counts = {}
    tool_stats = {}

    for tool in tools:
        tool_counts[tool], tool_stats[tool] = get_tool_data(tool, min_samples)

    df_stats = pd.concat(tool_stats.values(), axis=0)

    return tool_counts, tool_stats, df_stats

In [43]:
def plot(tool_counts: pd.DataFrame, tool_stats: pd.DataFrame, df_stats: pd.DataFrame, outdir: str):
    bar_plot(tool_stats, os.path.join(outdir, "n_bsjs_detected.png"))
    
    diffs = [0, 1, 2, 3, 4, 5, 10, 20]
    diff_df = {diff: get_diff_groups(df_stats, diff) for diff in diffs}

    plot_distances(diff_df, os.path.join(outdir, "distances"))
    diff_plotdata = plot_diff_upset(diff_df, os.path.join(outdir, "upset"))
    plot_pies(diff_plotdata, os.path.join(outdir, "pies.png"))

    plot_densities(diff_df[1], os.path.join(outdir, "density"))

In [44]:
data = load_data()

In [45]:
os.makedirs(OUT_DIR, exist_ok=True)
plot(*data, OUT_DIR)

KeyboardInterrupt: 