In [101]:
from collections import defaultdict
from itertools import product

from tqdm import tqdm
import polars as pl
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
from upsetplot import UpSet, from_memberships
import warnings

# Ignore future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

plt.rcParams['figure.dpi'] = 300

In [102]:
DETECTION_DIR = "detection"
QUANTIFICATION_DIR = "quantification"
OUT_DIR = "../chapters/4_results_and_discussion/figures/detection"
os.makedirs(OUT_DIR, exist_ok=True)
tools = [tool_csv[:-4] for tool_csv in os.listdir(DETECTION_DIR)]
tools

['find_circ', 'segemehl', 'dcc', 'circexplorer2', 'ciri2']

In [103]:
def parse_locstring(locstring: str):
    chrom, coords, strand = locstring.split(":")
    start, end = coords.split("-")
    return {"chr": chrom, "start": int(start), "end": int(end), "strand": strand}

In [104]:
tool_dfs = [
    pl.scan_csv(os.path.join(DETECTION_DIR, f"{tool}.csv"), separator='\t')
        .with_columns(tool=pl.lit(tool))
    for tool in tools
]

samples = tool_dfs[0].drop("id", "tool").collect_schema()

tool_dfs = [df.select("id", "tool", *samples).with_columns(**{sample: pl.col(sample).cast(int) for sample in samples}) for df in tool_dfs]

df = pl.concat(tool_dfs)

df = df.with_columns(
    total_counts=pl.sum_horizontal(samples),
    n_samples=pl.sum_horizontal(pl.col(samples).is_not_null()),
    location=pl.col("id").map_elements(parse_locstring, return_dtype=pl.Struct)
)

df = df.with_columns(
    mean_counts=pl.col("total_counts") / pl.col("n_samples")
)

df = df.with_columns(
    chr=pl.col("location").map_elements(lambda x: x["chr"], return_dtype=str),
    start=pl.col("location").map_elements(lambda x: x["start"], return_dtype=int),
    end=pl.col("location").map_elements(lambda x: x["end"], return_dtype=int),
    strand=pl.col("location").map_elements(lambda x: x["strand"], return_dtype=str)
)

df = df.select("tool", "chr", "start", "end", "strand", "total_counts", "n_samples", "mean_counts")

In [105]:
tool_hits = df.group_by("tool").len().rename({"len": "n_bsjs"})

ax = sns.barplot(tool_hits.collect(), x="tool", y="n_bsjs")
ax.bar_label(ax.containers[0])
plt.ylabel("Number of BSJs detected")
plt.xlabel("Tool")
plt.title("Number of BSJs detected by each tool")
plt.savefig(os.path.join(OUT_DIR, "n_bsjs_detected.png"))
plt.close()

In [106]:
df = df.collect().lazy()

In [107]:
def identify_shift_partners(df: pl.LazyFrame, max_shift: int = 0, consider_strand = True):
    df = df.select("chr", "start", "end", "strand", "tool")
    df = df.group_by("chr", "start", "end", "strand").agg(tools= pl.col("tool").unique())

    df = df.sort("end"  ).with_columns(end_group  =pl.col("end"  ).diff().fill_null(0).gt(max_shift).cum_sum())
    df = df.sort("start").with_columns(start_group=pl.col("start").diff().fill_null(0).gt(max_shift).cum_sum())

    group_cols = ["chr", "start_group", "end_group"] + (["strand"] if consider_strand else [])
    df = df.join(df, on=group_cols, how="inner")
    df = df.select("chr", "start", "end", "strand", "start_right", "end_right", "tools_right")
    df = df.filter((pl.col("start") - pl.col("start_right")).abs() <= max_shift)
    df = df.filter((pl.col("end") - pl.col("end_right")).abs() <= max_shift)
    df = df.group_by("chr", "start", "end", "strand").agg(
        tools = pl.col("tools_right").flatten().unique()
    ).sort("chr", "start", "end", "strand")
    df = df.with_columns(n_tools=pl.col("tools").map_elements(lambda x: len(x), return_dtype=int))
    df = df.with_columns(shift=pl.lit(max_shift), consider_strand=pl.lit(consider_strand))

    return df

