# Peak calling testing

In [None]:
from subprocess import Popen, PIPE, DEVNULL
from tempfile import TemporaryDirectory, NamedTemporaryFile
from pathlib import Path
from itertools import product
from io import StringIO
from collections import deque, namedtuple, defaultdict

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyranges as pr
import pysam
from scripts.get_labels import read_knrgl
from myutils.rmsk import read_rmsk

from tqdm import tqdm

In [None]:
# read in repeatmasker output
rmsk = read_rmsk("/iblm/netapp/data4/mcuoco/sz_slavseq/resources/hs38d1.fa.out")

rep_names = [
    "L1HS_3end",
    "L1PA2_3end",
    "L1PA3_3end",
    "L1PA4_3end",
    "L1PA5_3end",
    "L1PA6_3end",
]

rmsk = rmsk.loc[(rmsk["repName"].isin(rep_names)) & (rmsk["repEnd"] > 860), :]
rmsk = rmsk.apply(
    lambda x: x
    if (x["strand"] == "+" and x["repStart"] < 765)
    or (x["strand"] == "-" and x["repLeft"] < 765)
    else None,
    axis=1,
).dropna()

rmsk["genoStart"] = rmsk.apply(
    lambda x: x["genoStart"] - 1000 if x["strand"] == "-" else x["genoStart"], axis=1
)
rmsk["genoEnd"] = rmsk.apply(
    lambda x: x["genoEnd"] + 1000 if x["strand"] == "+" else x["genoEnd"], axis=1
)

rmsk = rmsk.rename(
    columns={
        "genoName": "Chromosome",
        "genoStart": "Start",
        "genoEnd": "End",
        "strand": "Strand",
    }
).loc[:, ["Chromosome", "Start", "End", "Strand"]]
rmsk = rmsk.loc[(rmsk.Start >= 0) & (rmsk.End >= 0), :]

rmsk["Name"] = rmsk.index.values.astype(str)

## Define functions to run `bedtools intersect`

In [None]:
def bam_intersect(bam, bed, view_args=""):
    """
    Intersect bam file with bedfile, optionally filter to bam first.
    Return output as a dataframe
    """
    cmd = f"samtools view -b {bam} {view_args} | bedtools intersect -abam stdin -b {bed} -bed -wa -wb"
    p = Popen(cmd, shell=True, stdout=PIPE, stderr=DEVNULL)

    names = [
        "Chromosome",
        "Start",
        "End",
        "read_id",
        "mapq",
        "Strand",
        "intersect_start",
        "intersect_end",
        "x",
        "y",
        "Flag",
        "z",
        "ChromosomeB",
        "StartB",
        "EndB",
        "NameB",
        "ScoreB",
        "StrandB",
    ]
    with StringIO(p.stdout.read().decode()) as bed:
        df = pd.read_csv(bed, sep="\t", header=None, names=names)

    df.drop(["x", "y", "z"], axis=1, inplace=True)
    df["Flag"] = df["Flag"].str.rstrip(",").astype(int)
    return df


def bed_intersect(peaks, annotation, **kwargs):
    # compute reads per knrgl
    # assert peaks.columns == ["Chromosome", "Start", "End", "Strand", "nreads", "width"], "Incorrect columsn in peak file"

    # save bed files for bedtools
    bed_names = ["Chromosome", "Start", "End", "Name", "Score", "Strand"]
    peak_tmp = NamedTemporaryFile()
    pr.PyRanges(peaks).to_bed(peak_tmp.name)
    peak_names = bed_names + [c for c in peaks.columns if c not in bed_names]

    annotation_tmp = NamedTemporaryFile()
    pr.PyRanges(annotation[["Chromosome", "Start", "End", "Name"]]).to_bed(
        annotation_tmp.name
    )
    annotation_names = [f"{c}b" for c in bed_names] + [
        c for c in annotation.columns if c not in bed_names
    ]

    cmd = f"bedtools intersect -a {peak_tmp.name} -b {annotation_tmp.name} -wa -wb"
    p = Popen(cmd, shell=True, stdout=PIPE)
    with StringIO(p.stdout.read().decode()) as bed:
        df = pd.read_csv(
            bed, sep="\t", header=None, names=peak_names + annotation_names
        )
    peak_tmp.close()
    annotation_tmp.close()

    return df

## Define peak caller

In [None]:
# My custom peak caller
# iterate over reads, cluster overlapping together
# TODO: link r1 and r2 peaks


# custom tuple to store reads in deque
Read = namedtuple(
    "Read",
    [
        "query_name" "reference_name",
        "reference_start",
        "reference_end",
        "is_read1",
        "is_read2",
        "is_forward",
        "is_reverse",
        "mapping_quality",
    ],
)


