In [1]:
from collections import defaultdict
from itertools import product

from tqdm.contrib.itertools import product
import polars as pl
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
from upsetplot import UpSet, from_memberships
import warnings

# Ignore future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

plt.rcParams['figure.dpi'] = 300
sns.set_palette("colorblind")
# Set font size
plt.rcParams.update({'font.size': 16})

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DETECTION_DIR = "detection"
QUANTIFICATION_DIR = "quantification"
OUT_DIR = "../chapters/4_results_and_discussion/figures/detection"
os.makedirs(OUT_DIR, exist_ok=True)
tools = [tool_csv[:-4] for tool_csv in os.listdir(DETECTION_DIR)]
tools

['ciri2', 'circexplorer2', 'find_circ', 'dcc', 'segemehl']

In [3]:
def parse_locstring(locstring: str):
    chrom, coords, strand = locstring.split(":")
    start, end = coords.split("-")
    return {"chr": chrom, "start": int(start), "end": int(end), "strand": strand}

In [4]:
tool_dfs = [
    pl.scan_csv(os.path.join(DETECTION_DIR, f"{tool}.csv"), separator='\t')
        .with_columns(tool=pl.lit(tool))
    for tool in tools
]

samples = tool_dfs[0].drop("id", "tool").collect_schema()

tool_dfs = [df.select("id", "tool", *samples).with_columns(**{sample: pl.col(sample).cast(int) for sample in samples}) for df in tool_dfs]

df = pl.concat(tool_dfs)

df = df.with_columns(
    total_counts=pl.sum_horizontal(samples),
    n_samples=pl.sum_horizontal(pl.col(samples).is_not_null()),
    location=pl.col("id").map_elements(parse_locstring, return_dtype=pl.Struct)
)

df = df.with_columns(
    mean_counts=pl.col("total_counts") / pl.col("n_samples")
)

df = df.with_columns(
    chr=pl.col("location").map_elements(lambda x: x["chr"], return_dtype=str),
    start=pl.col("location").map_elements(lambda x: x["start"], return_dtype=int),
    end=pl.col("location").map_elements(lambda x: x["end"], return_dtype=int),
    strand=pl.col("location").map_elements(lambda x: x["strand"], return_dtype=str)
)

df = df.select("tool", "chr", "start", "end", "strand", "total_counts", "n_samples", "mean_counts")

In [None]:
df = df.with_columns(start=pl.when(pl.col("tool").is_in(["ciri2", "dcc"])).then(pl.col("start")-1).otherwise(pl.col("start")))
df = df.with_columns(strand=pl.when(pl.col("tool").is_in(["dcc", "segemehl"])).then(pl.col("strand").map_elements(lambda x: "-" if x == "+" else "+", return_dtype=str)).otherwise(pl.col("strand")))

In [6]:
tool_hits = df.group_by("tool").len().rename({"len": "n_bsjs"}).sort("n_bsjs")

ax = sns.barplot(tool_hits.collect(), x="tool", y="n_bsjs")
ax.bar_label(ax.containers[0])
plt.ylabel("Number of BSJs detected", fontsize=20)
plt.xlabel("Tool", fontsize=20)
plt.title("Number of BSJs detected by each tool", fontsize=25)
# Increase width of plot
plt.gcf().set_size_inches(10, 5)
plt.ylim(0, 1.2e6)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "n_bsjs_detected.png"))
plt.close()

In [7]:
df = df.collect().lazy()

In [8]:
def identify_shift_partners(df: pl.LazyFrame, max_shift: int = 0, consider_strand = True):
    df = df.select("chr", "start", "end", "strand", "tool")
    df = df.group_by("chr", "start", "end", "strand").agg(tools= pl.col("tool").unique())

    df = df.sort("end"  ).with_columns(end_group  =pl.col("end"  ).diff().fill_null(0).gt(max_shift).cum_sum())
    df = df.sort("start").with_columns(start_group=pl.col("start").diff().fill_null(0).gt(max_shift).cum_sum())

    group_cols = ["chr", "start_group", "end_group"] + (["strand"] if consider_strand else [])
    df = df.join(df, on=group_cols, how="inner")
    df = df.select("chr", "start", "end", "strand", "start_right", "end_right", "tools_right")
    df = df.filter((pl.col("start") - pl.col("start_right")).abs() <= max_shift)
    df = df.filter((pl.col("end") - pl.col("end_right")).abs() <= max_shift)
    df = df.group_by("chr", "start", "end", "strand").agg(
        tools = pl.col("tools_right").flatten().unique()
    ).sort("chr", "start", "end", "strand")
    df = df.with_columns(n_tools=pl.col("tools").map_elements(lambda x: len(x), return_dtype=int))
    df = df.with_columns(shift=pl.lit(max_shift), consider_strand=pl.lit(consider_strand))

    return df

In [9]:
shift_stranded_df = defaultdict(dict)
df_list = []

