In [50]:
import pandas as pd
import polars as pl
import itertools
from tqdm import tqdm
import os
import marsilea as ma
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 300

In [51]:
QUANTIFICATION_DIR = "quantification"
OUT_DIR = "../chapters/4_results_and_discussion/figures/quantification"
tools = [tool_csv[:-4] for tool_csv in os.listdir(QUANTIFICATION_DIR)]
tools

['segemehl',
 'find_circ',
 'dcc',
 'psirc',
 'circexplorer2',
 'ciriquant',
 'ciri2']

In [52]:
def parse_locstring(locstring: str):
    parts = locstring.split(":")
    chrom, coords, strand = parts if len(parts) == 3 else (parts[0], parts[1], ".")
    chrom = chrom[len("circ_"):] if chrom.startswith("circ_") else chrom
    start, end = coords.split("-")
    return {"chr": chrom, "start": int(start), "end": int(end), "strand": strand}

In [53]:
tool_dfs = []

for tool in tools:
    file = os.path.join(QUANTIFICATION_DIR, f"{tool}.tsv")
    df = pl.scan_csv(file, separator="\t")
    df = df.with_columns(tool=pl.lit(tool))
    df = df.drop("gene_id")
    df = df.rename({("tx" if tool == "psirc" else "circ_id"): "id"})
    tool_dfs.append(df)

samples = tool_dfs[0].drop("id", "tool").collect_schema()

tool_dfs = [df.select("id", "tool", *samples).with_columns(**{sample: pl.col(sample).cast(float) for sample in samples}) for df in tool_dfs]
tool_dfs = [df.with_columns(**{sample: pl.col(sample) / pl.col(sample).sum() for sample in samples}) for df in tool_dfs]

df = pl.concat(tool_dfs)

df = df.with_columns(location=pl.col("id").map_elements(parse_locstring, return_dtype=pl.Struct))
df = df.with_columns(
    chr=pl.col("location").map_elements(lambda x: x["chr"], return_dtype=str),
    start=pl.col("location").map_elements(lambda x: x["start"], return_dtype=int),
    end=pl.col("location").map_elements(lambda x: x["end"], return_dtype=int)
).drop("location")

df = df.unpivot(on=samples, index=["chr", "start", "end", "id", "tool"], variable_name="sample", value_name="count")
df = df.select("chr", "start", "end", "tool", "sample", "count")
df = df.filter(pl.col("count").is_not_null() & (pl.col("count") > 0))

In [54]:
df = df.collect().lazy()
df.collect()

chr,start,end,tool,sample,count
str,i64,i64,str,str,f64
"""chr1""",12856302,12856922,"""segemehl""","""antiHormonal_20m_ESR1_tamoxife…",0.000645
"""chr1""",16460044,16464434,"""segemehl""","""antiHormonal_20m_ESR1_tamoxife…",0.000645
"""chr1""",34223096,34223225,"""segemehl""","""antiHormonal_20m_ESR1_tamoxife…",0.000645
"""chr1""",38146648,38147548,"""segemehl""","""antiHormonal_20m_ESR1_tamoxife…",0.000645
"""chr1""",38573986,38576027,"""segemehl""","""antiHormonal_20m_ESR1_tamoxife…",0.000645
…,…,…,…,…,…
"""chrX""",151440420,151444714,"""ciri2""","""aging_18m_CYP19A1_no_1""",0.000272
"""chrX""",156355855,156374551,"""ciri2""","""aging_18m_CYP19A1_no_1""",0.000272
"""chrX""",158068262,158100998,"""ciri2""","""aging_18m_CYP19A1_no_1""",0.000544
"""chrX""",158599573,158600294,"""ciri2""","""aging_18m_CYP19A1_no_1""",0.000272


In [55]:
quantification_tools = ["ciriquant", "psirc"]
detection_tools = ["segemehl", "find_circ", "ciri2", "dcc", "circexplorer2"]


def plot(df: pl.LazyFrame, max_shift: int, fillnull: bool):
    df = df.sort("chr", "end"  ).with_columns(end_group  =pl.col("end"  ).diff().fill_null(0).gt(max_shift).cum_sum())
    df = df.sort("chr", "start").with_columns(start_group=pl.col("start").diff().fill_null(0).gt(max_shift).cum_sum())

    group_cols = ["chr", "start_group", "end_group"]
    df = df.group_by(group_cols + ["start", "end"]).len().join(df, on=group_cols, how="inner")
    df = df.select("chr", "start", "end", "sample", "start_right", "end_right", "tool", "count")

    df = df.filter((pl.col("start") - pl.col("start_right")).abs() <= max_shift)
    df = df.filter((pl.col("end") - pl.col("end_right")).abs() <= max_shift)
    df = df.collect().pivot(on="tool", values="count", index=["chr", "start", "end", "sample"], aggregate_function="sum", sort_columns=True).lazy()
    df = df.collect().to_pandas().set_index(["chr", "start", "end", "sample"])

    if fillnull:
        df.fillna(0, inplace=True)

    df_detection = df[detection_tools]
    df_quantification = df[quantification_tools]
    df_agg = pd.DataFrame(index=df.index)
    df_agg["sum"]    = df_detection.sum(axis=1)
    df_agg["min"]    = df_detection.min(axis=1)
    df_agg["max"]    = df_detection.max(axis=1)
    df_agg["mean"]   = df_detection.mean(axis=1)
    df_agg["median"] = df_detection.median(axis=1)

    df_plot = pd.concat([df_detection, df_agg, df_quantification], axis=1)

    # Calculate the correlation matrix
    corr = df_plot.corr()

    # Invert column order
    corr = corr.iloc[::-1]

    h = ma.Heatmap(corr, annot=True, fmt=".2f")

    categories = ["Detection", "Aggregations", "Quant."]
    colors = ["#54F0F0", "#F05454", "#F0F054"]

    h.group_cols([categories[0] if group in detection_tools else categories[2] if group in quantification_tools else categories[1] for group in corr.columns], order=categories)
    h.group_rows([categories[0] if group in detection_tools else categories[2] if group in quantification_tools else categories[1] for group in corr.index], order=categories[::-1])
    h.add_bottom(ma.plotter.Chunk(categories, fill_colors=colors), pad=0.05)
    h.add_bottom(ma.plotter.Labels(corr.columns), pad=0.05)
    h.add_left(ma.plotter.Chunk(categories[::-1], fill_colors=colors[::-1]), pad=0.05)
    h.add_left(ma.plotter.Labels(corr.index), pad=0.05)

    h.add_legends("right")
    h.add_title(f"Correlation matrix; max shift: {max_shift} nt; {"Missing as 0" if fillnull else "Missing as NaN"}")
    h.save(os.path.join(OUT_DIR, f"correlation_heatmap_{max_shift}_{"0" if fillnull else "na"}.png"))
    plt.close()

In [56]:
for max_shift, fillnull in tqdm(itertools.product([0, 1, 3, 5, 10], [True, False])):
    plot(df, max_shift, fillnull)

0it [00:00, ?it/s]

10it [01:38,  9.86s/it]
