In [None]:
from pathlib import Path
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import colors
import pyranges as pr
from pyslavseq.preprocessing import collate_labels, df2tabix
from upsetplot import UpSet

HUE_ORDER = ["KNRGL", "OTHER", "KRGL"]

In [None]:
nrlabels = ["megane"]
rlabels = [
    "primer_sites",
    "l1hs",
    "l1pa2",
    "l1pa3",
    "l1pa4",
    "l1pa5",
    "l1pa6",
]

labels = [
    *rlabels,
    "polyA",
    "polyT",
    *nrlabels,
]

## 1. Read data

In [None]:
# load metadata
meta = pd.read_csv(snakemake.config["donors"], sep="\t", dtype={"donor_id": str})  # type: ignore

# load bulk data
bdata = pd.read_csv(snakemake.input.bulk[0], sep="\t").query("n_reads >= 5")
bdata["donor_id"] = bdata["donor_id"].astype(str)
bdata.columns = bdata.columns.str.replace("#", "")
bdata = bdata.merge(meta, on="donor_id")
bdata["Width"] = bdata["End"] - bdata["Start"]
bdata[labels] = bdata[labels].astype(bool)
print(f"Loaded {len(bdata)} peaks from {bdata['donor_id'].nunique()} donors")

# load l1hs annotations
l1hs = pr.read_bed(snakemake.input.l1hs_rmsk)  # type: ignore
print(f"Loaded {len(l1hs)} l1hs annotations")

# load megane annotations
meg = pr.read_bed(snakemake.input.megane[0])  # type: ignore
print(f"Loaded {len(meg)} megane variants")

bdata["locus"] = (
    bdata["Chromosome"]
    + ":"
    + bdata["Start"].astype(str)
    + "-"
    + bdata["End"].astype(str)
)

In [None]:
g, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.ecdfplot(
    data=bdata[bdata["megane"]],
    x="n_reads",
    hue="donor_id",
    ax=ax,
    alpha=0.5,
    log_scale=True,
    complementary=True,
)

In [None]:
g, axs = plt.subplots(1, 3, figsize=(21, 7), sharey=True)
# remove hspace
g.subplots_adjust(wspace=0.1)
for ax, n in zip(axs, [0, 100, 200]):
    df = bdata.query("n_reads > @n").groupby("cell_id")["megane"].sum().reset_index()
    df = df.join(
        bdata[["cell_id", "donor_id"]].drop_duplicates().set_index("cell_id"),
        on="cell_id",
    )
    sns.barplot(data=df, x="megane", y="donor_id", ax=ax)
    ax.set_title(f"peaks overlapping megane with n_reads > {n}")

In [None]:
ldata, mdata = [], []

for m in snakemake.input.megane:  # type: ignore
    d = Path(m).parent.name
    bdf = bdata.query("donor_id == @d")[
        ["Chromosome", "Start", "End", "n_reads", "n_proper_pairs"]
    ]
    bdf = pr.PyRanges(bdf)

    # megane
    meg = pr.read_bed(m).df
    meg["AC"] = meg["Strand"].astype(int)
    meg = pr.PyRanges(meg[meg["Score"].str.contains("L1HS")])
    mdf = meg.join(bdf, how="left").df
    mdf["n_reads"] = mdf["n_reads"].apply(lambda x: 0 if x < 0 else x)
    mdf["n_proper_pairs"] = mdf["n_proper_pairs"].apply(lambda x: 0 if x < 0 else x)
    mdf["donor_id"] = d
    mdata.append(mdf)

    # l1hs
    ldf = l1hs.join(bdf, how="left").df
    ldf["n_reads"] = ldf["n_reads"].apply(lambda x: 0 if x < 0 else x)
    ldf["n_proper_pairs"] = ldf["n_proper_pairs"].apply(lambda x: 0 if x < 0 else x)
    ldf["donor_id"] = d
    ldata.append(ldf)