for max_shift, consider_strand in product([0, 1, 2, 3, 4, 5, 10, 20, 50], [True, False]):
    df_current = identify_shift_partners(df, max_shift, consider_strand).collect().lazy()
    shift_stranded_df[max_shift][consider_strand] = df_current
    df_list.append(df_current)

  0%|          | 0/18 [00:00<?, ?it/s]

100%|██████████| 18/18 [12:12<00:00, 40.72s/it] 


In [10]:
outdir = os.path.join(OUT_DIR, "upset")
os.makedirs(outdir, exist_ok=True)

for shift, stranded_df in shift_stranded_df.items():
    for stranded, df_current in stranded_df.items():
        plotdata = from_memberships(df_current.collect()["tools"])

        upset = UpSet(plotdata, subset_size="count", min_degree=2 if shift > 0 else None, min_subset_size=10)
        upset.plot()

        plt.savefig(os.path.join(outdir, f"shift_{shift}_{'stranded' if stranded else 'unstranded'}.png"))
        plt.close()

In [11]:
df_concat = pl.concat(df_list)
df_concat = df_concat.group_by("shift", "consider_strand", "n_tools").len()

In [12]:
df_concat = df_concat.collect()
df_stranded = pd.DataFrame(df_concat.filter(pl.col("consider_strand")), columns=df_concat.columns)
df_unstranded = pd.DataFrame(df_concat.filter(~pl.col("consider_strand")), columns=df_concat.columns)

In [13]:
def plot(df: pd.DataFrame):
    ax = df.pivot(index="shift", columns="n_tools", values="len").fillna(0).plot.bar(stacked=True)
    plt.xlabel("Max shift")
    plt.ylabel("Number of BSJs")
    ax.set_ylim(0, 1.3e6)

    plt.legend(title="Number of tools", bbox_to_anchor=(1.05, 1), loc='upper left')

In [14]:
plot(df_stranded)
plt.title("Number of BSJs detected by multiple tools, considering strand")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "shift_agreement_stranded.png"))
plt.close()

In [15]:
plot(df_unstranded)
plt.title("Number of BSJs detected by multiple tools, ignoring strand")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "shift_agreement_unstranded.png"))
plt.close()

In [16]:
selected_shift = 3
selected_stranded = False
min_tools = 4

In [17]:
filtered = shift_stranded_df[selected_shift][selected_stranded].filter(pl.col("n_tools") >= min_tools).collect().lazy()
len(filtered.collect())

33660

In [18]:
for tool in tools:
    df_tool = pl.scan_csv(os.path.join(DETECTION_DIR, f"{tool}.csv"), separator='\t')

    df_tool = df_tool.select("id", *samples)
    df_tool = df_tool.with_columns(**{sample: pl.col(sample).cast(int) for sample in samples})

    df_tool = df_tool.with_columns(
        location=pl.col("id").map_elements(parse_locstring, return_dtype=pl.Struct)
    )

    df_tool = df_tool.with_columns(
        chr=pl.col("location").map_elements(lambda x: x["chr"], return_dtype=str),
        start=pl.col("location").map_elements(lambda x: x["start"], return_dtype=int),
        end=pl.col("location").map_elements(lambda x: x["end"], return_dtype=int),
        strand=pl.col("location").map_elements(lambda x: x["strand"], return_dtype=str)
    )

    df_tool = df_tool.join(filtered, on=["chr", "start", "end", "strand"], how="inner")
    df_tool = df_tool.with_columns(circ_id=pl.col("chr") + pl.lit(":") + pl.col("start").cast(str) + pl.lit("-") + pl.col("end").cast(str) + pl.lit(":") + pl.col("strand"))
    df_tool = df_tool.with_columns(gene_id=None)
    df_tool = df_tool.select("circ_id", "gene_id", *samples)
    df_tool.sink_csv(os.path.join(QUANTIFICATION_DIR, f"{tool}.tsv"), include_header=True, separator='\t')

thread 'polars-8' panicked at /home/conda/feedstock_root/build_artifacts/polars_1736665968752/work/crates/polars-ops/src/chunked_array/gather/chunked.rs:172:22:
infallible: ShapeMismatch(ErrString("expected struct fields to have given length. given = 41238, field length = 2."))
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread 'polars-35' panicked at /home/conda/feedstock_root/build_artifacts/polars_1736665968752/work/crates/polars-ops/src/chunked_array/gather/chunked.rs:172:22:
infallible: ShapeMismatch(ErrString("expected struct fields to have given length. given = 41238, field length = 3."))
thread 'polars-5' panicked at /home/conda/feedstock_root/build_artifacts/polars_1736665968752/work/crates/polars-ops/src/chunked_array/gather/chunked.rs:172:22:
infallible: ShapeMismatch(ErrString("expected struct fields to have given length. given = 41238, field length = 2."))
thread 'polars-12' panicked at /home/conda/feedstock_root/build_artifacts/polars_173

PanicException: infallible: ShapeMismatch(ErrString("expected struct fields to have given length. given = 41238, field length = 3."))