In [None]:
import string
import sys
import warnings

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from IPython.display import display
from scipy import stats
from sklearn import metrics

import rp2.data
import rp2.paths
import rp2.regression
import rp2.txburst
from rp2 import hagai_2018

rp2.check_environment()

In [None]:
np.seterr(all="warn")
warnings.filterwarnings("error", category=RuntimeWarning)

In [None]:
def report_package_versions(package_list):
    for package in package_list:
        m = __import__(package)
        print(f"{package} version: {m.__version__}")


print(f"Python version: {sys.version.split()[0]}")

report_package_versions(["matplotlib", "numpy", "pandas", "scanpy", "scipy", "seaborn", "statsmodels"])

In [None]:
sns.set(style="white", palette="bright", font="sans-serif", context="paper")

In [None]:
report_path = rp2.paths.get_output_path("report")
rp2.create_folder(report_path, create_clean=True)

def figure_path(name):
    return report_path.joinpath(f"figure_{name}.svg")


def save_figure(name):
    plt.savefig(figure_path(name))


def save_table(df, name):
    out_path = report_path.joinpath(f"table_{name}.csv")
    df.to_csv(out_path)


def display_series(s, indent=0):
    width = s.index.str.len().max() + 2
    for n, v in s.items():
        pad = width - len(n)
        print(f"{' ' * indent}{n}{' ' * pad}{v}")


def label_subplots(axes, titles=string.ascii_lowercase, font_size=16, join_infix=None):
    for ax, title in zip(axes, titles):
        title = f"({title})"
        if join_infix is not None:
            current_title = ax.title.get_text()
            if current_title != "":
                title = f"{title}{join_infix}{current_title}"
        ax.set_title(title, fontsize=font_size)


def make_gene_selector(symbol_series, rows=3):
    symbol_series = symbol_series.sort_values()
    return widgets.Select(options=list(zip(symbol_series.values, symbol_series.index)), rows=rows)


def ensure_that(condition):
    if not condition:
        raise AssertionError

In [None]:
def test_pvalues(pval_series, alpha):
    results = sm.stats.multipletests(pval_series, alpha, method="fdr_bh")
    return pd.DataFrame(index=pval_series.index, data={"pvals_corrected": results[1], "reject": results[0]})

In [None]:
condition_columns = ["replicate", "treatment", "time_point"]
index_columns = ["gene"] + condition_columns
time_points = ["0", "2", "4", "6"]

gene_info_df = rp2.load_biomart_gene_symbols_df("mouse")

## Acquisition and preparation of RNA counts

In [None]:
mouse_umi_adata = hagai_2018.load_umi_counts("mouse")
mouse_umi_adata = mouse_umi_adata[mouse_umi_adata.obs.time_point.isin(time_points)].copy()

print("Full Hagai mouse dataset has:")
print(f"  {mouse_umi_adata.n_obs:,} cells")
print(f"  {mouse_umi_adata.n_vars:,} genes")

ensure_that(mouse_umi_adata.n_obs == 53_086)
ensure_that(mouse_umi_adata.n_vars == 22_048)

del mouse_umi_adata

In [None]:
mouse_counts_adata = hagai_2018.load_counts("mouse", scaling="median")
mouse_counts_adata = mouse_counts_adata[mouse_counts_adata.obs.time_point.isin(time_points)].copy()

print("Scaled Hagai mouse dataset has:")
print(f"  {mouse_counts_adata.n_vars:,} genes")

ensure_that(mouse_counts_adata.n_obs == 53_086)
ensure_that(mouse_counts_adata.n_vars == 16_798)

In [None]:
cell_counts = mouse_counts_adata.X.sum(axis=1)
counts_per_cell = cell_counts[0].item()
print(f"Counts per cell: {counts_per_cell:,}")

ensure_that(counts_per_cell == 9_161)
ensure_that(cell_counts[-1].item() == counts_per_cell)

In [None]:
lps_responsive_gene_ids = mouse_counts_adata.var.index[mouse_counts_adata.var.lps_responsive]
print(f"{len(lps_responsive_gene_ids):,} genes are LPS-responsive")

ensure_that(len(lps_responsive_gene_ids) == 2_336)

In [None]:
additional_genes = gene_info_df.symbol[gene_info_df.symbol.isin(["Tnf", "Il1b"])]
analysis_gene_ids = sorted(set(lps_responsive_gene_ids).union(additional_genes.index))

print(f"{len(analysis_gene_ids):,} genes to be used in analysis")

ensure_that(len(analysis_gene_ids) == 2_338)

In [None]:
condition_df = mouse_counts_adata.obs[condition_columns].drop_duplicates()

print(f"{len(condition_df)} conditions per gene")
print("Per replicate:")
display_series(condition_df.replicate.value_counts().sort_index(), indent=2)
print(f"{len(condition_df) * len(analysis_gene_ids):,} data points overall")

ensure_that(len(condition_df) == 20)
ensure_that((20 * 2_337) == 46_740)

del condition_df

## Verification of linear mean-variance relationship of RNA response

In [None]:
mv_slope_alpha = 0.05
mv_slope_r2_min = 0.6
mv_intercept_alpha = 0.05

In [None]:
analysis_count_adata = mouse_counts_adata[:, analysis_gene_ids].copy()
gene_condition_stats_df = hagai_2018.calculate_counts_condition_stats(analysis_count_adata)

ensure_that(len(gene_condition_stats_df) == 46_760)