ldata = pd.concat(ldata).merge(meta, on="donor_id")
ldata["locus"] = tuple(zip(ldata["Chromosome"], ldata["Start"], ldata["End"]))
ldata = ldata.groupby(["donor_id", "race", "locus"])["n_reads"].max().reset_index()
ldata["locus"] = ldata["locus"].astype(str)
mdata = pd.concat(mdata).merge(meta, on="donor_id")
mdata["locus"] = tuple(zip(mdata["Chromosome"], mdata["Start"], mdata["End"]))
mdata = (
    mdata.groupby(["donor_id", "race", "locus", "AC"])["n_reads"].max().reset_index()
)
mdata["locus"] = mdata["locus"].astype(str)

## Known germline coverage

In [None]:
# cdf plots
g, axs = plt.subplots(1, 2, figsize=(16, 8))
cols = sns.color_palette("tab10", n_colors=2)

opts = {
    "hue_order": ["CAUC", "AA"],
    "hue": "race",
    "palette": {"CAUC": cols[0], "AA": cols[1]},
    "alpha": 0.5,
    "stat": "count",
    "log_scale": True,
}


for d in bdata["donor_id"].unique():
    for ax, data in zip(axs, [ldata, mdata]):
        df = data.query("donor_id == @d")
        df = df.loc[df.groupby("locus")["n_reads"].idxmax()]
        # add one for log scale
        df["n_reads"] = df["n_reads"] + 1
        sns.ecdfplot(df, x="n_reads", ax=ax, **opts)

axs[0].set_title("# Reference L1HS")
axs[1].set_title("# Non-Reference L1HS (detected from WGS)")

In [None]:
# KRGL heatmap
df = (
    ldata.groupby(["locus", "donor_id"])["n_reads"]
    .max()
    .reset_index()
    .pivot_table(index="locus", columns="donor_id", values="n_reads")
    .fillna(0)
)

log_norm = colors.LogNorm(vmin=df.min().min() + 1, vmax=df.max().max())
sns.clustermap(df, cmap="viridis", norm=log_norm, yticklabels=False, method="ward")

In [None]:
# KNRGL heatmap
assert len(mdata[["locus", "donor_id"]]) == len(
    mdata[["locus", "donor_id"]].drop_duplicates()
), "duplicate locus-donor pairs found!"
locus_order = mdata.groupby(["locus"])["AC"].sum().sort_values().index

df = (
    mdata.groupby(["locus", "donor_id"])["n_reads"]
    .max()
    .reset_index()
    .pivot_table(index="locus", columns="donor_id", values="n_reads")
    .fillna(0)
)

log_norm = colors.LogNorm(vmin=df.min().min() + 1, vmax=df.max().max())
sns.clustermap(
    df.loc[locus_order],
    cmap="viridis",
    norm=log_norm,
    yticklabels=False,
    row_cluster=False,
    col_cluster=False,
)

## 2. Peak summary stats - unlabelled

In [None]:
data = (
    bdata.groupby(["libd_id", "race", "diagnosis", "donor_id", "age"])
    .size()
    .reset_index(name="n_peaks")
)

In [None]:
features_scale = [
    ("Width", False),
    ("n_reads", True),
    ("n_ref_reads", True),
    ("n_unique_5end", True),
    ("n_unique_3end", True),
    ("n_unique_clipped_3end", True),
    ("n_duplicates", True),
    ("three_end_clipped_length_mean", False),
    ("three_end_clipped_length_q0", False),
    ("three_end_clipped_length_q1", False),
    ("three_end_clippedA_mean", False),
    ("three_end_clippedA_q0", False),
    ("three_end_clippedA_q1", False),
    ("alignment_score_normed_mean", False),
    ("alignment_score_normed_q0", False),
    ("alignment_score_normed_q1", False),
    ("L1_alignment_score_mean", False),
    ("L1_alignment_score_q0", False),
    ("L1_alignment_score_q1", False),
]

# subplots
g, axs = plt.subplots(5, 4, figsize=(24, 30))
axs = axs.flatten()

# setup colors
cols = sns.color_palette("tab10", n_colors=2)
opts = {
    "hue_order": ["CAUC", "AA"],
    "hue": "race",
    "palette": {"CAUC": cols[0], "AA": cols[1]},
    "alpha": 0.5,
}

for i, (f, s) in enumerate(features_scale):
    for g, df in bdata.groupby("libd_id"):
        sns.ecdfplot(data=df, x=f, log_scale=s, ax=axs[i], **opts)

