In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy
import seaborn as sns
import statsmodels.api as sm
import sys
import warnings
from IPython.display import display
from scipy import stats
from sklearn import metrics

import rp2.data
import rp2.paths
import rp2.regression
from rp2 import hagai_2018

rp2.check_environment()

In [None]:
np.seterr(all="warn")
warnings.filterwarnings("error", category=RuntimeWarning)

In [None]:
print(f"Python version: {sys.version.split()[0]}")
print(f"Scanpy version: {scanpy.__version__}")
print(f"statsmodels version: {sm.__version__}")

In [None]:
sns.set(style="white", palette="bright", font="sans-serif", context="paper")

In [None]:
figures_path = rp2.paths.get_output_path("report")
rp2.create_folder(figures_path, create_clean=True)

def figure_path(name):
    return figures_path.joinpath(f"figure_{name}.svg")


def save_figure(name):
    plt.savefig(figure_path(name))


def display_series(s, indent=0):
    width = s.index.str.len().max() + 2
    for n, v in s.items():
        pad = width - len(n)
        print(f"{' ' * indent}{n}{' ' * pad}{v}")


def label_subplots(axes, titles="abcd", font_size=16):
    for ax, title in zip(axes, titles):
        ax.set_title(f"({title})", fontsize=font_size)


def ensure_that(condition):
    if not condition:
        raise AssertionError

In [None]:
condition_columns = ["replicate", "treatment", "time_point"]
index_columns = ["gene"] + condition_columns
time_points = ["0", "2", "4", "6"]

gene_info_df = rp2.load_biomart_gene_symbols_df("mouse")

## Acquisition and preparation of RNA counts

In [None]:
mouse_umi_adata = hagai_2018.load_umi_counts("mouse")
mouse_umi_adata = mouse_umi_adata[mouse_umi_adata.obs.time_point.isin(time_points)].copy()

print("Full Hagai mouse dataset has:")
print(f"  {mouse_umi_adata.n_obs:,} cells")
print(f"  {mouse_umi_adata.n_vars:,} genes")

ensure_that(mouse_umi_adata.n_obs == 53_086)
ensure_that(mouse_umi_adata.n_vars == 22_048)

del mouse_umi_adata

In [None]:
mouse_counts_adata = hagai_2018.load_counts("mouse", scaling="median")
mouse_counts_adata = mouse_counts_adata[mouse_counts_adata.obs.time_point.isin(time_points)].copy()

print("Scaled Hagai mouse dataset has:")
print(f"  {mouse_counts_adata.n_vars:,} genes")

ensure_that(mouse_counts_adata.n_obs == 53_086)
ensure_that(mouse_counts_adata.n_vars == 16_798)

In [None]:
lps_responsive_gene_ids = mouse_counts_adata.var.index[mouse_counts_adata.var.lps_responsive]
print(f"{len(lps_responsive_gene_ids):,} genes are LPS-responsive")

ensure_that(len(lps_responsive_gene_ids) == 2_336)

In [None]:
additional_genes = gene_info_df.symbol[gene_info_df.symbol.isin(["Tnf", "Il1b"])]
analysis_gene_ids = sorted(set(lps_responsive_gene_ids).union(additional_genes.index))

print(f"{len(analysis_gene_ids):,} genes to be used in analysis")

ensure_that(len(analysis_gene_ids) == 2_338)

In [None]:
condition_df = mouse_counts_adata.obs[condition_columns].drop_duplicates()

print(f"{len(condition_df)} conditions per gene")
print("Per replicate:")
display_series(condition_df.replicate.value_counts().sort_index(), indent=2)
print(f"{len(condition_df) * len(analysis_gene_ids):,} data points overall")

ensure_that(len(condition_df) == 20)
ensure_that((20 * 2_337) == 46_740)

del condition_df

## Verification of linear mean-variance relationship of RNA response

In [None]:
analysis_count_adata = mouse_counts_adata[:, analysis_gene_ids].copy()
gene_condition_stats_df = hagai_2018.calculate_counts_condition_stats(analysis_count_adata)

ensure_that(len(gene_condition_stats_df) == 46_760)

