In [120]:
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
import marsilea as ma

In [121]:
DETECTION_DIR = "detection"
QUANTIFICATION_DIR = "quantification"
OUT_DIR = "../chapters/4_results_and_discussion/figures/quantification"
detection_tools = [tool_csv[:-4] for tool_csv in os.listdir(DETECTION_DIR)]
quantification_tools = [tool_csv[:-4] for tool_csv in os.listdir(QUANTIFICATION_DIR)]

In [122]:
def parse_locstring(locstring: str):
    parts = locstring.split(":")
    chrom, coords, strand = parts if len(parts) == 3 else (parts[0], parts[1], ".")
    chrom = chrom[len("circ_"):] if chrom.startswith("circ_") else chrom
    start, end = coords.split("-")
    return chrom, int(start), int(end), strand

In [123]:
def get_tool_data(tool: str, is_detection: bool = True):
    df = pd.read_csv(os.path.join(DETECTION_DIR if is_detection else QUANTIFICATION_DIR, f"{tool}.{"c" if is_detection else "t"}sv"), sep='\t', index_col=0)
    df.fillna(0, inplace=True)
    if not is_detection:
        df = df.iloc[:, 1:]

    df_loc = pd.DataFrame([parse_locstring(loc) for loc in df.index], columns=["chrom", "start", "end", "strand"], index=df.index)
    df_loc["tool"] = tool

    # Merge df_stats and df_loc
    df = pd.concat([df_loc, df], axis=1)

    return pd.melt(df, id_vars=["chrom", "start", "end", "strand", "tool"], var_name="sample", value_name="count", ignore_index=False)

In [124]:
def get_diff_groups(df_locs: pd.DataFrame, max_diff: int = 0):
    df_diff = df_locs.sort_values(["chrom", "end"])
    df_diff["end_group"] = df_diff.groupby("chrom")["end"].diff().gt(max_diff).cumsum()
    df_diff = df_diff.sort_values(["chrom", "start"])
    df_diff["start_group"] = df_diff.groupby("chrom")["start"].diff().gt(max_diff).cumsum()

    return df_diff

In [None]:
tool_counts = {}

for tool in detection_tools:
    tool_counts[tool] = get_tool_data(tool)

for tool in quantification_tools:
    tool_counts[tool] = get_tool_data(tool, is_detection=False)

tool_counts["circexplorer2"]

In [None]:
df_stats = pd.concat(tool_counts.values(), axis=0)
df_diff = get_diff_groups(df_stats, 1)
df_diff

In [None]:
df_grouped = df_diff.groupby(["chrom", "start_group", "end_group", "sample"]).agg({
    "tool": list,
    "total_counts": list
})

df_grouped = df_grouped[df_grouped["tool"].apply(lambda x: len(set([t for t in x if t in detection_tools]))) >= 4]
df_grouped

In [None]:
df_grouped["count_dict"] = df_grouped.apply(lambda row: dict(zip(row["tool"], row["total_counts"])), axis=1)
df_grouped.drop(columns=["tool", "total_counts"], inplace=True)
df_grouped.fillna(0, inplace=True)

# Explode the count_dict
df_grouped = df_grouped["count_dict"].apply(pd.Series)
df_detection = df_grouped[detection_tools]
df_quantification = df_grouped[quantification_tools]
df_agg = pd.DataFrame(index=df_grouped.index)
df_agg["sum"]    = df_detection.sum(axis=1)
df_agg["min"]    = df_detection.min(axis=1)
df_agg["max"]    = df_detection.max(axis=1)
df_agg["median"] = df_detection.median(axis=1)

df_plot = pd.concat([df_detection, df_agg, df_quantification], axis=1)
df_plot

In [None]:
# Calculate the correlation matrix
corr = df_plot.corr()

# Invert column order
corr = corr.iloc[::-1]

h = ma.Heatmap(corr, annot=True, fmt=".2f")

categories = ["Detection", "Aggregations", "Quant."]
colors = ["#54F0F0", "#F05454", "#F0F054"]

h.group_cols([categories[0] if group in detection_tools else categories[2] if group in quantification_tools else categories[1] for group in corr.columns], order=categories)
h.group_rows([categories[0] if group in detection_tools else categories[2] if group in quantification_tools else categories[1] for group in corr.index], order=categories[::-1])
h.add_bottom(ma.plotter.Chunk(categories, fill_colors=colors), pad=0.05)
h.add_bottom(ma.plotter.Labels(corr.columns), pad=0.05)
h.add_left(ma.plotter.Chunk(categories[::-1], fill_colors=colors[::-1]), pad=0.05)
h.add_left(ma.plotter.Labels(corr.index), pad=0.05)

h.add_legends("right")
h.add_title("Correlation between detection tools and aggregations")
h.save(os.path.join(OUT_DIR, "correlation_heatmap.png"))