In [None]:
def fit_mean_variance_trends(df):
    x = sm.add_constant(df["mean"])
    y = df["variance"]
    model = sm.RLM(y, x, M=sm.robust.norms.HuberT(t=1.345))
    ensure_that(model.M.t == 1.345)

    rlm_results = model.fit()

    results = {
        "intercept": rlm_results.params[0],
        "slope": rlm_results.params[1],
        "intercept_pval": rlm_results.pvalues[0],
        "slope_pval": rlm_results.pvalues[1],
        "r2_unweighted": metrics.r2_score(y, rlm_results.fittedvalues),
        "r2_weighted": metrics.r2_score(y, rlm_results.fittedvalues, sample_weight=rlm_results.weights),
    }
    return pd.Series(results)


treatment_sets = {
    "all": ["unst", "lps", "pic"],
#    "lps": ["unst", "lps"],
#    "pic": ["unst", "pic"],
}

mv_fit_map = {set_name: gene_condition_stats_df[gene_condition_stats_df.treatment.isin(set_list)].groupby("gene").apply(fit_mean_variance_trends)
              for set_name, set_list in treatment_sets.items()}

In [None]:
def test_mv_fit(mv_fit_df):
    mv_test_df = mv_fit_df.copy()

    intercept_test = test_pvalues(mv_test_df["intercept_pval"], mv_intercept_alpha)
    mv_test_df["intercept_pval_adj"] = intercept_test.pvals_corrected
    mv_test_df["accept_intercept"] = intercept_test.reject

    slope_test = test_pvalues(mv_test_df["slope_pval"], mv_slope_alpha)
    mv_test_df["slope_pval_adj"] = slope_test.pvals_corrected
    mv_test_df["accept_slope"] = slope_test.reject

    mv_test_df["accept_r2"] = mv_test_df["r2_unweighted"] > mv_slope_r2_min

    return mv_test_df


all_treatment_mv_fit = test_mv_fit(mv_fit_map["all"])
display_series(all_treatment_mv_fit[[c for c in all_treatment_mv_fit.columns if c.startswith("accept_")]].agg(np.count_nonzero))

ensure_that(mv_intercept_alpha == 0.05)
ensure_that(mv_slope_alpha == 0.05)
ensure_that(mv_slope_r2_min == 0.6)

ensure_that(np.count_nonzero(all_treatment_mv_fit.accept_intercept) == 614)
ensure_that(np.count_nonzero(all_treatment_mv_fit.accept_slope) == 2_133)

In [None]:
def plot_mv_r2_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(all_treatment_mv_fit.loc[all_treatment_mv_fit.accept_slope].r2_unweighted, bins=50, kde=False, ax=ax)
    ax.axvline(x=0.6, ls="--", label="Cut-off")
    ax.set_xlabel("$R^2$ of slope")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    ax.legend()
    return ax


plot_mv_r2_hist()
plt.show()

In [None]:
all_treatment_good_mv_fit = all_treatment_mv_fit.loc[all_treatment_mv_fit.accept_slope & all_treatment_mv_fit.accept_r2]
all_treatment_good_mv_fit.insert(0, "symbol", gene_info_df.loc[all_treatment_good_mv_fit.index].symbol)

accepted_intercept_df = all_treatment_good_mv_fit.loc[all_treatment_good_mv_fit.accept_intercept]

print(f"{len(all_treatment_good_mv_fit):,} mean-variance trends have a good fit (based on slope and unweighted R2)")
print(f"  i.e. {100 * (len(all_treatment_good_mv_fit) / len(analysis_gene_ids)):.1f}% of all analysis genes")
print(f"  {len(accepted_intercept_df):,} have a significant intercept")
print(f"    i.e. {100 * (len(accepted_intercept_df) / len(all_treatment_good_mv_fit)):.1f}% of good fits")
print(f"    {np.count_nonzero(accepted_intercept_df.intercept < 0)} are negative")
print(f"    {np.count_nonzero(accepted_intercept_df.intercept > 0)} are positive")

ensure_that(len(all_treatment_good_mv_fit) == 1_551)
ensure_that(len(accepted_intercept_df) == 430)
ensure_that(np.round(100 * (len(all_treatment_good_mv_fit) / len(analysis_gene_ids))) == 66)
ensure_that(np.count_nonzero(all_treatment_good_mv_fit.slope < 0) == 0)
ensure_that(np.count_nonzero(accepted_intercept_df.intercept < 0) == 414)
ensure_that(np.count_nonzero(accepted_intercept_df.intercept > 0) == 16)

del accepted_intercept_df

In [None]:
def make_value_range_str(series):
    return f"{series.min():.2f} to {series.max():.2f}"


print("Slope range:", make_value_range_str(all_treatment_good_mv_fit.slope))
print("Intercept range:", make_value_range_str(all_treatment_good_mv_fit.intercept[all_treatment_good_mv_fit.accept_intercept]))

In [None]:
def plot_mv_slope_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(all_treatment_good_mv_fit.slope), bins=100, kde=False, ax=ax)
    ax.set_xlabel("Slope (log$_{10}$)")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    return ax


plot_mv_slope_hist()
plt.show()

In [None]:
def plot_mv_max_rna_hist(ax=None):
    ax = ax or plt.subplots()[1]

    stats_df = gene_condition_stats_df.loc[gene_condition_stats_df.gene.isin(all_treatment_good_mv_fit.index)]
    gene_rna_max = stats_df.groupby("gene")["mean"].agg("max")
    ensure_that(len(gene_rna_max) == 1_551)

    sns.distplot(np.log10(gene_rna_max), bins=50, kde=False, ax=ax)
    ax.set_xlabel("Maximum mean RNA (log$_{10}$)")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    return ax


plot_mv_max_rna_hist()
plt.show()

