In [None]:
from pathlib import Path
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
import numpy as np
import pandas as pd
import seaborn as sns

cols = sns.color_palette("tab10", n_colors=2)
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import pyranges as pr
from collections import defaultdict
from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, roc_auc_score, RocCurveDisplay
from pysam import AlignmentFile
from tempfile import NamedTemporaryFile

In [None]:
def read_megane_bed(bed_fn):
    df = pd.read_csv(
        bed_fn,
        sep="\t",
        header=None,
        names=[
            "Chromosome",
            "Start",
            "End",
            "family",
            "left_bp",
            "right_bp",
            "confidence",
            "length",
            "subfamily",
            "transduction",
            "filter",
        ],
        usecols=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    )
    df = df[df["filter"].str.contains("PASS;H")]
    df = df[df["subfamily"].str.contains("L1HS")]
    df = df[df["confidence"].str.contains("high;PASS")]
    df["donor_AC"] = df["filter"].str.split(";").str[0].astype(int)
    return df


# read meta data
samples = pd.read_csv(
    snakemake.config["samples"],
    sep="\t",
    dtype={"sample_id": str, "tissue_id": str, "donor_id": str},
)
meta = pd.read_csv(snakemake.config["donors"], sep="\t", dtype={"donor_id": str})
samples = samples.merge(meta, on=["donor_id"]).set_index(
    ["donor_id", "sample_id"], drop=False
)

# read coverage
cov = {Path(c).name.rstrip(".coverage.bed"): c for c in snakemake.input.cov}
samples["coverage"] = samples["sample_id"].map(cov)

with open("resources/bad_cells.txt", "r") as f:
    bad_cells = [line.strip() for line in f.readlines()]

samples = samples[~samples["sample_id"].isin(bad_cells)]

In [None]:
cov = []
for r in tqdm(samples.itertuples(), total=len(samples)):
    mdf = read_megane_bed(r.megane_30x)[["Chromosome", "Start", "End", "donor_AC"]]
    mdf = pr.PyRanges(mdf).extend(100).df
    cdf = pd.read_csv(
        r.coverage,
        sep="\t",
        header=None,
        names=["Chromosome", "Start", "End", "reads", "rpm"],
    )
    cdf["cell_id"] = r.sample_id
    cdf["donor_id"] = r.donor_id
    cdf["locus"] = (
        cdf["Chromosome"].astype(str)
        + ":"
        + cdf["Start"].astype(str)
        + "-"
        + cdf["End"].astype(str)
    )
    cdf = cdf.join(
        mdf.set_index(["Chromosome", "Start", "End"]), on=["Chromosome", "Start", "End"]
    )
    cdf["donor_AC"] = cdf["donor_AC"].fillna(0)
    cov.append(cdf)

In [None]:
cov = pd.concat(cov).reset_index(drop=True)
cov["label"] = cov["donor_AC"].apply(lambda x: True if x > 0 else False)
bulk_cov = cov[
    cov["cell_id"].str.contains("gDNA") | cov["cell_id"].str.contains("Bulk")
].copy()
af = bulk_cov.groupby("locus")["donor_AC"].sum()
bulk_cov["AC"] = bulk_cov["locus"].apply(lambda x: af[x])
cell_cov = cov[
    ~cov["cell_id"].str.contains("gDNA") & ~cov["cell_id"].str.contains("Bulk")
].copy()
cell_cov["AC"] = cell_cov["locus"].apply(lambda x: af[x])
cell_cov["tissue"] = cell_cov["cell_id"].apply(
    lambda x: "HIP" if "ush" in x.lower() else "DLPFC"
)

## Pairwise concordance of variant calls with SLAVseq

### Bulk

In [None]:
# make df of zeros with donor_id as index and columns
donors = meta.sort_values("race")[meta.donor_id != "CommonBrain"]["donor_id"]
df = pd.DataFrame(0, index=donors, columns=donors)

for d1 in donors:
    d1_knrgls = bulk_cov.query("donor_id == @d1 and donor_AC > 0")["locus"].unique()
    for d2 in donors:
        d2_cov = len(
            bulk_cov[
                bulk_cov["donor_id"].eq(d2)
                & bulk_cov["locus"].isin(d1_knrgls)
                & bulk_cov["rpm"].ge(10)
            ]
        )
        df.loc[d1, d2] = d2_cov / len(d1_knrgls)