In [None]:
def fit_mean_variance_trends(df):
    x = sm.add_constant(df["mean"])
    y = df["variance"]
    model = sm.RLM(y, x, M=sm.robust.norms.HuberT(t=1.345))
    ensure_that(model.M.t == 1.345)

    rlm_results = model.fit()

    results = {
        "intercept": rlm_results.params[0],
        "slope": rlm_results.params[1],
        "intercept_pval": rlm_results.pvalues[0],
        "slope_pval": rlm_results.pvalues[1],
        "r2_unweighted": metrics.r2_score(y, rlm_results.fittedvalues),
        "r2_weighted": metrics.r2_score(y, rlm_results.fittedvalues, sample_weight=rlm_results.weights),
    }
    return pd.Series(results)


treatment_sets = {
    "all": ["unst", "lps", "pic"],
#    "lps": ["unst", "lps"],
#    "pic": ["unst", "pic"],
}

mv_fit_map = {set_name: gene_condition_stats_df[gene_condition_stats_df.treatment.isin(set_list)].groupby("gene").apply(fit_mean_variance_trends)
              for set_name, set_list in treatment_sets.items()}

In [None]:
all_treatment_mv_fit = mv_fit_map["all"].copy()
all_treatment_mv_fit["accept_intercept"] = all_treatment_mv_fit["intercept_pval"] < 0.05
all_treatment_mv_fit["accept_slope"] = all_treatment_mv_fit["slope_pval"] < 0.05
all_treatment_mv_fit["accept_r2"] = all_treatment_mv_fit["r2_unweighted"] > 0.6
display_series(all_treatment_mv_fit[[c for c in all_treatment_mv_fit.columns if c.startswith("accept_")]].agg(np.count_nonzero))

ensure_that(np.count_nonzero(all_treatment_mv_fit.accept_intercept) == 812)
ensure_that(np.count_nonzero(all_treatment_mv_fit.accept_slope) == 2_139)

In [None]:
def plot_mv_r2_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(all_treatment_mv_fit.loc[all_treatment_mv_fit.accept_slope].r2_unweighted, bins=50, kde=False, ax=ax)
    ax.axvline(x=0.6, ls="--", label="Cut-off")
    ax.set_xlabel("$R^2$ of slope")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    ax.legend()
    return ax


plot_mv_r2_hist()
save_figure("2a")
plt.show()

In [None]:
all_treatment_good_mv_fit = all_treatment_mv_fit.loc[all_treatment_mv_fit.accept_slope & all_treatment_mv_fit.accept_r2]
all_treatment_good_mv_fit.insert(0, "symbol", gene_info_df.loc[all_treatment_good_mv_fit.index].symbol)

print(f"{len(all_treatment_good_mv_fit):,} mean-variance trends have a good fit (based on slope and unweighted R2)")
print(f"  i.e. {100 * (len(all_treatment_good_mv_fit) / len(analysis_gene_ids)):.1f}%")
print(f"  {np.count_nonzero(all_treatment_good_mv_fit.accept_intercept):,} have a significant intercept")

ensure_that(len(all_treatment_good_mv_fit) == 1_551)
ensure_that(np.count_nonzero(all_treatment_good_mv_fit.accept_intercept) == 564)
ensure_that(np.round(100 * (len(all_treatment_good_mv_fit) / len(analysis_gene_ids))) == 66)
ensure_that(np.count_nonzero(all_treatment_good_mv_fit.slope < 0) == 0)

In [None]:
def plot_mv_slope_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(all_treatment_good_mv_fit.slope), bins=100, kde=False, ax=ax)
    ax.set_xlabel("Slope (log$_{10}$)")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    return ax


plot_mv_slope_hist()
save_figure("2b")
plt.show()

#### QUESTION: should we address this outlier?

In [None]:
mv_slope_outlier = all_treatment_good_mv_fit.sort_values(by="slope").iloc[-1]
display(mv_slope_outlier.to_frame().T)

sns.regplot("mean", "variance", data=gene_condition_stats_df.loc[gene_condition_stats_df.gene == mv_slope_outlier.name])
plt.show()

In [None]:
display(additional_genes.to_frame().join(all_treatment_mv_fit))