In [108]:
shift_stranded_df = defaultdict(dict)
df_list = []

for max_shift, consider_strand in tqdm(product([0, 1, 2, 3, 4, 5, 10, 20, 50], [True, False])):
    df_current = identify_shift_partners(df, max_shift, consider_strand).collect().lazy()
    shift_stranded_df[max_shift][consider_strand] = df_current
    df_list.append(df_current)

0it [00:00, ?it/s]

8it [00:50,  6.31s/it]


In [109]:
outdir = os.path.join(OUT_DIR, "upset")
os.makedirs(outdir, exist_ok=True)

for shift, stranded_df in shift_stranded_df.items():
    for stranded, df_current in stranded_df.items():
        plotdata = from_memberships(df_current.collect()["tools"])

        upset = UpSet(plotdata, subset_size="count", min_degree=2 if shift > 0 else None, min_subset_size=10)
        upset.plot()

        plt.savefig(os.path.join(outdir, f"shift_{shift}_{"stranded" if stranded else "unstranded"}.png"))
        plt.close()

In [110]:
df_concat = pl.concat(df_list)
df_concat = df_concat.group_by("shift", "consider_strand", "n_tools").len()

In [111]:
df_stranded = df_concat.filter(pl.col("consider_strand")).collect().to_pandas()
df_unstranded = df_concat.filter(~pl.col("consider_strand")).collect().to_pandas()

In [112]:
def plot(df: pd.DataFrame):
    ax = df.pivot(index="shift", columns="n_tools", values="len").fillna(0).plot.bar(stacked=True)
    plt.xlabel("Shift")
    plt.ylabel("Number of BSJs")
    ax.set_ylim(0, 1.3e6)

    plt.legend(title="Number of tools", bbox_to_anchor=(1.05, 1), loc='upper left')

In [121]:
plot(df_stranded)
plt.title("Number of BSJs detected by multiple tools, considering strand")
plt.savefig(os.path.join(OUT_DIR, "shift_agreement_stranded.png"))
plt.close()

In [122]:
plot(df_unstranded)
plt.title("Number of BSJs detected by multiple tools, ignoring strand")
plt.savefig(os.path.join(OUT_DIR, "shift_agreement_unstranded.png"))
plt.close()

In [125]:
selected_shift = 3
selected_stranded = False
min_tools = 4

In [126]:
filtered = shift_stranded_df[selected_shift][selected_stranded].filter(pl.col("n_tools") >= min_tools).collect().lazy()
len(filtered.collect())

106563

In [127]:
for tool in tools:
    df_tool = pl.scan_csv(os.path.join(DETECTION_DIR, f"{tool}.csv"), separator='\t')

    df_tool = df_tool.select("id", *samples)
    df_tool = df_tool.with_columns(**{sample: pl.col(sample).cast(int) for sample in samples})

    df_tool = df_tool.with_columns(
        location=pl.col("id").map_elements(parse_locstring, return_dtype=pl.Struct)
    )

    df_tool = df_tool.with_columns(
        chr=pl.col("location").map_elements(lambda x: x["chr"], return_dtype=str),
        start=pl.col("location").map_elements(lambda x: x["start"], return_dtype=int),
        end=pl.col("location").map_elements(lambda x: x["end"], return_dtype=int),
        strand=pl.col("location").map_elements(lambda x: x["strand"], return_dtype=str)
    )

    df_tool = df_tool.join(filtered, on=["chr", "start", "end", "strand"], how="inner")
    df_tool = df_tool.with_columns(circ_id=pl.col("chr") + pl.lit(":") + pl.col("start").cast(str) + pl.lit("-") + pl.col("end").cast(str) + pl.lit(":") + pl.col("strand"))
    df_tool = df_tool.with_columns(gene_id=None)
    df_tool = df_tool.select("circ_id", "gene_id", *samples)
    df_tool.sink_csv(os.path.join(QUANTIFICATION_DIR, f"{tool}.tsv"), include_header=True, separator='\t')