# Count Processing

This notebook includes code for processing counts generated by `count_bcs.sh`.

In [None]:
import os

os.environ["POLARS_MAX_THREADS"] = os.environ["SLURM_CPUS_PER_TASK"]
print(f"{os.environ['SLURM_CPUS_PER_TASK']} thread(s) available")

In [None]:
from IPython.display import display, Markdown, HTML
from tqdm.notebook import tqdm

%matplotlib inline
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats("jpeg")

import matplotlib
import matplotlib.pyplot as pyplot
import seaborn

import numpy
import polars
import pandas
import scipy

import json

polars.toggle_string_cache(True)

In [None]:
"""
Style Matplotlib plots
"""

plot_style = {
    "font.size": 12,
    "font.family": "sans-serif",
    "font.sans-serif": ["Inter"],
    "figure.figsize": [2, 2],
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "figure.facecolor": (1, 1, 1, 0),
    "text.usetex": False,
    "lines.markersize": 3,
    "axes.titleweight": "bold",
    "axes.labelweight": 600,
    "axes.labelsize": 9,
    "axes.facecolor": "none",
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.linewidth": 1.5,
    "grid.color": "#303030",
    "grid.alpha": 0.1,
    "xtick.labelsize": 9,
    "xtick.major.width": 1.5,
    "xtick.minor.width": 1,
    "xtick.minor.visible": False,
    "ytick.labelsize": 9,
    "ytick.major.width": 1.5,
    "ytick.minor.width": 1,
    "ytick.minor.visible": False,
    "figure.frameon": False,
    "legend.frameon": False,
    "legend.fancybox": False,
    "legend.fontsize": 9,
    "legend.scatterpoints": 1,
    "legend.markerscale": 1,
    "legend.handlelength": 1.0,
    "legend.handletextpad": 0.2,
    "axes.prop_cycle": matplotlib.cycler(color=["k", "b", "g", "r", "c", "y"]),
}

pyplot.style.use(plot_style)
_ = pyplot.ioff()

cm_white_to_gray = matplotlib.colors.LinearSegmentedColormap.from_list("white_to_gray", ["#ffffff", "#404040"])
cm_white_to_red = matplotlib.colors.LinearSegmentedColormap.from_list("white_to_red", ["#ffffff", "#ff0000"])
cm_white_to_blue = matplotlib.colors.LinearSegmentedColormap.from_list("white_to_blue", ["#ffffff", "#0000ff"])
cm_white_to_green = matplotlib.colors.LinearSegmentedColormap.from_list("white_to_green", ["#ffffff", "#009000"])

cm_light_to_gray = matplotlib.colors.LinearSegmentedColormap.from_list("light_to_gray", ["#e0e0e0", "#404040"])
cm_light_to_red = matplotlib.colors.LinearSegmentedColormap.from_list("light_to_red", ["#ffe0e0", "#ff0000"])
cm_light_to_blue = matplotlib.colors.LinearSegmentedColormap.from_list("light_to_blue", ["#e0e0ff", "#0000ff"])
cm_light_to_green = matplotlib.colors.LinearSegmentedColormap.from_list("light_to_green", ["#e0ffe0", "#009000"])

boxplot_style_args = {
    "whiskerprops": {
        "color": "#000000"
    },
    "boxprops": {
        "edgecolor": "#000000",
    },
    "flierprops": {
        "marker": "o",
        "markerfacecolor": "#000000",
        "markersize": 1,
        "markeredgecolor": None
    },
    "medianprops": {
        "color": "#000000"
    },
    "capprops": {
        "color": "#000000"
    }
}

In [None]:
"""
Common plotting functions
"""

# Adapted from https://stackoverflow.com/a/53865762

def calc_weights_for_scatter(x, y, bins=20):
    if type(x) is not numpy.ndarray or type(y) is not numpy.ndarray:
        raise ValueError("x and y must be `numpy.ndarray`s")
    
    bin_weights, xbins, ybins = numpy.histogram2d(x, y, bins=bins, density=False)
    weights = scipy.interpolate.interpn(points=(0.5*(xbins[1:] + xbins[:-1]), 0.5*(ybins[1:] + ybins[:-1])), values=bin_weights, xi=numpy.vstack([x, y]).T, method="splinef2d", bounds_error=False)
    weights[numpy.where(numpy.isnan(weights))] = 0.0
    
    normalizer = matplotlib.colors.Normalize(vmin=numpy.min(weights), vmax=numpy.max(weights))
    
    return weights, normalizer