class OverlapPeakCaller:
    def __init__(self, bam: pysam.AlignmentFile) -> None:
        self.bam = bam

    def read_converter(self, read: pysam.AlignedSegment) -> Read:
        "convert pysam.AlignedSegment to namedtuple Read"
        return Read(
            read.query_name,
            read.reference_name,
            read.reference_start,
            read.reference_end,
            read.is_read1,
            read.is_read2,
            read.is_forward,
            read.is_reverse,
            read.mapping_quality,
        )

    def make_peaks(self, reads, r1_bandwidth=0, min_mapq=60) -> deque:
        """
        Iterate over reads, cluster overlapping reads together
        Do this separately for read1_fwd, read1_rev, read2_fwd, read2_rev
        """

        peak_groups = {
            "read1_fwd": {
                "filter": lambda x: x.is_read1 and x.is_forward,
                "peak": deque(),
                "last_read": None,
            },
            "read1_rev": {
                "filter": lambda x: x.is_read1 and x.is_reverse,
                "peak": deque(),
                "last_read": None,
            },
            "read2_fwd": {
                "filter": lambda x: x.is_read2 and x.is_forward,
                "peak": deque(),
                "last_read": None,
            },
            "read2_rev": {
                "filter": lambda x: x.is_read2 and x.is_reverse,
                "peak": deque(),
                "last_read": None,
            },
        }

        for read in reads:
            for rg, d in peak_groups.items():
                # check if read is in read group
                if not d["filter"](read):
                    continue
                if d["last_read"] is None:
                    d["last_read"] = read
                    d["peak"].append(read)
                    continue

                # if read is not within bandwith of last read, yield peak
                if "read2" in rg:
                    end = d["last_read"].reference_end
                else:
                    end = d["last_read"].reference_end + r1_bandwidth

                if (read.reference_name != d["last_read"].reference_name) or (
                    read.reference_start > end
                ):
                    yield {
                        "read_group": rg,
                        "Chromosome": read.reference_name,
                        "Start": d["peak"][0].reference_start,
                        "End": d["last_read"].reference_end,
                        "Strand": "-" if read.is_reverse else "+",
                        "nreads": len(d["peak"]),
                    }
                    # start new peak
                    d["peak"] = deque()

                d["peak"].append(read)
                d["last_read"] = read

    def merge_peaks(self, peaks, bandwidth=0) -> dict:

        last_peaks = {
            "read1_fwd": None,
            "read2_fwd": None,
            "read1_rev": None,
            "read2_rev": None,
        }

        for p in peaks:
            # return unmatched peaks
            if last_peaks[p.read_group] is not None:
                yield {
                    "Chromosome": last_peaks[p.read_group].Chromosome,
                    "Start": last_peaks[p.read_group].Start,
                    "End": last_peaks[p.read_group].End,
                    "Strand": "-"
                    if p.read_group == "read1_fwd" or "read2_rev"
                    else "+",
                    "nreads": p.nreads,
                }

            # if concordant peak is within bandwidth of last peak, merge and yield
            for r1g, r2g in zip(
                ["read1_fwd", "read1_rev", "read2_rev", "read2_fwd"],
                ["read2_rev", "read2_fwd", "read1_fwd", "read1_rev"],
            ):
                if p.read_group == r1g and (
                    last_peaks[r2g] is not None
                    and p.Chromosome == last_peaks[r2g].Chromosome
                    and p.Start < (last_peaks[r2g].End + bandwidth)
                ):
                    yield {
                        "Chromosome": p.Chromosome,
                        "Start": last_peaks[r2g].Start,
                        "End": p.End,
                        "Strand": "-" if r1g == "read1_fwd" or "read2_rev" else "+",
                        "nreads": p.nreads + last_peaks[r2g].nreads,
                    }

                    # reset last peaks
                    last_peaks[r2g] = None
                    last_peaks[r1g] = None

            last_peaks[p.read_group] = p

    def run(self, r1_bandwidth=0, min_mapq=60):

        # get the reads
        reads = filter(
            lambda x: x.is_mapped
            and (not (x.is_secondary or x.is_supplementary))
            and x.mapping_quality >= min_mapq,
            self.bam.fetch(contig="chr1"),
        )
        reads = map(self.read_converter, reads)

        # get the peaks
        peaks = [p for p in self.make_peaks(reads, r1_bandwidth, min_mapq)]

        # sort peaks
        peaks = (
            pd.DataFrame.from_records(peaks)
            .sort_values(["Chromosome", "Start", "End"])
            .to_records()
        )

        # merge the peaks
        merged = [p for p in self.merge_peaks(peaks, r1_bandwidth)]

        # res["group"] = res["read group"].apply(lambda x: "read1" if "read1" in x else "read2")
        # res["width"] = res["End"] - res["Start"]

        return peaks, merged