In [None]:
races = meta.set_index("donor_id").loc[donors]["race"]
race_cols = {"CAUC": cols[0], "AA": cols[1]}
row_colors = races.map(race_cols)

plt.figure(figsize=(10, 10))
plt.title("% 30x WGS LINE-1 calls covered by donor X bulk SLAVseq")
ax = sns.heatmap(df, square=True)
ax.set_xlabel("Donor Bulk SLAVseq")
ax.set_ylabel("Donor WGS LINE-1 calls")
ax.tick_params(
    axis="y", which="major", pad=12, length=0
)  # extra padding to leave room for the row colors
ax.tick_params(
    axis="x", which="major", pad=12, length=0
)  # extra padding to leave room for the row colors

for i, color in enumerate(row_colors):
    ax.add_patch(
        plt.Rectangle(
            xy=(-0.02, i),
            width=0.02,
            height=1,
            color=color,
            lw=0,
            transform=ax.get_yaxis_transform(),
            clip_on=False,
        )
    )
    ax.add_patch(
        plt.Rectangle(
            xy=(i, -0.02),
            width=1,
            height=0.02,
            color=color,
            lw=0,
            transform=ax.get_xaxis_transform(),
            clip_on=False,
        )
    )

## Single-cell

In [None]:
# reshape in 2d
cell_cov_2d = cell_cov.pivot(index="locus", columns="cell_id", values="reads")

# get donors and cells
donors = meta.sort_values("race")[meta.donor_id != "CommonBrain"]["donor_id"]
cells = (
    cell_cov[cell_cov.donor_id != "CommonBrain"]
    .set_index("donor_id")
    .sort_values("tissue")
    .loc[donors]["cell_id"]
    .unique()
)

# make df of zeros with donor_id as index and columns
df = pd.DataFrame(0, index=donors, columns=cells)
with tqdm(total=len(donors) * len(cells)) as pbar:
    for d in donors:
        d_knrgls = bulk_cov[(bulk_cov.donor_id == d) & (bulk_cov.donor_AC > 0)][
            "locus"
        ].unique()
        for c in cells:
            c_cov = sum(cell_cov_2d.loc[d_knrgls, c] >= 10)
            df.loc[d, c] = c_cov / len(d_knrgls)
            pbar.update(1)

In [None]:
df

In [None]:
cell_cov

In [None]:
cols = sns.color_palette("tab10", n_colors=4)

tissues = (
    cell_cov.set_index("donor_id", drop=False)
    .loc[donors][["tissue", "cell_id"]]
    .drop_duplicates()["tissue"]
)
tissue_cols = {"HIP": cols[3], "DLPFC": cols[2]}
col_colors = tissues.map(tissue_cols)

races = meta.set_index("donor_id").loc[donors]["race"]
race_cols = {"CAUC": cols[0], "AA": cols[1]}
row_colors = races.map(race_cols)

group_sizes = (
    cell_cov[["tissue", "cell_id", "donor_id"]]
    .drop_duplicates()
    .groupby(["donor_id", "tissue"])
    .size()
    .to_frame(name="size")
    .loc[donors]
)
group_sizes["cum"] = group_sizes["size"].cumsum()

positions, labels = [], []
for i, s in group_sizes.iterrows():
    positions.append(s.cum - (s["size"] / 2))
    labels.append(i)

plt.figure(figsize=(100, 10))
plt.title("% 30x WGS LINE-1 calls covered by >=10 reads in single-cell SLAVseq")
ax = sns.heatmap(df)
ax.set_xlabel("Single-cell SLAVseq")
ax.set_ylabel("Donor WGS LINE-1 calls")
ax.set_xticks(positions)
ax.set_xticklabels(labels)

ax.tick_params(
    axis="y", which="major", pad=24, length=0
)  # extra padding to leave room for the row colors
for i, color in enumerate(row_colors):
    ax.add_patch(
        plt.Rectangle(
            xy=(-0.004, i),
            width=0.004,
            height=1,
            color=color,
            lw=0,
            transform=ax.get_yaxis_transform(),
            clip_on=False,
        )
    )