In [None]:
_, axes = plt.subplots(ncols=3, constrained_layout=True, figsize=(12, 3))
plot_mv_r2_hist(axes[0])
plot_mv_slope_hist(axes[1])
plot_mv_max_rna_hist(axes[2])
label_subplots(axes)
save_figure("2")
plt.show()

In [None]:
def get_closest_index(series, value):
    abs_diff = (series - value).abs()
    return np.argmin(abs_diff)


mv_r2_examples_df = all_treatment_mv_fit.iloc[[get_closest_index(all_treatment_mv_fit.r2_unweighted, r2) for r2 in [0.55, 0.65, 0.8]]]
display(mv_r2_examples_df)

In [None]:
def select_mv_examples(df, var_name):
    return df.sort_values(by=var_name).iloc[[0, len(df) // 2, -1]]


mv_slope_examples_df = select_mv_examples(all_treatment_good_mv_fit, "slope")
display(mv_slope_examples_df)

In [None]:
mv_intercept_examples_df = select_mv_examples(all_treatment_good_mv_fit.loc[all_treatment_good_mv_fit.accept_intercept], "intercept")
display(mv_intercept_examples_df)

In [None]:
mv_rna_max_examples = select_mv_examples(all_treatment_good_mv_fit.join(gene_condition_stats_df.groupby("gene")["mean"].agg("max").rename("rna_max")), "rna_max")
display(mv_rna_max_examples)

In [None]:
def format_intercept_str(intercept, fmt_func):
    prefix = "+" if intercept >= 0 else ""
    return f"{prefix}{fmt_func(intercept)}"


def format_mv_str(mv_series):
    fmt_func = "{:.2f}".format if mv_series.slope < 10 else "{:.0f}".format
    intercept_str = format_intercept_str(mv_series.intercept, fmt_func)
    if not mv_series.accept_intercept:
        intercept_str = f"({intercept_str})"
    slope_str = fmt_func(mv_series.slope)
    return f"$\\sigma^2={slope_str}\mu{intercept_str}$"


def plot_mv_fit(mv_fit, ax=None):
    ax = ax or plt.subplots()[1]

    gene_id = mv_fit.name
    gene_symbol = gene_info_df.symbol[gene_id]
    ax.set_title(f"$\it{{{gene_symbol}}}$: $R^2={mv_fit.r2_unweighted:.2f}$\n{format_mv_str(mv_fit)}")

    stats_df = gene_condition_stats_df.loc[gene_condition_stats_df.gene == gene_id]
    m = stats_df["mean"]
    v = stats_df["variance"]
    sns.scatterplot(m, v, hue=stats_df["treatment"], ax=ax)

    lx = np.asarray((0, m.max()))
    ly = (lx * mv_fit.slope) + mv_fit.intercept
    ax.plot(lx, ly, ls=":")

    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)
    ax.set_xlabel("Mean RNA count ($\mu$)")
    ax.set_ylabel("Variance ($\sigma^2$)")
    return ax


_, axes = plt.subplots(nrows=3, ncols=3, constrained_layout=True, figsize=(12, 9))
plot_mv_fit(mv_r2_examples_df.iloc[0], axes[0, 0])
plot_mv_fit(mv_r2_examples_df.iloc[1], axes[0, 1])
plot_mv_fit(mv_r2_examples_df.iloc[2], axes[0, 2])
plot_mv_fit(mv_slope_examples_df.iloc[0], axes[1, 0])
plot_mv_fit(mv_slope_examples_df.iloc[1], axes[1, 1])
plot_mv_fit(mv_slope_examples_df.iloc[2], axes[1, 2])
plot_mv_fit(mv_intercept_examples_df.iloc[0], axes[2, 0])
plot_mv_fit(mv_intercept_examples_df.iloc[1], axes[2, 1])
plot_mv_fit(mv_intercept_examples_df.iloc[2], axes[2, 2])
label_subplots(axes.flat, join_infix=" ")
save_figure("3")
plt.show()

In [None]:
mv_additional_genes = additional_genes.to_frame().join(all_treatment_mv_fit)
display(mv_additional_genes)

ensure_that(np.count_nonzero(mv_additional_genes.index.isin(all_treatment_good_mv_fit.index)) == 2)

In [None]:
_, axes = plt.subplots(ncols=2, constrained_layout=True, figsize=(8, 3))
plot_mv_fit(mv_additional_genes.iloc[0], axes[0])
plot_mv_fit(mv_additional_genes.iloc[1], axes[1])
save_figure("4")
plt.show()

## Fitting of bursting parameters

In [None]:
txburst_df = rp2.data.load_and_recalculate_txburst_results("mouse", condition_columns, "median")
txburst_df = txburst_df.loc[txburst_df.time_point.isin(time_points)]
txburst_df = txburst_df.loc[txburst_df.gene.isin(all_treatment_good_mv_fit.index)]
txburst_df = txburst_df.copy()

print("For the well-fitted genes:")
print(f"  txburst results are available for {len(txburst_df):,} data points")
print(f"  Across {txburst_df.gene.nunique():,} genes")
print(f"  {len(all_treatment_good_mv_fit) - txburst_df.gene.nunique()} genes have no conditions")

ensure_that(len(txburst_df) == 7_804)
ensure_that(len(txburst_df[condition_columns].drop_duplicates()) == 20)
ensure_that(txburst_df.gene.nunique() == 1_519)
#ensure_that(len(all_treatment_good_mv_fit) - txburst_df.gene.nunique() == 32)

In [None]:
attempted_txburst_df = rp2.data.load_txburst_results("mouse", condition_columns, "median")
attempted_txburst_df = attempted_txburst_df.loc[attempted_txburst_df.time_point.isin(time_points)]
attempted_txburst_df = attempted_txburst_df.loc[attempted_txburst_df.gene.isin(all_treatment_good_mv_fit.index)]

failed_txburst_df = attempted_txburst_df.set_index(index_columns)
failed_txburst_df = failed_txburst_df.drop(index=txburst_df.set_index(index_columns).index)

print("For the well-fitted genes:")
print(f"  txburst inference was attempted for {len(attempted_txburst_df):,} data points")
print(f"  {len(failed_txburst_df):,} failed")

ensure_that(len(attempted_txburst_df) == 31_020)
ensure_that(len(gene_condition_stats_df.loc[gene_condition_stats_df.gene.isin(all_treatment_good_mv_fit.index)]) == 31_020)
ensure_that(len(failed_txburst_df) == 23_216)

del attempted_txburst_df

In [None]:
txburst_gene_condition_counts = txburst_df.gene.value_counts()
txburst_gene_id_subset = txburst_gene_condition_counts[txburst_gene_condition_counts >= 10]

print(f"{len(txburst_gene_id_subset):,} genes have >= 10 txburst results")

ensure_that(len(txburst_gene_id_subset) == 99)

In [None]:
data_point_info_df = txburst_df.loc[txburst_df.gene.isin(txburst_gene_id_subset.index)].set_index(index_columns)
data_point_info_df = data_point_info_df.join(gene_condition_stats_df.set_index(index_columns)[["mean"]].add_prefix("rna_")).reset_index()
data_point_info_df["burstiness"] = data_point_info_df.k_off / data_point_info_df.k_on
data_point_info_df["log_burstiness"] = np.log10(data_point_info_df.burstiness)

analysis_mv_fit_df = all_treatment_good_mv_fit.loc[txburst_gene_id_subset.index]

print(f"{len(data_point_info_df):,} data points are available")

ensure_that(data_point_info_df.gene.nunique() == 99)
ensure_that(len(analysis_mv_fit_df) == 99)
ensure_that(len(data_point_info_df) == 1_343)

In [None]:
def plot_mv_fit_slope_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(analysis_mv_fit_df.slope), bins=20, kde=False, ax=ax)
    ax.set_xlabel("Slope (log$_{10}$)")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    return ax