## Define analysis helper functions

In [None]:
# def analysis helper functions
def annotate_peaks(peaks: pd.DataFrame, annotation: pd.DataFrame, name: str):
    """
    Annotate peaks that overlap entries in annotation
    """

    # intersect peaks with annotation
    intersection = bed_intersect(peaks, annotation)
    intersection = (
        intersection[["Chromosome", "Start", "End", "Nameb"]]
        .set_index(["Chromosome", "Start", "End"])
        .rename(columns={"Nameb": f"{name} ID"})
    )

    # annotate peaks
    peaks = peaks.join(intersection, on=["Chromosome", "Start", "End"], how="left")
    peaks[f"{name} cov"] = len(intersection.groupby([f"{name} ID"]).size())
    peaks[name] = len(annotation)

    return peaks

In [None]:
# define plotting helper funciton
def peak_cdf_plot(df, hue_col="indv"):
    # melt df
    df = df.melt(
        id_vars=["Chromosome", "Start", "End", hue_col, "group"],
        value_vars=["nreads", "width"],
        var_name="metric",
        value_name="value",
    )

    # create facetgrid
    g = sns.FacetGrid(
        df,
        col="metric",
        row="group",
        hue=hue_col,
        sharex=False,
        sharey=True,
        height=4,
        aspect=1,
    )
    g.map_dataframe(sns.ecdfplot, x="value", stat="proportion")

    # add legend
    g.add_legend()

    # if nreads, set x axis to log scale
    g.axes[0, 0].set_xscale("log")
    g.axes[1, 0].set_xscale("log")

    # if width, set xlim to 0
    g.axes[0, 1].set_xlim(0, None)
    g.axes[1, 1].set_xlim(0, None)

    return g

## Analysis

In [None]:
# find knrgl bed and bulk BAMs for each individual
individuals = pd.read_csv(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/config/bulk_donors.tsv", sep="\t"
)["donor_id"].values.astype(str)
indv_data = {i: {"knrgl": None, "bulk": None} for i in individuals}

for bulk in Path("/iblm/netapp/data4/mcuoco/sz_slavseq/results/align/").rglob(
    "*/gDNA_usd*.sorted.bam"
):
    if bulk.parts[-2] in individuals:
        indv_data[bulk.parts[-2]]["bulk"] = str(bulk)

for knrgl in Path("/iblm/netapp/data4/mcuoco/sz_slavseq/resources/").rglob(
    "*_insertions.bed"
):
    if knrgl.name.split("_")[0] in individuals:
        indv_data[knrgl.name.split("_")[0]]["knrgl"] = str(knrgl)

### Explore `bandwith` and `min_mapq` parameters

Use individual 27

In [None]:
# test peak caller
bam = pysam.AlignmentFile(indv_data["27"]["bulk"], "rb")
peaks, merged = OverlapPeakCaller(bam).run(r1_bandwidth=750)

In [None]:
peaks

In [None]:
merged

In [None]:
bandwidths = [250, 500, 1000, 2000]
min_mapqs = [5, 10, 20, 30, 40, 50, 60]
data = indv_data["27"]

total_reads = int(pysam.view("-c", data["bulk"]).rstrip("\n"))
bam = pysam.AlignmentFile(data["bulk"], "rb")
knrgl = read_knrgl(data["knrgl"])
knrgl["Name"] = knrgl.index.values.astype(str)  # give each knrgl a unique ID

res = []
pc = OverlapPeakCaller(bam)
for bw, mq in tqdm(
    product(bandwidths, min_mapqs), total=(len(bandwidths) * len(min_mapqs))
):
    # call peaks
    peaks = pc.run(r1_bandwidth=bw, min_mapq=mq)

    # annotate
    peaks = annotate_peaks(peaks, knrgl, "knrgl")
    peaks = annotate_peaks(peaks, rmsk, "rmsk")

    peaks["bw"] = bw
    peaks["min_mapq"] = mq

    # collect results
    res.append(peaks)

res = pd.concat(res)

In [None]:
res["frac knrgl cov"] = res["knrgl cov"] / res["knrgl"]
res["frac rmsk cov"] = res["rmsk cov"] / res["rmsk"]
plot_df = (
    res.groupby(["bw", "min_mapq", "frac knrgl cov", "frac rmsk cov"])
    .size()
    .reset_index()
    .rename(columns={0: "total peaks"})
    .melt(
        id_vars=["bw", "min_mapq"],
        value_vars=["total peaks", "frac knrgl cov", "frac rmsk cov"],
        value_name="value",
        var_name="var",
    )
)