ax.tick_params(
    axis="x", which="major", pad=12, length=0
)  # extra padding to leave room for the row colors
for i, color in enumerate(col_colors):
    ax.add_patch(
        plt.Rectangle(
            xy=(i, -0.02),
            width=1,
            height=0.02,
            color=color,
            lw=0,
            transform=ax.get_xaxis_transform(),
            clip_on=False,
        )
    )

## Bulk Sensitivity

In [None]:
g, axs = plt.subplots(2, 3, figsize=(18, 12), sharey="row", sharex="col")
plt.subplots_adjust(hspace=0.1, wspace=0.1)
thresholds = [10, 50, 100, 500]

sens = defaultdict(list)
for d, df in bulk_cov.query("donor_AC > 0").groupby("cell_id"):
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        stat="count",
        complementary=True,
        ax=axs[0, 0],
    )
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        complementary=True,
        ax=axs[1, 0],
    )
    for t in thresholds:
        sens[t].append(df.query(f"rpm > {t}").shape[0] / df.shape[0])
for t, s in sens.items():
    axs[1, 0].axvline(t, color="black", linestyle="--")
    axs[1, 0].text(t, 0.5, f"{np.mean(s):.2f}")
axs[0, 0].set_title("# KNRGLs")
axs[1, 0].set_title("% KNRGLs (sensitivity)")

sens = defaultdict(list)
for d, df in bulk_cov.query("donor_AC == 1").groupby("cell_id"):
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        stat="count",
        complementary=True,
        ax=axs[0, 1],
    )
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        complementary=True,
        ax=axs[1, 1],
    )
    for t in thresholds:
        sens[t].append(df.query(f"rpm > {t}").shape[0] / df.shape[0])
for t, s in sens.items():
    axs[1, 1].axvline(t, color="black", linestyle="--")
    axs[1, 1].text(t, 0.5, f"{np.mean(s):.2f}")
axs[0, 1].set_title("# het KNRGLs")
axs[1, 1].set_title("% het KNRGLs (sensitivity)")

sens = defaultdict(list)
for d, df in bulk_cov.query("donor_AC == 1 and AC > 1").groupby("cell_id"):
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        stat="count",
        complementary=True,
        ax=axs[0, 2],
    )
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        complementary=True,
        ax=axs[1, 2],
    )
    for t in thresholds:
        sens[t].append(df.query(f"rpm > {t}").shape[0] / df.shape[0])
for t, s in sens.items():
    axs[1, 2].axvline(t, color="black", linestyle="--")
    axs[1, 2].text(t, 0.5, f"{np.mean(s):.2f}")
axs[0, 2].set_title("# het KNRGLs in > 1 donor")
axs[1, 2].set_title("% het KNRGLs in > 1 donor (sensitivity)")

## Builk Contamination

In [None]:
g, ax1 = plt.subplots(1, 1, figsize=(6, 6))

for d, df in bulk_cov.query("donor_AC == 0").groupby("donor_id"):
    sns.ecdfplot(
        df,
        x="rpm",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        stat="count",
        complementary=True,
        ax=ax1,
    )
ax1.set_title("# KNRGLs in other donors")

## Bulk ROC curve

In [None]:
g, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
auc = {}
for d, df in bulk_cov.groupby("donor_id"):
    fpr, tpr, threshold = roc_curve(df["label"], df["rpm"])
    sns.lineplot(x=threshold, y=tpr, ax=ax1, alpha=0.5, c=cols[0], label="TPR")
    sns.lineplot(x=threshold, y=fpr, ax=ax1, alpha=0.5, c=cols[1], label="FPR")
    sns.lineplot(x=fpr, y=tpr, ax=ax2, alpha=0.5, c=cols[0])
    auc[d] = roc_auc_score(df["label"], df["rpm"])

# remove legend duplicates
handles, labels = ax1.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax1.legend(by_label.values(), by_label.keys())

ax1.set_title("TPR/FPR vs RPM threshold")
ax1.set_xlabel("RPM threshold")
ax1.set_ylabel("TPR/FPR")
ax1.set_xscale("log")
ax2.set_title("ROC curve")
ax2.set_xlabel("False positive rate")
ax2.set_ylabel("True positive rate")
sns.barplot(x=auc.values(), y=auc.keys(), ax=ax3)
ax3.set_title("AUC ROC")
ax3.set_xlabel("AUC")
ax3.set_ylabel("Donor ID")
plt.show()

## Bulk Sensitivity in 90x