## Fitting of bursting parameters

In [None]:
txburst_df = rp2.data.load_and_recalculate_txburst_results("mouse", condition_columns, "median")
txburst_df = txburst_df.loc[txburst_df.time_point.isin(time_points)]
txburst_df = txburst_df.loc[txburst_df.gene.isin(all_treatment_good_mv_fit.index)]
txburst_df = txburst_df.copy()

print("For the well-fitted genes:")
print(f"  txburst results are available for {len(txburst_df):,} data points")
print(f"  (across {txburst_df.gene.nunique():,} genes)")

ensure_that(len(txburst_df) == 7_804)
ensure_that(len(txburst_df[condition_columns].drop_duplicates()) == 20)

In [None]:
txburst_gene_condition_counts = txburst_df.gene.value_counts()
txburst_gene_id_subset = txburst_gene_condition_counts[txburst_gene_condition_counts >= 10]

print(f"{len(txburst_gene_id_subset):,} genes have >= 10 txburst results")

ensure_that(len(txburst_gene_id_subset) == 99)

In [None]:
data_point_info_df = txburst_df.loc[txburst_df.gene.isin(txburst_gene_id_subset.index)].set_index(index_columns)
data_point_info_df = data_point_info_df.join(gene_condition_stats_df.set_index(index_columns)[["mean"]].add_prefix("rna_")).reset_index()
data_point_info_df["burstiness"] = data_point_info_df.k_off / data_point_info_df.k_on

analysis_mv_fit_df = all_treatment_good_mv_fit.loc[txburst_gene_id_subset.index]

print(f"{len(data_point_info_df):,} data points are available")

ensure_that(data_point_info_df.gene.nunique() == 99)
ensure_that(len(analysis_mv_fit_df) == 99)
ensure_that(len(data_point_info_df) == 1_343)

In [None]:
def plot_mv_fit_slope_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(analysis_mv_fit_df.slope), bins=20, kde=False, ax=ax)
    ax.set_xlabel("Slope (log$_{10}$)")
    ax.set_ylabel("Number of genes")
    ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    return ax


plot_mv_fit_slope_hist()
save_figure("2c")
plt.show()

In [None]:
_, axes = plt.subplots(ncols=3, constrained_layout=True, figsize=(12, 3))
plot_mv_r2_hist(axes[0])
plot_mv_slope_hist(axes[1])
plot_mv_fit_slope_hist(axes[2])
label_subplots(axes)
save_figure("2abc")
plt.show()

### Plot inferred kinetic parameters