def reorder_points_and_weights(x, y, w, f = None):
    if type(x) is not numpy.ndarray or type(y) is not numpy.ndarray or type(w) is not numpy.ndarray:
        raise ValueError("x, y, and w must be `numpy.ndarray`s")
    
    x_, y_, w_ = x, y, w
    if f is not None:
        x_, y_, w_ = f(x), f(y), f(w)
    order = numpy.argsort(w_)
    return {
        "x": x_[order],
        "y": y_[order],
        "c": w_[order]
    }

NAME_TRANSLATION_TABLE = {
    "WT": "+/+",
    "Rhet": "R90W/+",
    "Rhom": "R90W/R90W",
    "Ehet": "E168d2/+",
    "Ehom": "E168d2/E168d2",
    "CrxKO": "-/-",
    "chip_peak_id": "CRE"
}

def translate_genotype_to_label(genotype):
    return NAME_TRANSLATION_TABLE[genotype]

In [None]:
"""
Read in sample information and total read counts, filter to rows of interest
"""

samples_information = polars.scan_csv("sample_information.tsv", sep="\t").with_columns([
    polars.format("{}.{}.{}.batch{}.rep{}", "genotype", "library", "source_molecule", "sequencing_group", "replicate").alias("computed_name"),
    polars.col("genotype").cast(polars.Categorical),
    polars.col("source_molecule").cast(polars.Categorical),
    polars.col("library").cast(polars.Categorical),
]).collect()

display(Markdown("### Read File Information"))
display(samples_information)

In [None]:
"""
Load barcode counts for each sample
"""

raw_samples_counts = []

for sample_info in tqdm(samples_information.iter_rows(named=True), total=52):
    sample_counts = polars.read_csv(f"""Barcodes/{sample_info['computed_name']}.barcode_counts.tsv""", sep="\t")
    
    raw_samples_counts.append(sample_counts.with_columns([
        polars.lit(polars.DataFrame(sample_info).to_struct('sample_info')).alias("sample_info")
    ]).unnest("sample_info").with_columns([
        polars.col("genotype").cast(polars.Categorical),
        polars.col("source_molecule").cast(polars.Categorical),
        polars.col("library").cast(polars.Categorical),
    ]))
    
samples_counts = polars.concat(raw_samples_counts, rechunk=True, how="vertical", parallel=True)

display(Markdown("### Sample Counts"))
display(samples_counts)

In [None]:
display(Markdown(f"### Total Counts per Sample"))
grid = seaborn.catplot(kind="bar", data=samples_counts.groupby(["genotype", "source_molecule", "library", "replicate"]).agg([polars.col("count").sum()]).sort(["library", "source_molecule", "genotype"]).to_pandas(), row="library", col="genotype", x="replicate", y="count", order=[1, 2, 3], height=3, palette="Set1")
display(grid.figure) 
pyplot.close(grid.figure)

In [None]:
"""
Load BC-library_id map
"""

bc_to_library_id_map = polars.read_csv("bc_to_library_id_map.tsv", sep="\t")

display(Markdown("### BC-to-Library ID Map"))
display(bc_to_library_id_map)

In [None]:
pipeline = samples_counts.lazy()

"""
Sum barcode counts for samples sequenced in multiple runs (multiple sequencing_group values)
"""

pipeline = pipeline.groupby(["library", "genotype", "source_molecule", "replicate", "BC"]).agg([
    polars.sum("count")
])

"""
Extract list of library members with zero DNA counts (not observed in sequencing of plasmid pool)
"""

missing_from_plasmid = pipeline.filter((polars.col("genotype") == "Plasmid") & (polars.col("count") == 0)).collect().groupby(["library"]).all().select(["library", "BC"]).with_columns([
    polars.col("BC").arr.lengths().alias("BC_count")
])
display(Markdown("### BCs missing from Plasmid"))
display(missing_from_plasmid)

"""
Compute counts per million reads
"""

pipeline = pipeline.with_columns([
    ((polars.col("count") * 1_000_000) / polars.sum("count")).alias("cpm"),
])

"""
Normalize cpm for each measurement by dividing by cpm in the plasmid (DNA) sample.
Measurements for which plasmid (DNA) counts are 0 are dropped.
"""

pipeline = pipeline.join(
    pipeline
        .filter(polars.col("genotype") == "Plasmid")
        .select(["library", "BC", "cpm", "count"])
        .rename({"cpm": "plasmid_cpm", "count": "plasmid_count"})
, on=["library", "BC"], how="left").with_columns([
    (polars.col("cpm") / polars.col("plasmid_cpm")).alias("activity")
]).filter((polars.col("genotype") != "Plasmid") & (polars.col("plasmid_count") > 0)).drop(["source_molecule", "plasmid_cpm", "plasmid_count"])

"""
Join counts with library ID map
"""