In [None]:

meta90x = pd.read_csv("/iblm/netapp/data3/mcuoco/sz_slavseq/config/U01_LIBD_wgs.tsv", sep="\t")
indvs = meta90x.set_index("ID")[["WGS_90x_HIPPO","WGS_90x_DLPFC"]].to_dict('index')
ids = meta
bulk_cov90x = []
for i in indvs:
	did = meta.set_index("libd_id")["donor_id"].loc[i]
	# check if the files exist
	wgs = Path("/iblm/netapp/data3/mcuoco/sz_slavseq/resources/chm13v2.0.XY/wgs_calls/90x") / indvs[i]["WGS_90x_DLPFC"] / "MEI_final_gaussian_genotyped.bed"
	bam = next(Path(f"/iblm/netapp/data3/mcuoco/sz_slavseq/results/chm13v2.0.XY/filtered/align/{did}").rglob("gDNA*tagged.sorted.bam"))
	assert bam.exists(), f"{bam} does not exist"
	# get one finished individual
	if Path(wgs).exists():
		# read the bed file and filter for L1HS
		mdf = (
			read_megane_bed(wgs)
			.query("family == 'LINE/L1'")
			.query("subfamily.str.contains('L1HS')")
			.query("confidence.str.contains('PASS')")
			.reset_index(drop=True)
		)
		# save to temp file and perform bash operations
		with NamedTemporaryFile() as tmp:
			pr.PyRanges(mdf).extend(100).to_bed(tmp.name)
			libsize = !samtools view -F 1412 -c {bam}
			cdf = !samtools view -F 1412 -b {bam} -q 30 | bedtools coverage -a {tmp.name} -b stdin -counts

		# wrangle
		cdf = pd.DataFrame([r.split("\t") for r in cdf])[[0,1,2,13,14]]
		cdf.columns = ["Chromosome","Start","End","filter","reads"]
		cdf["AC"] = cdf["filter"].str.split(";").apply(lambda x: x[0]).astype(int)
		cdf["locus"] = cdf["Chromosome"].astype(str) + ":" + cdf["Start"].astype(str) + "-" + cdf["End"].astype(str)
		cdf["reads"] = cdf["reads"].astype(int)
		libsize = int(libsize[0]) 
		cdf["rpm"] =  cdf["reads"] / libsize * 1e6
		cdf["donor_id"] = did
		bulk_cov90x.append(cdf)

bulk_cov90x = pd.concat(bulk_cov90x).reset_index(drop=True)

g, axs = plt.subplots(2, 2, figsize=(12, 12), sharey="row", sharex="col")
plt.subplots_adjust(hspace=0.1, wspace=0.1)
thresholds = [10, 50, 100, 500]

sens = defaultdict(list)
for d, df in bulk_cov90x.groupby("donor_id"):
	sns.ecdfplot(df, x="rpm", alpha=0.5, c=cols[0], log_scale=True, stat="count", complementary=True, ax=axs[0,0])
	sns.ecdfplot(df, x="rpm", alpha=0.5, c=cols[0], log_scale=True, complementary=True, ax=axs[1,0])
	for t in thresholds:
		sens[t].append(df.query(f"rpm > {t}").shape[0] / df.shape[0])
for t, s in sens.items():
	axs[1,0].axvline(t, color="black", linestyle="--")
	axs[1,0].text(t, 0.5, f"{np.mean(s):.2f}")
axs[0,0].set_title("# KNRGLs")
axs[1,0].set_title("% KNRGLs (sensitivity)")

sens = defaultdict(list)
for d, df in bulk_cov90x.query("AC == 1").groupby("donor_id"):
	sns.ecdfplot(df, x="rpm", alpha=0.5, c=cols[0], log_scale=True, stat="count", complementary=True, ax=axs[0,1])
	sns.ecdfplot(df, x="rpm", alpha=0.5, c=cols[0], log_scale=True, complementary=True, ax=axs[1,1])
	for t in thresholds:
		sens[t].append(df.query(f"rpm > {t}").shape[0] / df.shape[0])
for t, s in sens.items():
	axs[1,1].axvline(t, color="black", linestyle="--")
	axs[1,1].text(t, 0.5, f"{np.mean(s):.2f}")
axs[0,1].set_title("# het KNRGLs")
axs[1,1].set_title("% het KNRGLs (sensitivity)")