In [None]:
def plot_ksyn_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df.k_syn), bins=20, kde=False, ax=ax)
    ax.set_xlabel("$k_s$ (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


plot_ksyn_hist()
save_figure("3a")
plt.show()

In [None]:
def plot_koff_kon_scatter(ax=None):
    ax = ax or plt.subplots()[1]
    sns.scatterplot(np.log10(data_point_info_df.k_off), np.log10(data_point_info_df.k_on), ax=ax)
    ax.plot((-2, 1.5), (-2, 1.5), ls="--")
    ax.set_xlabel("$k_{off}$ (log$_{10}$)")
    ax.set_ylabel("$k_{on}$ (log$_{10}$)")
    return ax


plot_koff_kon_scatter()
save_figure("3b")
plt.show()

#### QUESTION: are the $k_{off}$ values being clipped at 1,000?

I believe the txburst optimisation is limited to a parameter search space of [0, 1000]. However, this doesn't appear to be affecting too many points:

In [None]:
data_point_info_df.k_off.sort_values().reset_index(drop=True).plot.line()
plt.show()

In [None]:
def plot_burstiness_hist(ax=None):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df["burstiness"]), bins=20, kde=False, ax=ax)
    ax.set_xlabel("Burstiness (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


plot_burstiness_hist()
save_figure("3c")
plt.show()

In [None]:
_, axes = plt.subplots(ncols=3, constrained_layout=True, figsize=(12, 4))
plot_ksyn_hist(axes[0])
plot_koff_kon_scatter(axes[1])
plot_burstiness_hist(axes[2])
label_subplots(axes)
save_figure("3abc")
plt.show()

#### Plot calculated burst parameters

In [None]:
def plot_bs_bf_scatter(ax=None):
    ax = ax or plt.subplots()[1]
    sns.scatterplot(np.log10(data_point_info_df.bs_point), np.log10(data_point_info_df.bf_point), s=20, ax=ax)
    ax.set_xlabel("Burst size (log$_{10}$)")
    ax.set_ylabel("Burst frequency (log$_{10}$)")
    return ax


def plot_bs_hist(ax):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df.bs_point), bins=50, kde=False, ax=ax)
    ax.set_xlabel("Burst size (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


def plot_bf_hist(ax):
    ax = ax or plt.subplots()[1]
    sns.distplot(np.log10(data_point_info_df.bf_point), bins=50, kde=False, ax=ax)
    ax.set_xlabel("Burst frequency (log$_{10}$)")
    ax.set_ylabel("Number of data points")
    return ax


fig4 = plt.figure(constrained_layout=True, figsize=(12, 8))
gs = fig4.add_gridspec(2, 3)
fig4_ax_a = fig4.add_subplot(gs[0:2, 0:2])
fig4_ax_b = fig4.add_subplot(gs[0, 2])
fig4_ax_c = fig4.add_subplot(gs[1, 2])
plot_bs_bf_scatter(fig4_ax_a)
plot_bs_hist(fig4_ax_b)
plot_bf_hist(fig4_ax_c)
label_subplots([fig4_ax_a, fig4_ax_b, fig4_ax_c])

save_figure("4abc")

plt.show()

In [None]:
def calculate_spearman_r(df, x_var, y_var):
    sp_corr = stats.spearmanr(df[x_var], df[y_var])
    return pd.Series(data={"r": sp_corr.correlation, "r_pval": sp_corr.pvalue})


def create_mean_trend_df(df, var, pval):
    trends_df = df.groupby("gene").apply(calculate_spearman_r, "rna_mean", var)
    trends_df["accept_r"] = trends_df.r < pval
    trends_df["trend"] = "uncertain"
    trends_df.loc[trends_df.accept_r & (trends_df.r > 0), "trend"] = "increasing"
    trends_df.loc[trends_df.accept_r & (trends_df.r < 0), "trend"] = "decreasing"
    return trends_df


for c in ["bf_point", "bs_point", "burstiness"]:
    print(f"Trends for {c} based on Spearman rank correlation:")
    display_series(create_mean_trend_df(data_point_info_df, c, 0.05).trend.value_counts().sort_index(), indent=2)

In [None]:
def fit_bp_curve(df, y_var):
    x_var = "rna_mean"
    results = rp2.regression.calculate_curve_fit(df, x_var, y_var, loss_function="huber", f_scale=1.0)
    a, b, c = results["a"], results["b"], results["c"]
    if a is np.nan:
        return None
    results["start"], results["end"] = rp2.regression.power_function((df[x_var].min(), df[x_var].max()), a, b, c)
    return pd.Series(data=results)


bp_curve_df_map = {c: data_point_info_df.groupby("gene").apply(fit_bp_curve, c)
                  for c in ["bf_point", "bs_point"]}

In [None]:
def determine_bp_curve_trends(df):
    series = pd.Series(index=df.index, data="uncertain")
    series[df["start"] < df["end"]] = "increasing"
    series[df["start"] > df["end"]] = "decreasing"
    return series


for c, df in bp_curve_df_map.items():
    print(f"Trends for {c} based on curve-fitting:")
    display_series(determine_bp_curve_trends(df).value_counts().sort_index(), indent=2)

## Relationship between burst size and mean-variance gradient of RNA response

$b_p=\left(\alpha -1\right)+\frac{\sigma _0}{\mu }$

$f_p=\frac{\mu }{b_p}$

TODO

In [None]:
def predict_bf_and_bs(rna_mean, slope, intercept):
    bs = (slope - 1) + (intercept / rna_mean)
    bf = rna_mean / bs
    return bf, bs


def calculate_bp_prediction_df(df, mv_df, force_zero_intercepts=False):
    if force_zero_intercepts:
        mv_df = mv_df.copy()
        mv_df.loc[~mv_df.accept_intercept, "intercept"] = 0

    slope, intercept = mv_df.loc[df.gene, ["slope", "intercept"]].to_numpy().T
    bf, bs = predict_bf_and_bs(df.rna_mean, slope, intercept)
    return pd.DataFrame(index=df.index, data={"bf": bf, "bs": bs})


predicted_bp_df = calculate_bp_prediction_df(data_point_info_df, analysis_mv_fit_df, force_zero_intercepts=True)

In [None]:
def determine_bs_prediction_trends(df, force_zero_intercepts=False):
    if force_zero_intercepts:
        df = df.copy()
        df.loc[~df.accept_intercept, "intercept"] = 0
    series = pd.Series(index=df.index, data="uncertain")
    series[df.intercept > 0] = "decreasing"
    series[df.intercept < 0] = "increasing"
    return series


def plot_bs_prediction(gene_id):
    data_point_info_subset = data_point_info_df.loc[data_point_info_df.gene == gene_id]
    lx = np.linspace(data_point_info_subset.rna_mean.min(), data_point_info_subset.rna_mean.max())

    slope, intercept = analysis_mv_fit_df.loc[gene_id, ["slope", "intercept"]]
    _, py = predict_bf_and_bs(lx, slope, intercept)

    a, b, c = bp_curve_df_map["bs_point"].loc[gene_id, ["a", "b", "c"]]
    cy = rp2.regression.power_function(lx, a, b, c)

    sns.scatterplot(data_point_info_subset.rna_mean, data_point_info_subset.bs_point)
    plt.plot(lx, py, ls=":", label="Predicted")
    plt.plot(lx, cy, ls="--", label="Fit curve")
    plt.xlim(left=0)
    plt.ylim(bottom=0, top=data_point_info_subset.bs_point.max())
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()


trend_df = determine_bp_curve_trends(bp_curve_df_map["bs_point"]).to_frame("curve_trend").join(determine_bs_prediction_trends(analysis_mv_fit_df).to_frame("pred_trend"))
mismatched_trend_df = trend_df.loc[trend_df.curve_trend != trend_df.pred_trend]
mismatched_trend_genes = gene_info_df.symbol[mismatched_trend_df.index].sort_values()

print(f"The trends of {len(mismatched_trend_genes):,} gene do not match (fit versus predicted)")

widgets.interactive(
    plot_bs_prediction,
    gene_id=widgets.Select(options=list(zip(mismatched_trend_genes.values, mismatched_trend_genes.index)), rows=3),
)

In [None]:
def plot_mean_bp_for_analyses_genes():
    for gene_id, gene_df in data_point_info_df.join(predicted_bp_df.add_suffix("_predicted")).groupby("gene"):
        gene_mv_row = analysis_mv_fit_df.loc[gene_id]
        display(gene_mv_row.to_frame().T)

        rna_mean = gene_df.rna_mean
        cx = np.linspace(rna_mean.min(), rna_mean.max())
        intercept = gene_mv_row.intercept if gene_mv_row.accept_intercept else 0
        bf, bs = predict_bf_and_bs(cx, gene_mv_row.slope, intercept)
        cy_map = {"bf": bf, "bs": bs}

        _, axes = plt.subplots(ncols=2, constrained_layout=True, figsize=(10, 4))
        for p, ax in zip(["bs", "bf"], axes):
            y_tx = gene_df[f"{p}_point"]
            y_pred = gene_df[f"{p}_predicted"]
            r2 = metrics.r2_score(y_tx, y_pred)

            ax.scatter(rna_mean, y_tx, marker="o", label="inferred")
            ax.scatter(rna_mean, y_pred, marker="x", label="predicted")

            ax.plot(cx, cy_map[p], ls=":")

            ax.set_title(f"$R^2$ = {r2:.2f}")
            ax.set_xlabel("RNA mean")
            ax.set_ylabel(p)
            ax.set_xlim(left=0)
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.show()


plot_mean_bp_for_analyses_genes()

## Modulation of burst size and frequency

TODO

## Relationship between burst size and frequency

TODO