# Peak evaluation

In [None]:
import pandas as pd
import seaborn as sns
import pyranges as pr
import pysam

## Read in data

In [None]:
peaks = pd.read_csv(
    snakemake.input.peaks,
    sep="\t",
    header=None,
    names=[
        "chrom",
        "start",
        "end",
        "name",
        "score",
        "strand",
        "signalValue",
        "pValue",
        "qValue",
        "peak",
    ],
)
peaks["size"] = peaks["end"] - peaks["start"]

# get read counts
tbx = pysam.TabixFile(snakemake.input.reads)
peaks["num_reads"] = peaks.apply(lambda x: sum(1 for _ in tbx.fetch(x["chrom"], x["start"], x["end"])), axis=1)

# # calculate Fraction of Reads in Peaks (FRiP) 
peak_reads = peaks["num_reads"].sum()
tbx = pysam.TabixFile(snakemake.input.reads)
total_reads = sum(1 for _ in tbx.fetch())
frip = peak_reads / total_reads
print(f"Fraction of Reads in Peaks (FRiP): {frip:.4f}")
print(f"Total peaks: {len(peaks)}")

min_reads = peaks['num_reads'].min()
max_reads = peaks['num_reads'].max()
print(f"Peaks have from {min_reads}-{max_reads} reads")
print(f"{len(peaks[peaks['num_reads'] == min_reads])} peaks have {min_reads} reads, {len(peaks[peaks['num_reads'] == max_reads])} peaks have {max_reads} reads")

In [None]:
# visualize peak width and signal value
fig = sns.jointplot(data=peaks, x="size", y="num_reads", joint_kws=dict(alpha=0.2, size=1, legend=False), marginal_kws=dict(bins=50, fill=False), marginal_ticks=True)
fig.ax_joint.set_yscale("log")
fig.ax_joint.set_xlabel("Peak width (bp)")
fig.ax_joint.set_ylabel("Number of reads")

In [None]:
# functions to read in germline insertions
def read_non_ref_db():
	df = pd.read_csv(
		snakemake.input.non_ref_l1,
		sep="\t",
		header=None,
		names=["chrom", "start", "end"],
		dtype={"chrom": str, "start": int, "end": int},
	)
	return df

def read_rmsk():
    """
    Read the repeatmasker output table and return locations of L1HS and L1PA2-6
    """
    # read the rmsk file
    df0 = pd.read_csv(
        snakemake.input.ref_l1,
        skiprows=3,
        delim_whitespace=True,
        names=["chrom", "start", "end", "strand", "repeat"],
        usecols=[4, 5, 6, 8, 9],
    )

    # filter for rep_names
    rep_names = [
        "L1HS_3end",
        "L1PA2_3end",
        "L1PA3_3end",
        "L1PA4_3end",
        "L1PA5_3end",
        "L1PA6_3end",
    ]
    df0 = df0[df0["repeat"].isin(rep_names)]

    # save to new dataframe
    df1 = pd.DataFrame()
    df1["chrom"] = df0["chrom"].astype(str)
    # set start positions depending on strand
    df1["start"] = df0.apply(
        lambda x: x["end"] if x["strand"] != "+" else x["start"], axis=1
    )
    df1["end"] = df1["start"]
    df1["start"] -= 1  # make zero-based

    return df1

def read_germline():
	non_ref = read_non_ref_db()
	ref = read_rmsk()
	germline = pd.concat([non_ref, ref])
	return germline

In [None]:
def rmsk_in_calls(rmsk_df, calls_df):
	"""
	Find the number of rmsk insertions that are called by the peak caller
	"""
	# convert to pyranges
	rmsk_df = pr.PyRanges(rmsk_df)
	calls_df = pr.PyRanges(calls_df)

	# find overlaps
	overlaps = rmsk_df.overlap(calls_df)
	overlaps = overlaps.df.drop_duplicates(subset=["Chromosome", "Start", "End"])

	# count overlaps
	overlaps = overlaps.groupby("Chromosome").size().reset_index()
	rmsk_df = rmsk_df.df.groupby("Chromosome").size().reset_index()
	overlaps = overlaps.merge(rmsk_df, on="Chromosome", suffixes=("_called", "_total"))
	overlaps = overlaps.rename(columns={"0_called": "num_called", "0_total": "num_total"})
	return overlaps

In [None]:
peaks = peaks.rename(columns={"chrom": "Chromosome", "start": "Start", "end": "End"})
rmsk = read_rmsk().rename(columns={"chrom": "Chromosome", "start": "Start", "end": "End"})
peak_calls = rmsk_in_calls(rmsk, peaks)
peak_calls["calls"] = "peaks"
print(f"{peak_calls.num_called.sum()} / {len(peaks)} peaks overlapping rmsk L1 insertions")

# windows = pd.read_pickle(snakemake.input.labels[0]).reset_index()
# windows = windows[["chrom", "start", "end"]]
# windows = windows.rename(columns={"chrom": "Chromosome", "start": "Start", "end": "End"})
# window_calls = rmsk_in_calls(rmsk, windows)
# window_calls["calls"] = "windows"

# calls = pd.concat([peak_calls, window_calls])

In [None]:
fig = sns.scatterplot(data=peak_calls, x="num_total", y="num_called")
fig.set(xlabel="# rmsk insertions / chr", ylabel="# calls overlapping rmsk insertions / chr")
sns.despine()

# TODO: 

1. Compute rmsk insertions covered by windowing strategy
2. Examine reads/peak for those that are/aren't overlapping rmsk insertions
3. Examine windows/peaks overlapping non-reference insertions called by xTea

## Try different parameters for the peak calls and repeat the above