sns.catplot(
    data=plot_df,
    x="bw",
    y="value",
    col="var",
    hue="min_mapq",
    kind="bar",
    palette="Set2",
    sharey=False,
    dodge=True,
)

In [None]:
peak_cdf_plot(res, hue_col="min_mapq")

In [None]:
peak_cdf_plot(res[res["knrgl ID"].notnull()], hue_col="min_mapq")

### Analyze across individuals

In [None]:
res = []
for ind, data in tqdm(indv_data.items(), total=len(indv_data)):
    # total_reads = int(pysam.view("-c", data["bulk"], "chr22").rstrip("\n")) # subset to chr22 (for testing)
    total = int(pysam.view("-c", data["bulk"]).rstrip("\n"))
    bam = pysam.AlignmentFile(data["bulk"], "rb")

    # call peaks
    peaks = OverlapPeakCaller(bam).call_peaks()

    ## Compare with knrgl and rmsk
    knrgl = read_knrgl(data["knrgl"])
    knrgl["Name"] = knrgl.index.values.astype(str)  # give each knrgl a unique ID
    peaks = annotate_peaks(peaks, knrgl, "KNRGL")
    peaks = annotate_peaks(peaks, rmsk, "RMSL")

    # collect results
    peaks["indv"] = int(ind)
    peaks["total alignments"] = total
    res.append(peaks)

res = pd.concat(res)

In [None]:
# add metadata
meta = pd.read_csv(
    "/iblm/logglun02/mcuoco/workflows/sz_slavseq/config/slavseq_metadata.tsv", sep="\t"
)
meta = meta[~meta["TISSUE_ID"].isin(["CommonBrain"])]
meta["indv"] = meta["TISSUE_ID"].str[3:].astype(int)
meta.set_index("indv", inplace=True)

# keep columns of interest
meta = meta[["RACE", "AGE", "DIAGNOSIS"]].drop_duplicates()

# join with res
res = res.join(meta, on="indv")

In [None]:
# plot
res["frac knrgl cov"] = res["knrgl cov"] / res["knrgl"]
res["frac rmsk cov"] = res["rmsk cov"] / res["rmsk"]
plot_df = (
    res.groupby(
        [
            "indv",
            "total alignments",
            "frac knrgl cov",
            "frac rmsk cov",
            "AGE",
            "RACE",
            "DIAGNOSIS",
        ]
    )
    .size()
    .reset_index()
    .rename(columns={0: "total peaks"})
    .melt(
        id_vars=["indv", "AGE", "RACE", "DIAGNOSIS"],
        value_vars=[
            "total alignments",
            "total peaks",
            "frac knrgl cov",
            "frac rmsk cov",
        ],
        value_name="value",
        var_name="var",
    )
)

sns.catplot(
    data=plot_df,
    x="indv",
    y="value",
    hue="indv",
    col="var",
    kind="bar",
    palette="Set2",
    sharey=False,
    dodge=False,
)

In [None]:
g = peak_cdf_plot(res)
g.fig.subplots_adjust(top=0.90)
g.fig.suptitle("All peaks")

In [None]:
g = peak_cdf_plot(res[res["KNRGL_ID"].notnull()])
g.fig.subplots_adjust(top=0.90)
g.fig.suptitle("Peaks overlapping KNRGL")

In [None]:
g = peak_cdf_plot(res[res["RMSK_ID"].notnull()])
g.fig.subplots_adjust(top=0.90)
g.fig.suptitle("Peaks overlapping RMSK")

In [None]:
# how many peaks of each type per knrgl insertion?
count = (
    res[["indv", "KNRGL_ID", "read group", "level_1"]]
    .pivot_table(
        index=["indv", "KNRGL_ID"],
        columns="read group",
        values="level_1",
        aggfunc="count",
    )
    .fillna(0)
)
count[["read1_fwd", "read1_rev", "read2_fwd", "read2_rev"]].max(axis=0)

In [None]:
# how many peaks of each type per knrgl insertion?
count = (
    res[["indv", "KNRGL_ID", "read group", "level_1"]]
    .pivot_table(
        index=["indv", "KNRGL_ID"],
        columns="read group",
        values="level_1",
        aggfunc="count",
    )
    .fillna(0)
)
count[["read1_fwd", "read1_rev", "read2_fwd", "read2_rev"]].sum(axis=1)

Where are the peaks with nreads = 1? Should they be grouped with other peaks or are they artifactual alignments?