## Cell Sensitivity

In [None]:
g, axs = plt.subplots(6, 7, figsize=(35, 30), sharex=True)
axs = axs.flatten()

# make dict of axes
axs_dict = {}
for i, d in enumerate(meta.donor_id.unique()):
    axs_dict[d] = axs[i]
    axs[i].set_title(d)
    axs[i].set_xlabel("reads")
    axs[i].set_ylabel("% KNRGLs")

for (c, d), df in tqdm(cell_cov.query("donor_AC > 0").groupby(["cell_id", "donor_id"])):
    sns.ecdfplot(
        df,
        x="reads",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        complementary=True,
        ax=axs_dict[d],
    )

In [None]:
g, axs = plt.subplots(6, 7, figsize=(35, 30), sharex=True)
axs = axs.flatten()

# make dict of axes
axs_dict = {}
for i, d in enumerate(meta.donor_id.unique()):
    axs_dict[d] = axs[i]
    axs[i].set_title(d)
    axs[i].set_xlabel("reads")
    axs[i].set_ylabel("# KNRGLs")
    axs[i].set_ylim(0, 50)

for (c, d), df in tqdm(
    cell_cov.query("donor_AC == 0").groupby(["cell_id", "donor_id"])
):
    sns.ecdfplot(
        df,
        x="reads",
        alpha=0.5,
        c=cols[0],
        log_scale=True,
        stat="count",
        complementary=True,
        ax=axs_dict[d],
    )

## Single Cell ROC curve

In [None]:
g, axs = plt.subplots(6, 7, figsize=(35, 30), sharex=True)
axs = axs.flatten()

# make dict of axes
axs_dict = {}
for i, d in enumerate(meta.donor_id.unique()):
    axs_dict[d] = axs[i]
    axs[i].set_title(d)

auc = {}
for (c, d), df in tqdm(cell_cov.groupby(["cell_id", "donor_id"])):
    fpr, tpr, threshold = roc_curve(df["label"], df["rpm"])
    sns.lineplot(x=fpr, y=tpr, ax=axs_dict[d], alpha=0.5, c=cols[0])
    auc[(c, d)] = roc_auc_score(df["label"], df["rpm"])

## cell coverage by donor

In [None]:
df = (
    cell_cov[(cell_cov.donor_AC == 1) & (cell_cov.donor_id != "CommonBrain")]
    .groupby(["locus", "cell_id", "donor_id", "tissue"])["reads"]
    .sum()
    .reset_index(name="reads")
)
df["donor_tissue"] = df["donor_id"].str.cat(df["tissue"], sep="_")

g = sns.ecdfplot(
    df, x="reads", hue="donor_tissue", alpha=0.5, complementary=True, stat="count"
)
g.set_xscale("log")
g.legend_.remove()

In [None]:
# g, ax = plt.subplots(1, 1, figsize=(10, 10))
for t in [3, 5, 10, 20, 50, 100]:
    donor_total = df.groupby(["donor_id"])["cell_id"].nunique()
    n_cov_donor = df.groupby(["locus", "donor_id"])["reads"].apply(lambda x: sum(x > t))
    # frac_cov_donor = n_cov_donor / donor_total
    donor_tissue_total = df.groupby(["donor_id", "tissue"])["cell_id"].nunique()
    n_cov_tissue = (
        df.groupby(["locus", "donor_id", "tissue"])["reads"]
        .apply(lambda x: sum(x > t))
        .reset_index(name="n_cells")
    )
    break

sns.displot(
    n_cov_tissue,
    x="n_cells",
    col="tissue",
    hue="donor_id",
    kind="ecdf",
    complementary=True,
    stat="count",
    hue_order=["1", "10", "11"],
)

In [None]:
sns.jointplot(
    n_cov_tissue,
    x="DLPFC",
    y="HIP",
    kind="hist",
    joint_kws={"discrete": True},
    marginal_kws={"discrete": True},
    marginal_ticks=True,
)

In [None]:
sns.jointplot(
    n_cov_tissue[~((n_cov_tissue.DLPFC == 0) & (n_cov_tissue.HIP == 0))],
    x="DLPFC",
    y="HIP",
    kind="hist",
    joint_kws={"discrete": True},
    marginal_kws={"discrete": True},
    marginal_ticks=True,
)