pipeline = pipeline.join(bc_to_library_id_map.lazy(), on="BC", how="left")

pipeline = pipeline.sort(["library", "genotype", "library_id", "BC", "replicate"])

normalized_samples_counts = pipeline.collect()

display(Markdown("### Plasmid-normalized Counts"))
display(normalized_samples_counts)

In [None]:
display(Markdown("## basal (all barcodes and replicates)"))
data = normalized_samples_counts.filter(polars.col("library_id") == "basal").to_pandas()
data["genotype"].cat = data["genotype"].cat.remove_unused_categories()

grid = seaborn.catplot(kind="strip", data=data, row="library", x="genotype", hue="replicate", y="count", height=3, aspect=2, palette="Set1", alpha=0.4, linewidth=0, dodge=True)

for axis in grid.axes.flat:
    axis.axhline(10, color="#00000080", linestyle=":")
    axis.set_yscale("symlog", linthresh=1)
    axis.set_ylim(bottom=0)

display(grid.figure)
pyplot.close(grid.figure)

In [None]:
display(Markdown("## scrambled sequences (all barcodes and replicates)"))

data = normalized_samples_counts.filter(polars.col("library_id").str.contains("Scramble")).to_pandas()
data["genotype"].cat = data["genotype"].cat.remove_unused_categories()

grid = seaborn.catplot(kind="strip", data=data, row="library", x="genotype", y="count", hue="replicate", dodge=True, height=3, aspect=2, palette="Set1", alpha=0.1, linewidth=0)

for axis in grid.axes.flat:
    axis.axhline(10, color="#00000080", linestyle=":")
    axis.set_yscale("symlog", linthresh=10)
    axis.set_ylim(bottom=0)

display(grid.figure)
pyplot.close(grid.figure)

In [None]:
all_r2s = {}

for (library, genotype), dataframe in normalized_samples_counts.sort(["library", "genotype"]).groupby(["library", "genotype"], maintain_order=True):
    display(Markdown(f"## {genotype} ({library})"))
    data = dataframe.with_columns([
        polars.col("activity") + 1e-4
    ]).pivot(values="activity", index="library_id", columns="replicate", aggregate_function="mean")
    
    figure, axes = pyplot.subplots(nrows=1, ncols=3, sharex=False, sharey=False, figsize=(5.25, 1.75))
    
    comparisons = [("1", "2"), ("1", "3"), ("2", "3")]
    
    bounds = (1e-5, 1e2)
    
    for index, ((rep1, rep2), axis) in enumerate(zip(comparisons, axes.flat)):
        x = data[rep1].to_numpy()
        y = data[rep2].to_numpy()
        
        bins = numpy.geomspace(start=bounds[0], stop=bounds[1], num=150)
        
        # Calculate R^2
        fit = scipy.stats.linregress(numpy.log10(x), numpy.log10(y))
        fit_func = lambda x: fit.slope * x + fit.intercept

        if library not in all_r2s:
            all_r2s[library] = []
        all_r2s[library].append(fit.rvalue**2)
        
        # Plot wildtype CREs as a scatterplot colored by density
        weights, color_normalizer = calc_weights_for_scatter(x, y, bins=bins)
        axis.scatter(**reorder_points_and_weights(x, y, weights), cmap=cm_light_to_gray, norm=color_normalizer, s=6, marker="o", rasterized=True, zorder=1, label=f"R²={fit.rvalue**2:.2f}")
        
        axis.set(xlim=bounds, ylim=bounds, xlabel=f"Rep{rep1}", ylabel=f"Rep{rep2}")
        axis.set_xscale("log", base=10)
        axis.set_yscale("log", base=10)
        axis.tick_params(which="major", bottom=True, top=False, left=True, right=False, labelbottom=False, labeltop=False, labelleft=False, labelright=False)
        axis.tick_params(which="minor", bottom=False, top=False, left=False, right=False)

        axis.set_aspect(aspect=1, adjustable="box", anchor="C")
        
        handles, labels = axis.get_legend_handles_labels()
        legend = axis.legend(handles, labels, loc="upper left")

    for axis in axes.flat[len(comparisons):]:
        axis.axis("off")

    figure.tight_layout()
    figure.savefig(f"Reproduciblilty/reproducibility_{library}_{genotype}.pdf", bbox_inches="tight")
    display(figure)
    pyplot.close(figure)
    
for library, r2_values in all_r2s.items():
    display(Markdown(f"#### {library} avg. R^2"))
    print("{:.2f}".format(numpy.mean(r2_values)))

In [None]:
pipeline = normalized_samples_counts.lazy()
    
