In [2]:
from tqdm import tqdm
import polars as pl
import pandas as pd
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from upsetplot import UpSet, from_memberships
from itertools import combinations
from typing import List, Dict
import warnings
import marsilea as ma
from scipy.stats import gaussian_kde

# Ignore future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

plt.rcParams['figure.dpi'] = 300

In [3]:
DETECTION_DIR = "detection"
OUT_DIR = "../chapters/4_results_and_discussion/figures/detection"
tools = [tool_csv[:-4] for tool_csv in os.listdir(DETECTION_DIR)]
tools

['find_circ', 'segemehl', 'dcc', 'circexplorer2', 'ciri2']

In [4]:
def parse_locstring(locstring: str):
    chrom, coords, strand = locstring.split(":")
    start, end = coords.split("-")
    return {"chr": chrom, "start": int(start), "end": int(end), "strand": strand}

In [5]:
tool_dfs = [
    pl.scan_csv(os.path.join(DETECTION_DIR, f"{tool}.csv"), separator='\t')
        .with_columns(tool=pl.lit(tool))
    for tool in tools
]

samples = tool_dfs[0].drop("id", "tool").collect_schema()

tool_dfs = [df.select("id", "tool", *samples).with_columns(**{sample: pl.col(sample).cast(int) for sample in samples}) for df in tool_dfs]

df = pl.concat(tool_dfs)

df = df.with_columns(
    total_counts=pl.sum_horizontal(samples),
    n_samples=pl.sum_horizontal(pl.col(samples).is_not_null()),
    location=pl.col("id").map_elements(parse_locstring, return_dtype=pl.Struct)
)

df = df.with_columns(
    mean_counts=pl.col("total_counts") / pl.col("n_samples")
)

df = df.with_columns(
    chr=pl.col("location").map_elements(lambda x: x["chr"], return_dtype=str),
    start=pl.col("location").map_elements(lambda x: x["start"], return_dtype=int),
    end=pl.col("location").map_elements(lambda x: x["end"], return_dtype=int),
    strand=pl.col("location").map_elements(lambda x: x["strand"], return_dtype=str)
)

df = df.select("tool", "chr", "start", "end", "total_counts", "n_samples", "mean_counts")

In [6]:
tool_hits = df.group_by("tool").len().rename({"len": "n_bsjs"})

ax = sns.barplot(tool_hits.collect(), x="tool", y="n_bsjs")
ax.bar_label(ax.containers[0])
plt.ylabel("Number of BSJs detected")
plt.xlabel("Tool")
plt.title("Number of BSJs detected by each tool")
plt.savefig(os.path.join(OUT_DIR, "n_bsjs_detected.png"))
plt.close()

In [7]:
df = df.collect().lazy()

In [8]:
def identify_shift_partners(df: pl.LazyFrame, max_shift: int = 0):
    df = df.select("chr", "start", "end", "tool")
    df = df.group_by("chr", "start", "end").agg(tools= pl.col("tool").unique())

    df = df.sort("chr", "end"  ).with_columns(end_group  =pl.col("end"  ).diff().fill_null(0).gt(max_shift).cum_sum())
    df = df.sort("chr", "start").with_columns(start_group=pl.col("start").diff().fill_null(0).gt(max_shift).cum_sum())

    df = df.join(df, on=["chr", "start_group", "end_group"], how="inner")
    df = df.select("chr", "start", "end", "start_right", "end_right", "tools_right")
    df = df.filter((pl.col("start") - pl.col("start_right")).abs() <= max_shift)
    df = df.filter((pl.col("end") - pl.col("end_right")).abs() <= max_shift)
    df = df.group_by("chr", "start", "end").agg(
        tools = pl.col("tools_right").flatten().unique()
    ).sort("chr", "start", "end")
    df = df.with_columns(n_tools=pl.col("tools").map_elements(lambda x: len(x), return_dtype=int))

    return df

In [9]:
shift_df = {max_shift: identify_shift_partners(df, max_shift).collect() for max_shift in tqdm([0, 1, 2, 3, 4, 5, 10, 20, 50])}

  0%|          | 0/9 [00:00<?, ?it/s]

 67%|██████▋   | 6/9 [01:16<00:51, 17.09s/it]

In [14]:
outdir = os.path.join(OUT_DIR, "upset")
os.makedirs(outdir, exist_ok=True)

for shift, df_shift in shift_df.items():
    plotdata = from_memberships(df_shift["tools"])

    upset = UpSet(plotdata, subset_size="count", min_degree=2, min_subset_size=10)
    upset.plot()
    # plt.title(f"Max shift: {diff}, {"considering" if include_strand else "ignoring"} strand", fontsize=16)

    plt.savefig(os.path.join(outdir, f"shift_{shift}.png"))
    plt.close()

In [29]:
dfs = []
for shift, df_shift in shift_df.items():
    df_shift = df_shift.group_by("n_tools").len().to_pandas()
    df_shift = df_shift.sort_values("n_tools")
    df_shift.index = df_shift["n_tools"]
    df_shift = df_shift.drop("n_tools", axis=1)
    df_shift = df_shift.rename({"len": shift}, axis=1)
    dfs.append(df_shift)

df_agg = pd.concat(dfs, axis=1).fillna(0).T
# Divide by row sum and multiply by 100 to get percentage
df_agg = df_agg.div(df_agg.sum(axis=1), axis=0) * 100
ax = df_agg.plot(kind='bar', stacked=True)
# Set legend position to right
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), title="Number of tools")
# Rotate x labels
plt.xticks(rotation=0)
plt.xlabel("Max. shift")
plt.ylabel("Percentage of BSJs")
plt.title("Influence of max. shift on tool agreement", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "shift_agreement.png"))
plt.close()