plot_mv_fit_slope_hist()
plt.show()

### Plot inferred kinetic parameters

In [None]:
def plot_ksyn_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df.k_syn), bins=20, kde=False, ax=ax)
    ax.set_xlabel("$k_s$ (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


plot_ksyn_hist()
plt.show()

In [None]:
def plot_koff_kon_scatter(ax=None):
    ax = ax or plt.subplots()[1]
    sns.scatterplot(np.log10(data_point_info_df.k_off), np.log10(data_point_info_df.k_on), ax=ax)
    ax.plot((-2, 1.5), (-2, 1.5), ls="--", label="$k_{on}=k_{off}$")
    ax.set_xlabel("$k_{off}$ (log$_{10}$)")
    ax.set_ylabel("$k_{on}$ (log$_{10}$)")
    ax.legend(loc="upper left")
    return ax


plot_koff_kon_scatter()
plt.show()

#### QUESTION: are $k_{on}<k_{off}$ cases expected?

#### QUESTION: are the $k_{off}$ values being clipped at 1,000?

I believe the txburst optimisation is limited to a parameter search space of [0, 1000]. However, this doesn't appear to be affecting too many points:

In [None]:
data_point_info_df.k_off.sort_values().reset_index(drop=True).plot.line()
plt.show()

In [None]:
def plot_burstiness_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df["burstiness"]), bins=20, kde=False, ax=ax)
    ax.set_xlabel("Burstiness (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


plot_burstiness_hist()
plt.show()

In [None]:
_, axes = plt.subplots(ncols=3, constrained_layout=True, figsize=(12, 4))
plot_ksyn_hist(axes[0])
plot_koff_kon_scatter(axes[1])
plot_burstiness_hist(axes[2])
label_subplots(axes)
save_figure("5")
plt.show()

#### Plot example fits

In [None]:
def get_condition_subset(df, replicate, treatment, time_point):
    df = df.loc[df.replicate == replicate]
    df = df.loc[df.treatment == treatment]
    df = df.loc[df.time_point == time_point]
    return df


def format_k_param(value):
    fmt = "{:.2f}" if value < 5 else "{:.0f}"
    return fmt.format(value)


def plot_txburst_fit_from_series(tx_series, ax=None):
    ax = ax or plt.subplots()[1]

    gene_id, replicate, treatment, time_point = tx_series[["gene", "replicate", "treatment", "time_point"]]

    obs_subset = get_condition_subset(mouse_counts_adata.obs, replicate, treatment, time_point)
    adata = mouse_counts_adata[obs_subset.index, gene_id]
    counts = adata.X.A.squeeze()

    bin_values, bin_edges, _ =  ax.hist(counts, bins=50)
    hist_area = np.sum(np.diff(bin_edges) * bin_values)

    k_on, k_off, k_syn = tx_series[["k_on", "k_off", "k_syn"]]

    max_count = np.max(counts)
    pmf_in = np.arange(max_count)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        pmf_out = rp2.txburst.poisson_beta_pmf(pmf_in, k_on, k_off, k_syn)
    ax.plot(pmf_in + np.mean(bin_edges[:2]), hist_area * pmf_out, ls="--", lw=2)

    gene_symbol = gene_info_df.symbol[gene_id]
    gene_info_str = f"$\it{{{gene_symbol}}}$: Individual {replicate}, {treatment.upper()}, {time_point}h"
    param_info_str = f"$k_{{on}}={format_k_param(k_on)}$, $k_{{off}}={format_k_param(k_off)}$, $k_s={format_k_param(k_syn)}$"
    ax.set_title(f"{gene_info_str}\n{param_info_str}")

    ax.set_xlabel("RNA count")
    ax.set_ylabel("Number of cells")
    return ax


def plot_txburst_fit(gene_id, replicate, treatment, time_point, ax=None):
    data_point_subset = get_condition_subset(data_point_info_df, replicate, treatment, time_point)
    data_point_subset = data_point_subset.loc[data_point_subset.gene == gene_id]
    return plot_txburst_fit_from_series(data_point_subset.squeeze(), ax=ax)


_, axes = plt.subplots(ncols=2, constrained_layout=True, figsize=(10, 3))
for gene_id, ax in zip(additional_genes.index, axes):
    plot_txburst_fit(gene_id, "3", "lps", "6", ax=ax)
label_subplots(axes, join_infix=" ")
save_figure("6")
plt.show()

#### QUESTION: can we find a bad k-terms fit?

In [None]:
bad_txburst_fit = failed_txburst_df.join(gene_condition_stats_df.set_index(index_columns)["mean"].rename("rna_mean")).sort_values("rna_mean").reset_index().iloc[-1]

plot_txburst_fit_from_series(bad_txburst_fit)
plt.show()

#### Plot calculated bursting parameters

In [None]:
def plot_bs_bf_scatter(ax=None):
    ax = ax or plt.subplots()[1]
    gene_labels = data_point_info_df.join(additional_genes.rename("Gene"), on="gene").Gene.fillna("")
    sns.scatterplot(
        np.log10(data_point_info_df.bs_point),
        np.log10(data_point_info_df.bf_point),
        hue=gene_labels,
        hue_order=gene_labels.sort_values().unique(),
        s=20,
        ax=ax,
    )
    ax.set_xlabel("Burst size (log$_{10}$)")
    ax.set_ylabel("Burst frequency (log$_{10}$)")
    return ax


def plot_bs_hist(ax):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df.bs_point), bins=50, kde=False, ax=ax)

    ax.set_xlabel("Burst size (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


def plot_bf_hist(ax):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df.bf_point), bins=50, kde=False, ax=ax)
    ax.set_xlabel("Burst frequency (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


def plot_bp_fig():
    fig = plt.figure(constrained_layout=True, figsize=(12, 8))
    gs = fig.add_gridspec(2, 3)
    ax_a = fig.add_subplot(gs[0:2, 0:2])
    ax_b = fig.add_subplot(gs[0, 2])
    ax_c = fig.add_subplot(gs[1, 2])
    plot_bs_bf_scatter(ax_a)
    plot_bs_hist(ax_b)
    plot_bf_hist(ax_c)
    label_subplots([ax_a, ax_b, ax_c])


plot_bp_fig()
save_figure("7")
plt.show()

#### QUESTION: are points in the lower right quadrant of interest?

In [None]:
display(data_point_info_df.loc[(np.log10(data_point_info_df.bs_point) > 1) & (np.log10(data_point_info_df.bf_point) < -1.5)].set_index("gene").join(gene_info_df.symbol).reset_index())

In [None]:
def plot_conditions_per_gene(ax=None):
    ax = ax or plt.subplots()[1]
    counts = txburst_gene_condition_counts.value_counts().sort_index()
    sns.barplot(x=counts.index, y=counts.values, ax=ax)
    ax.set_xlabel("Number of conditions")
    ax.set_ylabel("Number of genes")
    return ax


_, axes = plt.subplots(ncols=2, constrained_layout=True, figsize=(10, 3), gridspec_kw={'width_ratios': [4, 3]})
plot_conditions_per_gene(axes[0])
plot_mv_fit_slope_hist(axes[1])
label_subplots(axes)
save_figure("8")
plt.show()

## Bursting parameter trends

### Trends based on Spearman rank correlation

In [None]:
bp_trends_spearman_alpha = 0.05

In [None]:
def calculate_spearman_r(df, x_var, y_var):
    sp_corr = stats.spearmanr(df[x_var], df[y_var])
    return pd.Series(data={"r": sp_corr.correlation, "r_pval": sp_corr.pvalue})


bp_spearman_df_map = {c: data_point_info_df.groupby("gene").apply(calculate_spearman_r, "rna_mean", c)
                      for c in ["bs_point", "bf_point", "log_burstiness"]}

In [None]:
def concat_bp_df_map(df_map):
    return pd.concat(df_map.values(), keys=df_map.keys(), names=["param"])


def determine_spearman_trends(df, pval):
    trend_df = pd.DataFrame(index=df.index, data={"possible": "constant"})
    trend_df.loc[df.r < 0, "possible"] = "decreasing"
    trend_df.loc[df.r > 0, "possible"] = "increasing"
    trend_df["accept_r"] = df.r_pval < pval
    trend_df["trend"] = trend_df.possible.where(trend_df.accept_r, "uncertain")
    return trend_df


def determine_and_display_spearman_trends(pval):
    concat_df = concat_bp_df_map(bp_spearman_df_map)
    trend_df = determine_spearman_trends(concat_df, pval)
    counts = trend_df.groupby("param").trend.value_counts().sort_index(level=0).rename("count")

    plot_output = widgets.Output()
    with plot_output:
        sns.barplot(x="param", y="count", hue="trend", data=counts.to_frame().reset_index(), hue_order=["decreasing", "increasing", "uncertain"])
        plt.show()

    info_output = widgets.Output()
    with info_output:
        display(counts)

    display(widgets.HBox([plot_output, info_output]))


widgets.interactive(
    determine_and_display_spearman_trends,
    pval=widgets.BoundedFloatText(value=bp_trends_spearman_alpha, min=0, max=1, step=0.01),
)

In [None]:
ensure_that(bp_trends_spearman_alpha == 0.05)

bp_spearman_trends_df = determine_spearman_trends(concat_bp_df_map(bp_spearman_df_map), bp_trends_spearman_alpha)
print("Trends based on Spearman rank correlation:")
display(bp_spearman_trends_df.groupby("param").trend.value_counts().sort_index(level=0))

### Trends based on curve fitting

In [None]:
bp_trends_curve_r2_min = 0.4

In [None]:
def fit_bp_curve(df, y_var):
    x_var = "rna_mean"
    results = rp2.regression.calculate_curve_fit(df, x_var, y_var, loss_function="huber", f_scale=1.0)
    a, b, c = results["a"], results["b"], results["c"]
    if a is np.nan:
        return None
    results["start"], results["end"] = rp2.regression.power_function((df[x_var].min(), df[x_var].max()), a, b, c)
    return pd.Series(data=results)


bp_curve_df_map = {c: data_point_info_df.groupby("gene").apply(fit_bp_curve, c)
                  for c in ["bs_point", "bf_point", "log_burstiness"]}

In [None]:
def determine_bp_curve_trends(df, r2):
    trend_df = pd.DataFrame(index=df.index, data={"trend": "uncertain"})
    trend_df["accept_r2"] = df.r2 > r2
    trend_df.loc[trend_df.accept_r2 & (df["start"] < df["end"]), "trend"] = "increasing"
    trend_df.loc[trend_df.accept_r2 & (df["start"] > df["end"]), "trend"] = "decreasing"
    trend_df.loc[trend_df.accept_r2 & (df["start"] == df["end"]), "trend"] = "constant"
    return trend_df


def determine_and_display_curve_trends(r2):
    concat_df = concat_bp_df_map(bp_curve_df_map)
    trend_df = determine_bp_curve_trends(concat_df, r2)
    counts = trend_df.groupby("param").trend.value_counts().sort_index(level=0).rename("count")

    plot_output = widgets.Output()
    with plot_output:
        sns.barplot(x="param", y="count", hue="trend", data=counts.to_frame().reset_index(), hue_order=["decreasing", "increasing", "uncertain"])
        plt.show()

    info_output = widgets.Output()
    with info_output:
        display(counts)

    display(widgets.HBox([plot_output, info_output]))


widgets.interactive(
    determine_and_display_curve_trends,
    r2=widgets.FloatSlider(value=bp_trends_curve_r2_min, min=0, max=1, step=0.05),
)

In [None]:
ensure_that(bp_trends_curve_r2_min == 0.4)

bp_curve_trends_df = determine_bp_curve_trends(concat_bp_df_map(bp_curve_df_map), bp_trends_curve_r2_min)

print("Trends based on curve-fitting:")
display(bp_curve_trends_df.groupby("param").trend.value_counts().sort_index())

### Comparison of trends

In [None]:
def plot_trend_comparison(param, ax=None, cbar_ax=None):
    ax = ax or plt.subplots()[1]

    ct = pd.crosstab(
        bp_curve_trends_df.loc[param].trend.rename("Curve fit"),
        bp_spearman_trends_df.loc[param].trend.rename("Spearman correlation"),
        margins=True,
    )
    ct.index = ct.index.str.title()
    ct.columns = ct.columns.str.title()

    # Bodge for missing category
    if ct.index[0] != "Decreasing":
        ct = pd.DataFrame([[0] * len(ct.columns)], index=["Decreasing"], columns=ct.columns).append(ct)

    show_cbar = cbar_ax is not None
    sns.heatmap(
        ct,
        annot=True,
        cbar=show_cbar,
        cbar_ax=cbar_ax,
        ax=ax,
    )

    # Workaround for misaligned labels
    ax.set_yticklabels(ax.get_yticklabels(), verticalalignment="center")

    return ax


_, axes = plt.subplots(ncols=4, constrained_layout=True, figsize=(13, 4), gridspec_kw={'width_ratios': [1, 1, 1, 0.1]})
plot_trend_comparison("bs_point", axes[0])
axes[0].set_title("Burst size")
plot_trend_comparison("bf_point", axes[1])
axes[1].set_title("Burst frequency")
plot_trend_comparison("log_burstiness", axes[2], cbar_ax=axes[3])
axes[2].set_title("Burstiness (log$_{10}$)")
label_subplots(axes[:-1], join_infix="\n")
save_figure("9")
plt.show()

In [None]:
bp_trends = bp_curve_trends_df.trend.where(bp_curve_trends_df.trend == bp_spearman_trends_df.trend, "uncertain")

In [None]:
def plot_trend_counts():
    counts = bp_trends.groupby("param").value_counts().rename("count")
    sns.barplot(x="param", y="count", hue="trend", data=counts.to_frame().reset_index(), hue_order=["decreasing", "increasing", "uncertain"])
    plt.show()    


plot_trend_counts()

In [None]:
def make_trend_frequency_table():
    columns = ["bs_point", "bf_point", "log_burstiness"]
    df = bp_trends.groupby("gene").agg(tuple).value_counts().rename("genes").reset_index()
    df[columns] = pd.DataFrame(df.iloc[:, 0].to_list())
    return df.loc[:, columns + ["genes"]].sort_values(by=["genes"] + columns, ascending=[False] + ([True] * len(columns)))


trend_frequency_df = make_trend_frequency_table()
display(trend_frequency_df)

save_table(trend_frequency_df, "1")

In [None]:
def determine_bs_prediction_trends(df):
    trend_df = pd.DataFrame(index=df.index, data={"trend": "uncertain"})
    trend_df[df.intercept > 0] = "decreasing"
    trend_df[df.intercept < 0] = "increasing"
    trend_df["confident"] = df.accept_intercept
    return trend_df


def plot_bs_trends_contingency(ax=None):
    ax = ax or plt.subplots()[1]
    pred_trends_df = determine_bs_prediction_trends(analysis_mv_fit_df)
    pred_trends_df.loc[~pred_trends_df.confident] = "uncertain"
    ct = pd.crosstab(
        bp_trends.loc["bs_point"].rename("Established"),
        pred_trends_df.trend.rename("Predicted burst size ($b_p$)"),
        margins=True,
    )
    ct.index = ct.index.str.title()
    ct.columns = ct.columns.str.title()

    sns.heatmap(ct, annot=True, ax=ax)
    ax.set_yticklabels(ax.get_yticklabels(), verticalalignment="center")

    return ax


_, ax = plt.subplots(figsize=(5, 4))
plot_bs_trends_contingency(ax)
save_figure("10")
plt.show()

In [None]:
def estimate_gene_burstiness(gene_df):
    q_map = {
        "increasing": 0.9,
        "decreasing": 0.1,
        "uncertain": 0.5,
    }
    trend_df = bp_spearman_trends_df.loc["log_burstiness"]
    trend = trend_df.loc[gene_df.name, "trend"]
    q = q_map[trend]
    return gene_df.burstiness.quantile(q)


gene_burstiness_estimates = data_point_info_df.groupby("gene").apply(estimate_gene_burstiness)


def plot_burstiness_estimate(gene_id, y_scale):
    gene_data_point_info = data_point_info_df.loc[data_point_info_df.gene == gene_id]
    display(bp_spearman_df_map["log_burstiness"].join(bp_spearman_trends_df.loc[[("log_burstiness", gene_id)]]))
    estimate = gene_burstiness_estimates.loc[gene_id]
    sns.scatterplot("rna_mean", "burstiness", data=gene_data_point_info)
    plt.axhline(y=estimate, ls=":")
    plt.yscale(y_scale)
    plt.show()


widgets.interactive(
    plot_burstiness_estimate,
    gene_id=make_gene_selector(gene_info_df.symbol[gene_burstiness_estimates.index]),
    y_scale=widgets.RadioButtons(options=["linear", "log"]),
)

## Relationship between burst size and mean-variance gradient of RNA response

For a mean-variance trend $\sigma^2=\alpha\mu+\sigma_0$, the predicted burst size (in the bursty limit) is:
$$b_p=\left(\alpha -1\right)+\frac{\sigma _0}{\mu }$$
This expression is constant when $\sigma_0=0$ and strictly monotonic for $\alpha>0$ otherwise: increasing when $\sigma_0<0$ and decreasing when $\sigma_0>0$. It may be generalised to a power-law curve with the exponent fixed at -1. In all cases the function approaches $\alpha-1$ as $\mu\to\infty$

In [None]:
def predict_bf_and_bs(rna_mean, slope, intercept):
    bs = (slope - 1) + (intercept / rna_mean)
    bf = rna_mean / bs
    return bf, bs


def calculate_bp_prediction_df(df, mv_df, force_zero_intercepts=False):
    if force_zero_intercepts:
        mv_df = mv_df.copy()
        mv_df.loc[~mv_df.accept_intercept, "intercept"] = 0

    gene_id = df.gene
    df = df.set_index(index_columns)

    slope, intercept = mv_df.loc[gene_id, ["slope", "intercept"]].to_numpy().T
    bf, bs = predict_bf_and_bs(df.rna_mean, slope, intercept)
    return pd.DataFrame(index=df.index, data={"bf": bf, "bs": bs})


predicted_bp_df = calculate_bp_prediction_df(data_point_info_df, analysis_mv_fit_df, force_zero_intercepts=True)

In [None]:
bs_trend_df = bp_curve_trends_df.loc["bs_point"].add_prefix("curve_").join(determine_bs_prediction_trends(analysis_mv_fit_df).add_prefix("pred_"))
mismatched_bs_trend_df = bs_trend_df.loc[bs_trend_df.curve_trend != bs_trend_df.pred_trend]

print(f"The trends of {len(mismatched_bs_trend_df):,} genes do not match (fit versus predicted)")

In [None]:
def plot_bs_prediction(gene_id):
    data_point_info_subset = data_point_info_df.loc[data_point_info_df.gene == gene_id]
    lx = np.linspace(data_point_info_subset.rna_mean.min(), data_point_info_subset.rna_mean.max())

    slope, intercept = analysis_mv_fit_df.loc[gene_id, ["slope", "intercept"]]
    _, py = predict_bf_and_bs(lx, slope, intercept)

    a, b, c = bp_curve_df_map["bs_point"].loc[gene_id, ["a", "b", "c"]]
    cy = rp2.regression.power_function(lx, a, b, c)

    sns.scatterplot(data_point_info_subset.rna_mean, data_point_info_subset.bs_point)
    plt.plot(lx, py, ls=":", label="Predicted")
    plt.plot(lx, cy, ls="--", label="Fit curve")
    plt.xlim(left=0)
    plt.ylim(bottom=0, top=data_point_info_subset.bs_point.max())
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()


widgets.interactive(
    plot_bs_prediction,
    gene_id=make_gene_selector(gene_info_df.symbol[mismatched_bs_trend_df.index]),
)

In [None]:
def plot_bs_limits_scatter(df, ax):
    x = np.log10(df.pred_bs)
    y = np.log10(df.curve_bs)
    l_min = np.max((x.min(), y.min()))
    l_max = np.min((x.max(), y.max()))
    ax.scatter(x, y, c=np.log10(df.burstiness))
    ax.plot((l_min, l_max), (l_min, l_max), ":")
    ax.set_xlabel("Predicted burst size limit (log$_{10}$)")
    ax.set_ylabel("Curve fit burst size limit (log$_{10}$)")
    plt.colorbar(ax.collections[0], ax=ax)
    return ax


bp_trend_limits_df = data_point_info_df.set_index(index_columns).join(predicted_bp_df.add_prefix("pred_")).reset_index()
bp_trend_limits_df = bp_trend_limits_df.sort_values(by=["gene", "rna_mean"])
bp_trend_limits_df = bp_trend_limits_df.loc[~bp_trend_limits_df.gene.duplicated(keep="last")]
bp_trend_limits_df = bp_trend_limits_df.set_index("gene")[["rna_mean", "pred_bf", "pred_bs"]]
bp_trend_limits_df = bp_trend_limits_df.join(bp_curve_df_map["bf_point"]["end"].rename("curve_bf"))
bp_trend_limits_df = bp_trend_limits_df.join(bp_curve_df_map["bs_point"]["end"].rename("curve_bs"))
bp_trend_limits_df = bp_trend_limits_df.join(gene_burstiness_estimates.rename("burstiness"))

fig, axes = plt.subplots(ncols=2, constrained_layout=True, figsize=(9, 4))
plot_bs_limits_scatter(bp_trend_limits_df, axes[0])
plot_bs_limits_scatter(bp_trend_limits_df.sort_values(by="burstiness").iloc[-50:], axes[1])
label_subplots(axes)
plt.show()

Sq. relative change (of burst size) is $\left(\frac{b_k-b_p}{b_p}\right)^2$, where $b_p$ is the predicted size.

In [None]:
def plot_burstiness_bs_diff_scatter(ax=None):
    ax = ax or plt.subplots()[1]
    df = data_point_info_df.set_index(index_columns).join(predicted_bp_df.add_suffix("_pred"))
    bs_diff = df.bs_point - df.bs_pred
    bs_diff /= df.bs_pred
    bs_diff **= 2
    bs_diff = np.log10(bs_diff + 1)
    sns.scatterplot(np.log10(df.burstiness), bs_diff, ax=ax)
    ax.axhline(y=0, ls=":", label="")
    ax.set_xlabel("Burstiness (log$_{10}$)")
    ax.set_ylabel(r"Sq. relative change (log$_{10}$)")
    return ax


plot_burstiness_bs_diff_scatter()
plt.show()

In [None]:
def plot_mean_bp_scatter(gene_id):
    gene_df = data_point_info_df.set_index(index_columns).join(predicted_bp_df.add_suffix("_predicted")).reset_index()
    gene_df = gene_df.loc[gene_df.gene == gene_id]

    gene_mv_row = analysis_mv_fit_df.loc[gene_id]
    display(gene_mv_row.to_frame().T)

    rna_mean = gene_df.rna_mean
    cx = np.linspace(rna_mean.min(), rna_mean.max())
    intercept = gene_mv_row.intercept if gene_mv_row.accept_intercept else 0
    bf, bs = predict_bf_and_bs(cx, gene_mv_row.slope, intercept)
    cy_map = {"bf": bf, "bs": bs}

    _, axes = plt.subplots(ncols=2, constrained_layout=True, figsize=(10, 4))
    for p, ax in zip(["bs", "bf"], axes):
        y_tx = gene_df[f"{p}_point"]
        y_pred = gene_df[f"{p}_predicted"]
        r2 = metrics.r2_score(y_tx, y_pred)

        ax.scatter(rna_mean, y_tx, marker="o", label="inferred")
        ax.scatter(rna_mean, y_pred, marker="x", label="predicted")

        ax.plot(cx, cy_map[p], ls=":")

        ax.set_title(f"$R^2$ = {r2:.2f} (data versus prediction)")
        ax.set_xlabel("RNA mean")
        ax.set_ylabel(p)
        ax.set_xlim(left=0)
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()


widgets.interactive(
    plot_mean_bp_scatter,
    gene_id=make_gene_selector(gene_info_df.symbol[data_point_info_df.gene.unique()]),
)

## Modulation of burst size and frequency

TODO

## Relationship between burst size and frequency

In [None]:
def plot_bs_bf_reciprocal_scatter(ax=None, x_max=None):
    ax = ax or plt.subplots()[1]
    bs = data_point_info_df.bs_point
    bf = data_point_info_df.bf_point / data_point_info_df.rna_mean
    lx = np.linspace(bs.min(), bs.max(), 1000)
    ly = 1 / lx
    sns.scatterplot(bs, bf, ax=ax)
    ax.plot(lx, ly, ls="--")
    ax.set_xlabel("Burst size")
    ax.set_ylabel("Burst frequency / $\mu$")
    if x_max is not None:
        ax.set_xlim(left=-10, right=x_max)
    return ax


def plot_bs_bf_reciprocal_scatter_line(ax=None):
    ax = ax or plt.subplots()[1]
    bs = 1 / data_point_info_df.bs_point
    bf = data_point_info_df.bf_point / data_point_info_df.rna_mean
    sns.scatterplot(bs, bf, ax=ax)
    ax.plot((0, 1), (0, 1), ls="--")
    ax.set_xlabel("1 / burst size")
    ax.set_ylabel("Burst frequency / $\mu$")
    return ax


_, axes = plt.subplots(ncols=3, constrained_layout=True, figsize=(12, 4))
plot_bs_bf_reciprocal_scatter(axes[0])
plot_bs_bf_reciprocal_scatter(axes[1], x_max=150)
plot_bs_bf_reciprocal_scatter_line(axes[2])
label_subplots(axes)
plt.show()