"""
Aggregate counts for each library member (CRE) across all barcodes and replicates; compute summary statistics
"""

pipeline = pipeline.groupby(["library", "genotype", "library_id"]).agg([
    polars.col("activity").mean().alias("activity_mean"),
    polars.col("activity").std().alias("activity_std"),
    polars.count().alias("n_observations"),
]).with_columns([
    (polars.col("activity_std") / polars.col("activity_mean")).alias("activity_cov")
]).with_columns([
    (polars.col("activity_mean") / (polars.col("activity_cov").pow(2) + 1).sqrt()).log().alias("activity_mu"),
    (polars.col("activity_cov").pow(2) + 1).log().sqrt().alias("activity_sigma"),
])

"""
Within each genotype, calculate t-tests for each CRE vs basal
"""

def _calculate_t_tests(arguments):
    results = []
    for i in range(len(arguments[0])):
        results.append(scipy.stats.ttest_ind_from_stats(*map(lambda l: l[i], arguments), equal_var=False)[1])
    return polars.Series("pvalue", results)

pipeline = pipeline.join(
    pipeline.filter(polars.col("library_id") == "basal").select([
        polars.col("library"),
        polars.col("genotype"),
        polars.col("activity_mean").prefix("basal_"),
        polars.col("activity_mu").prefix("basal_"),
        polars.col("activity_sigma").prefix("basal_"),
        polars.col("n_observations").prefix("basal_")
    ]),
    on=["library", "genotype"], how="left"
).with_columns([
    polars.map(["basal_activity_mu", "basal_activity_sigma", "basal_n_observations", "activity_mu", "activity_sigma", "n_observations"], _calculate_t_tests).fill_nan(1).alias("pvalue")
]).with_columns([
    (polars.col("pvalue") * (polars.col("pvalue").count() / polars.col("pvalue").rank(method="max"))).alias("qvalue")
])

"""
Compute 5th and 95 %ile of scrambled library members within each genotype; classify CREs
"""

pipeline = pipeline.join(
    pipeline.filter(polars.col("library_id").str.contains("Scramble")).groupby(["library", "genotype"]).agg([
        polars.col("activity_mean").apply(lambda series: numpy.percentile(series, 5)).alias("scrambled_5th_percentile"),
        polars.col("activity_mean").apply(lambda series: numpy.percentile(series, 20)).alias("scrambled_20th_percentile"),
        polars.col("activity_mean").apply(lambda series: numpy.percentile(series, 80)).alias("scrambled_80th_percentile"),
        polars.col("activity_mean").apply(lambda series: numpy.percentile(series, 95)).alias("scrambled_95th_percentile")
    ]),
    on=["library", "genotype"], how="left"
).with_columns([
    polars.when((polars.col("activity_mean") > polars.col("scrambled_95th_percentile")) & (polars.col("qvalue") < 0.05)).then("strong_enhancer")
        .when((polars.col("activity_mean") > polars.col("basal_activity_mean")) & (polars.col("qvalue") < 0.05)).then("weak_enhancer")
        .when((polars.col("activity_mean") < polars.col("scrambled_5th_percentile"))& (polars.col("qvalue") < 0.05)).then("strong_silencer")
        .when((polars.col("activity_mean") < polars.col("basal_activity_mean")) & (polars.col("qvalue") < 0.05)).then("weak_silencer")
        .otherwise("inactive").cast(polars.Categorical).alias("activity_class")
]).drop([
    "scrambled_5th_percentile",
    "scrambled_20th_percentile",
    "scrambled_80th_percentile",
    "scrambled_95th_percentile",
    "activity_cov",
    "basal_activity_mean",
    "basal_activity_mu",
    "basal_activity_sigma",
    "basal_n_observations"
])

pipeline = pipeline.sort(["library", "library_id", "genotype"])

processed_samples_counts = pipeline.collect()

display(Markdown("### Aggregated Normalized Counts"))
display(processed_samples_counts)

"""
Write out processed counts
"""

processed_samples_counts.write_parquet("processed_library_counts.parquet")

In [None]:
display(Markdown("## scrambled sequences (means across replicates and barcodes)"))

data = processed_samples_counts.filter(polars.col("library_id").str.contains("Scramble")).to_pandas()
data["genotype"].cat = data["genotype"].cat.remove_unused_categories()

grid = seaborn.catplot(kind="strip", data=data, row="library", x="genotype", hue="genotype", y="activity_mean", height=3, aspect=2, palette="Set1", alpha=0.4, linewidth=0)
for axis in grid.axes.flat:
    axis.set_yscale("symlog", linthresh=1e-1)
    axis.set_ylim(bottom=0)
display(grid.figure)
pyplot.close(grid.figure)