## 3. Peak summary stats - labelled by annotation

In [None]:
def plot_ecdf(data, features_scale, labels):

    _, axs = plt.subplots(
        len(features_scale),
        len(labels),
        figsize=(len(labels) * 6, len(features_scale) * 6),
        sharex="row",
    )

    cols = sns.color_palette("tab10", n_colors=2)
    opts = {
        "hue_order": ["CAUC", "AA"],
        "hue": "race",
        "palette": {"CAUC": cols[0], "AA": cols[1]},
        "alpha": 0.5,
    }

    for i, (f, s) in enumerate(features_scale):
        for j, l in enumerate(labels):
            for g, df in data.groupby("libd_id"):
                sns.ecdfplot(data=df[df[l]], x=f, log_scale=s, ax=axs[i, j], **opts)
                axs[i, j].set_title(l)

    plt.show()


def plot_2dhist(data, features_scale, labels):
    _, axs = plt.subplots(
        len(features_scale),
        len(labels),
        figsize=(len(labels) * 6, len(features_scale) * 6),
        sharex="row",
        sharey="row",
    )

    for i, (f, s) in enumerate(features_scale):
        if f == "n_reads":
            continue
        for j, l in enumerate(labels):
            df = data[data[l]]
            if s == True:
                df = df[df[f] > 0]
            sns.histplot(
                data=df, x="n_reads", y=f, log_scale=(True, s), bins=50, ax=axs[i, j]
            )
            axs[i, j].set_title(l)

In [None]:
ax_dict = UpSet(
    bdata[labels].value_counts(),
    sort_categories_by="input",
    min_subset_size=100,
    show_counts=True,
).plot()

In [None]:
bdata["other"] = bdata[labels].sum(axis=1) == 0
labels.append("other")

In [None]:
plot_ecdf(bdata, features_scale, labels)

In [None]:
plot_2dhist(bdata, features_scale, labels)

## Remove reference insertions

In [None]:
# make booleans of all
nrdata = bdata[
    ~bdata.primer_sites
    & ~bdata.l1hs
    & ~bdata.l1pa2
    & ~bdata.l1pa3
    & ~bdata.l1pa4
    & ~bdata.l1pa5
    & ~bdata.l1pa6
    & ~bdata.primer_sites
    & ~bdata.polyA
    & ~bdata.polyT
].copy()
nrdata = nrdata[nrdata.n_ref_reads == 0].copy()

In [None]:
ax_dict = UpSet(nrdata[nrlabels].value_counts()).plot()

In [None]:
nrdata["other"] = nrdata[nrlabels].sum(axis=1) == 0
nrlabels.append("other")

In [None]:
plot_ecdf(nrdata, features_scale, nrlabels)

In [None]:
plot_2dhist(nrdata, features_scale, nrlabels)

## Apply final filter

In [None]:
bdata = bdata.query(
    "max_mapq == 60 and n_reads > 10 and alignment_score_q1 > 60 and three_end_clipped_length_q1 > 50 and three_end_clippedA_q1 > 0 and n_unique_5end > 0"
).sort_values(["Chromosome", "Start"])
nrdata = nrdata.query(
    "max_mapq == 60 and n_reads > 10 and alignment_score_q1 > 60 and three_end_clipped_length_q1 > 50 and three_end_clippedA_q1 > 0 and n_unique_5end > 0"
).sort_values(["Chromosome", "Start"])
plot_2dhist(nrdata, features_scale, nrlabels)

In [None]:
labels.remove("other")
ax_dict = UpSet(
    bdata[labels].value_counts(),
    sort_categories_by="input",
    min_subset_size=100,
    show_counts=True,
).plot()

## Examine sharing across individuals

In [None]:
nrdata = pr.PyRanges(nrdata).cluster().df
donors_per_peak = nrdata.groupby("Cluster", observed=True)["donor_id"].nunique()
nrdata["n_donors"] = nrdata["Cluster"].map(donors_per_peak)

In [None]:
g, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
sns.ecdfplot(nrdata, x="n_donors", ax=ax1)
ax1.set_xlabel("n_donors (all peaks)")
sns.ecdfplot(donors_per_peak, ax=ax2)
ax2.set_xlabel("n_donors (unique peaks)")