In [None]:
import ipywidgets as widgets
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from IPython.display import display, Latex
from scipy import stats
from sklearn import metrics

import rp2.data
from rp2 import hagai_2018, create_gene_symbol_map
from rp2.paths import get_output_path
from rp2.regression import power_function, calculate_curve_fit

rp2.check_environment()

### Settings controlling downstream analysis

In [None]:
analysis_species = "mouse"
analysis_counts = "median"
analysis_treatments = ["unst", "lps", "pic"]
analysis_time_points = ["0", "2", "4", "6"]

default_huber_epsilon = 1.345
mv_rlm_factor = default_huber_epsilon
min_mv_r2 = 0.60
min_conditions = 10

bp_curve_loss = "huber"
bp_curve_f_scale = 1.00

condition_columns = ["replicate", "treatment", "time_point"]
index_columns = ["gene"] + condition_columns

In [None]:
gene_symbol_map = create_gene_symbol_map(analysis_species)

### Determine which genes have a sufficient number of conditions with valid burst parameters

In [None]:
condition_info_df = rp2.data.load_and_recalculate_txburst_results(analysis_species, condition_columns, count_type=analysis_counts)
condition_info_df["k_burstiness"] = condition_info_df.k_off / condition_info_df.k_on
condition_info_df["log_burstiness"] = np.log(condition_info_df.k_burstiness)

condition_info_df = condition_info_df.loc[condition_info_df.treatment.isin(analysis_treatments)]
condition_info_df = condition_info_df.loc[condition_info_df.time_point.isin(analysis_time_points)]
print(f"{len(condition_info_df):,} conditions for {condition_info_df.gene.nunique():,} genes have been processed by txburst")

condition_info_df["valid_bp"] = condition_info_df.bs_point.notna() & condition_info_df.bf_point.notna()

print(f"{np.count_nonzero(condition_info_df.valid_bp):,} conditions have valid burst parameters")

valid_counts = condition_info_df.groupby("gene").valid_bp.agg(np.count_nonzero)
valid_gene_ids = valid_counts.index[valid_counts >= min_conditions]
print(f"{len(valid_gene_ids):,} genes have {min_conditions} or more conditions with valid burst parameters")

condition_info_df = condition_info_df.loc[condition_info_df.gene.isin(valid_gene_ids)]

### Calculate statistics of RNA counts

In [None]:
def calculate_count_stats(condition_subset):
    counts_adata = hagai_2018.load_counts(analysis_species, scaling=analysis_counts)
    print(f"Counts available for {counts_adata.n_obs:,} cells and {counts_adata.n_vars:,} genes")

    counts_adata = counts_adata[:, counts_adata.var_names.isin(condition_subset.gene)]
    for column in condition_columns:
        counts_adata = counts_adata[counts_adata.obs[column].isin(condition_subset[column])]

    counts_adata = counts_adata.copy()
    print(f"Calculating count statistics for {counts_adata.n_obs:,} cells and {counts_adata.n_vars:,} genes")

    stats_df = hagai_2018.calculate_counts_condition_stats(counts_adata, group_columns=condition_columns)

    return stats_df


condition_info_df = condition_info_df.set_index(index_columns).join(
    calculate_count_stats(condition_info_df).set_index(index_columns),
    how="left",
).reset_index()

### Display an interactive mean-variance plot for genes with sufficient conditions with valid burst parameters

Although the list of genes is restricted to those with a minimum number of conditions with valid burst parameters, all conditions are plotted and used to fit the regression line. The solid line shows the fit to all points and dotted lines show fits for individual treatments (including unstimulated points in each case).

Plotted points are scaled according to the weight they are assigned by the robust linear model. The sensitivity of the model to outliers can be adjusted rlm_factor slider. Changes to this value are for illustration only and will not change downstream analysis (for this, change the value of mv_rlm_factor above and re-run all cells).

In [None]:
def apply_huber_regressor(df, x_var, y_var, epsilon, include_weights=False):
    x, y = df.loc[:, [x_var, y_var]].to_numpy().T
    x = sm.add_constant(x)

    lm_res = sm.RLM(y, x, sm.robust.norms.HuberT(t=epsilon)).fit()

    results = {
        "slope": lm_res.params[1],
        "intercept": lm_res.params[0],
    }

    r2_unweighted = metrics.r2_score(y, lm_res.predict(x))
    if include_weights:
        results["r2_unweighted"] = r2_unweighted
        results["r2_weighted"] = metrics.r2_score(y, lm_res.predict(x), sample_weight=lm_res.weights)
        results["weights"] = lm_res.weights
    else:
        results["r2"] = r2_unweighted

    return results


def apply_mv_regressor(df, x_var, y_var, epsilon=mv_rlm_factor):
    return apply_huber_regressor(df, x_var, y_var, epsilon, include_weights=True)


def apply_standard_regressor(df, x_var, y_var):
    return apply_huber_regressor(df, x_var, y_var, epsilon=default_huber_epsilon)


def make_gene_selector(gene_ids):
    gene_symbols = gene_symbol_map.lookup(gene_ids).sort_values()
    return widgets.Select(
        options=list(zip(gene_symbols.values, gene_symbols.index)),
        rows=3,
    )


def format_plus_c(c):
    return f"+{c:.2f}" if c > 0 else f"-{abs(c):.2f}"


def plot_mean_var(gene_id, scale, plot_treatment_lines, rlm_factor):
    treatment_colour_map = {"unst": "black", "lps": "red", "pic": "green"}

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]

    lr_results = apply_mv_regressor(condition_info_subset, "mean", "variance", rlm_factor)
    lr_weights = lr_results["weights"]

    treatment_lr_results_map = {treatment: apply_mv_regressor(condition_info_subset.loc[condition_info_subset.treatment.isin(["unst", treatment])], "mean", "variance")
                                for treatment in filter(lambda t: t != "unst", analysis_treatments)}

    plot_output = widgets.Output()
    info_output = widgets.Output()

    with plot_output:
        colours = condition_info_subset.treatment.map(treatment_colour_map)
        legend_handles = []
        for treatment_name, colour in treatment_colour_map.items():
            legend_handles.append(matplotlib.lines.Line2D([], [], marker="o", color=colour, label=treatment_name, linestyle="None", markersize=8))

        is_log = scale == "log"
        log_shift = 1 if is_log else 0
        space_function = np.geomspace if is_log else np.linspace

        x, y = condition_info_subset.loc[:, ["mean", "variance"]].to_numpy().T
        s = np.interp(lr_weights, (0, 1), (10, 50))
        plt.scatter(x + log_shift, y + log_shift, c=colours, s=s)

        lr_x = space_function(log_shift, x.max() + log_shift)
        lr_y = ((lr_x - log_shift) * lr_results["slope"]) + lr_results["intercept"]
        plt.plot(lr_x, lr_y + log_shift, "-")

        if plot_treatment_lines:
            for treatment, lr_res in treatment_lr_results_map.items():
                lr_y2 = ((lr_x - log_shift) * lr_res["slope"]) + lr_res["intercept"]
                plt.plot(lr_x, lr_y2 + log_shift, ":", c=treatment_colour_map[treatment])

        plt.xscale(scale)
        plt.xlim(left=log_shift)
        plt.xlabel("Mean count ($\mu$)")
        plt.yscale(scale)
        plt.ylim(bottom=log_shift)
        plt.ylabel("Variance ($\sigma^2$)")
        plt.legend(
            handles=legend_handles,
            loc="upper left",
            bbox_to_anchor=(1, 1)
        )
        plt.show()

    with info_output:
        print(f"No. of conditions with burst parameters: {np.count_nonzero(condition_info_subset.valid_bp)} / {len(condition_info_subset)}")
        print(f"No. of weights < 1: {np.count_nonzero(lr_weights < 1)}")
        for treatment in analysis_treatments:
            print(f"  {np.count_nonzero((lr_weights < 1) & (condition_info_subset.treatment == treatment))} {treatment}")
        display(Latex(f"$\sigma^2={lr_results['slope']:.2f}\mu{format_plus_c(lr_results['intercept'])}$"))
        display(Latex(f"Weighted $R^2$: {lr_results['r2_weighted']:.3f}"))
        display(Latex(f"Unweighted $R^2$: {lr_results['r2_unweighted']:.3f}"))

        for name, treatment_lr_results in treatment_lr_results_map.items():
            display(Latex(f"Unweighted $R_{{{name}}}^2$: {treatment_lr_results['r2_unweighted']:.3f}"))


    display(widgets.HBox((plot_output, info_output)))


widgets.interactive(
    plot_mean_var,
    gene_id=make_gene_selector(valid_gene_ids),
    scale=widgets.RadioButtons(options=[["Linear", "linear"], ["Log-log (+1)", "log"]]),
    plot_treatment_lines=False,
    rlm_factor=widgets.FloatSlider(mv_rlm_factor, min=1.001, max=5, step=0.001, readout_format=".3f"),
)

### Fit trends to mean-variance relationships

A linear mean-variance trend is fitted and genes are further filtered by the $R^2$ score.

In [None]:
condition_info_df["mv_weight"] = False
mv_gene_info_df = pd.DataFrame()

for gene_id, gene_df in condition_info_df.groupby("gene"):
    results = apply_mv_regressor(gene_df, "mean", "variance", mv_rlm_factor)
    weights = results.pop("weights")
    
    results["n_valid_bp"] = np.count_nonzero(gene_df.valid_bp)

    mv_gene_info_df = mv_gene_info_df.append(pd.DataFrame(index=[gene_id], data=results))
    condition_info_df.loc[gene_df.index, "mv_weight"] = weights

mv_gene_info_df.index.name = "gene"

_, axes = plt.subplots(ncols=2, figsize=(12, 4))
for r2_name, ax in zip(["r2_unweighted", "r2_weighted"], axes.flat):
    mv_gene_info_df[r2_name].plot.hist(bins=30, ax=ax)
    ax.axvline(x=min_mv_r2, ls=":", label="cutoff")
    ax.set_xlabel(r2_name)
    ax.set_ylabel("Number of genes")
plt.show()

mv_gene_info_df.drop(columns="r2_weighted", inplace=True)
mv_gene_info_df.rename(columns={"r2_unweighted": "r2"}, inplace=True)

print(f"Of {len(mv_gene_info_df):,} genes:")

subset_with_good_r2 = mv_gene_info_df.loc[mv_gene_info_df.r2 >= min_mv_r2]
print(f"  {len(subset_with_good_r2):,} have (unweighted) R2 >= {min_mv_r2}")
filtered_gene_ids = subset_with_good_r2.index

print(f"  {len(filtered_gene_ids):,} will be kept")
print("The following genes will be filtered out:")

gene_symbol_map.lookup(filtered_gene_ids).to_frame().sort_values(by="symbol").to_csv(get_output_path("burst_trends_gene_list.csv"), index_label="id")

display(gene_symbol_map.added_to(mv_gene_info_df.loc[~mv_gene_info_df.index.isin(filtered_gene_ids)]).sort_values(by="gene_symbol"))

### Display an interactive plot of burst parameters against mean count for filtered genes

In [None]:
def get_bp_condition_subset(condition_df):
    condition_mask = condition_df.valid_bp
    return condition_df.loc[condition_mask]


@widgets.interact(
    gene_id=make_gene_selector(filtered_gene_ids),
    parameter=widgets.RadioButtons(options=[("Frequency", "f"), ("Size", "s")], rows=2),
    loss_function=widgets.Select(options=["linear", "cauchy", "huber", "soft_l1"], value=bp_curve_loss, rows=4),
    f_scale=widgets.FloatSlider(bp_curve_f_scale, min=0.01, max=5, step=0.01),
)
def plot_mean_bp(gene_id, parameter, loss_function, f_scale):
    x_var = "mean"
    y_var = f"b{parameter}_point"

    condition_info_subset = get_bp_condition_subset(condition_info_df.loc[condition_info_df.gene == gene_id])
    mv_info_row = mv_gene_info_df.loc[gene_id]

    curve_res = calculate_curve_fit(condition_info_subset, x_var, y_var, loss_function, f_scale)
    line_res = apply_standard_regressor(condition_info_subset, x_var, y_var)

    x_range = condition_info_subset[x_var].min(), condition_info_subset[x_var].max()

    plot_output = widgets.Output()
    info_output = widgets.Output()

    with plot_output:
        ax = sns.scatterplot(
            x=x_var,
            y=y_var,
            hue="treatment",
            data=condition_info_subset,
        )

        if parameter == "s":
            ax.axhline(y=mv_info_row.slope - 1, ls="--")

        line_x = np.asarray(x_range)
        line_y = (line_x * line_res["slope"]) + line_res["intercept"]
        ax.plot(line_x, line_y, ls=":")

        if curve_res["a"] is not np.nan:
            curve_x = np.linspace(*x_range)
            curve_y = power_function(curve_x, curve_res["a"], curve_res["b"], curve_res["c"])
            ax.plot(curve_x, curve_y, ls="-")

        ax.set_xlim(0, x_range[1])
        ax.set_ylim(0, condition_info_subset[y_var].max())

        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.show()

    with info_output:
        display(Latex(f"Linear $R^2$ (unweighted): {line_res['r2']:.3f}"))
        display(Latex(f"Curve $R^2$: {curve_res['r2']:.3f}"))

    display(widgets.HBox((plot_output, info_output)))

### Fit trends to burst parameters

Lines and curves are fit for the filtered list of genes based on conditions with valid point estimates of burst parameters.

In [None]:
def calculate_per_gene_info(df):
    results = pd.Series(dtype=np.float)
    results["n_conditions"] = len(df)

    lr_dict = {
        "bf": apply_standard_regressor(df, "mean", "bf_point"),
        "bs": apply_standard_regressor(df, "mean", "bs_point"),
        "kb": apply_standard_regressor(df, "mean", "log_burstiness"),
    }

    curve_dict = {
        "bf": calculate_curve_fit(df, "mean", "bf_point", bp_curve_loss, bp_curve_f_scale),
        "bs": calculate_curve_fit(df, "mean", "bs_point", bp_curve_loss, bp_curve_f_scale),
        "kb": calculate_curve_fit(df, "mean", "log_burstiness", bp_curve_loss, bp_curve_f_scale),
    }

    max_mean_count = df["mean"].max()

    for lr_n, lr_v in lr_dict.items():
        for n, v in lr_v.items():
            results[f"{lr_n}_{n}"] = v
    for curve_n, curve_v in curve_dict.items():
        for n, v in curve_v.items():
            results[f"{curve_n}_pf_{n}"] = v
        results[f"{curve_n}_pf_converge"] = power_function(max_mean_count, curve_v["a"], curve_v["b"], curve_v["c"])

    return results


condition_info_gene_subset = get_bp_condition_subset(condition_info_df.loc[condition_info_df.gene.isin(filtered_gene_ids)])
condition_info_gene_subset = condition_info_gene_subset.groupby("gene").apply(calculate_per_gene_info)
condition_info_gene_subset = condition_info_gene_subset.join(
    mv_gene_info_df.rename(columns={n: f"mv_{n}" for n in mv_gene_info_df.columns}),
)
for prefix in ["bs", "bf", "kb"]:
    condition_info_gene_subset[f"{prefix}_score"] = np.where(
        condition_info_gene_subset[f"{prefix}_pf_r2"].isna(),
        condition_info_gene_subset[f"{prefix}_r2"],
        condition_info_gene_subset[f"{prefix}_pf_r2"]
    )

print("Curves successfully fit to:")
print(f"  Burst sizes of {condition_info_gene_subset.bs_pf_a.count():,} genes")
print(f"  Burst frequencies of {condition_info_gene_subset.bf_pf_a.count():,} genes")
print(f"  Burstiness of {condition_info_gene_subset.kb_pf_a.count():,} genes")

### Relationship between mean-variance gradient and burst size

Prediction for a bursty regime:

> ...burst size is necessarily constant (and equal to the slope of the mean-variance line) over the range of mean mRNA response (i.e., burst size $b_k=\alpha-1$)

#### Scatter plot of the relationship (one point per gene)

The burst size plotted for a gene is an estimate of its "convergence" taken to be the point on the fitted curve at the highest mean count value.

In [None]:
def plot_mv_slope_and_burst_size(scale, axes_range, colour_by, burstiness_range, mv_r2_range, exponent_range):
    df = condition_info_gene_subset
    xy = df[["mv_slope", "bs_pf_converge"]].to_numpy()
    xy[:, 0] -= 1

    axes_min, axes_max = axes_range
    if (scale == "log") and (axes_min == 0):
        axes_min = np.nanmin(xy)
    if axes_min >= axes_max:
        axes_max = axes_min + 1

    plot_mask = (df["kb_pf_converge"] >= burstiness_range[0]) & (df["kb_pf_converge"] <= burstiness_range[1])
    plot_mask &= (df["mv_r2"] >= mv_r2_range[0]) & (df["mv_r2"] <= mv_r2_range[1])
    plot_mask &= (df["bs_pf_b"] >= exponent_range[0]) & (df["bs_pf_b"] <= exponent_range[1])
    plot_xy = xy[plot_mask]

    colours = None
    if isinstance(colour_by, str):
        colours = df.loc[plot_mask, colour_by]

    plot_output = widgets.Output()
    info_output = widgets.Output()

    with plot_output:
        plt.scatter(*plot_xy.T, c=colours, cmap="viridis")
        plt.plot((axes_min, axes_max), (axes_min, axes_max), ls=":")

        plt.xlabel(r"$\alpha$ - 1")
        plt.xscale(scale)
        plt.xlim(axes_min, axes_max)
        plt.ylabel("Converged burst size")
        plt.yscale(scale)
        plt.ylim(axes_min, axes_max)
        plt.gca().set_aspect(1)
        if colours is not None:
            plt.colorbar()
        plt.show()

    with info_output:
        print(f"No. of points: {len(plot_xy)}")
        if len(plot_xy) >= 2:
            corr = stats.pearsonr(*plot_xy.T)
            display(Latex(f"$\\rho$={corr[0]:.2f} (pval={corr[1]:.4f})"))

            for name, Model in [("OLS", sm.OLS), ("RLM", sm.RLM)]:
                lr_res = Model(plot_xy[:, 1], plot_xy[:, 0]).fit()
                ci = lr_res.conf_int(0.01).squeeze()
                print(f"{name}: {lr_res.params[0]:.2f} with 99% CI ({ci[0]:.2f}, {ci[1]:.2f})")

    display(widgets.HBox((plot_output, info_output)))


def make_axes_range_slider(df, x_var, y_var):
    max_value = np.ceil(df[[x_var, y_var]].max().max())
    return widgets.IntRangeSlider(value=(0, max_value), min=0, max=max_value)


def make_value_range_slider(df, var):
    min_value, max_value = df[var].agg(["min", "max"]).to_numpy()
    return widgets.FloatRangeSlider(value=(min_value, max_value), min=min_value, max=max_value, step=0.01)


widgets.interactive(
    plot_mv_slope_and_burst_size,
    scale=widgets.RadioButtons(options=[("Linear", "linear"), ("Log-log", "log")]),
    axes_range=make_axes_range_slider(condition_info_gene_subset, "mv_slope", "bs_pf_converge"),
    colour_by=[
        ("Nothing", None),
        ("Burstiness (log10)", "kb_pf_converge"),
        ("Burstiness R2", "kb_pf_r2"),
        ("Mean-variance R2", "mv_r2"),
        ("Burst size 'a'", "bs_pf_a"),
        ("Burst size 'b'", "bs_pf_b"),
        ("Burst size 'c'", "bs_pf_c"),
        ("Burst size R2", "bs_pf_r2"),
    ],
    burstiness_range=make_value_range_slider(condition_info_gene_subset, "kb_pf_converge"),
    mv_r2_range=make_value_range_slider(condition_info_gene_subset, "mv_r2"),
    exponent_range=make_value_range_slider(condition_info_gene_subset, "bs_pf_b"),
)

#### Scatter plot of burst size deviation from the mean-variance gradient against burstiness (one point per condition for all genes)

In [None]:
def create_burst_model_df(condition_df, model_fit_df):
    condition_df = condition_df.loc[condition_df.gene.isin(model_fit_df.index)]

    model_df = pd.DataFrame()
    for gene_id, condition_subset in condition_df.groupby("gene"):
        model_fit_row = model_fit_df.loc[gene_id]

        gene_df = condition_subset[["gene", "mean", "mv_weight", "bs_point", "bf_point"]].copy()
        gene_df.columns = ["gene", "mean_count", "mv_weight", "bs_txburst", "bf_txburst"]
        gene_df["burstiness"] = np.log(condition_subset["k_burstiness"])

        for prefix in ("bs", "bf"):
            pf_params = model_fit_row[[f"{prefix}_pf_{v}" for v in ["a", "b", "c"]]]
            gene_df[f"{prefix}_curve"] = power_function(condition_subset["mean"].values, *pf_params)
            gene_df[f"{prefix}_curve_r2"] = model_fit_row[f"{prefix}_pf_r2"]

        model_df = model_df.append(gene_df)

    return model_df


def plot_size_burstiness_relationship():
    df = bp_model_df.copy()
    slope = condition_info_gene_subset.loc[df.gene, "mv_slope"].values - 1
    df["y_val"] = df["bs_txburst"] - slope
    df["y_val"] /= slope
    df["y_val"] = df["y_val"].abs()

    plot_output = widgets.Output()
    info_output = widgets.Output()

    with plot_output:
        _, ax = plt.subplots(figsize=(10, 6))
        df.plot.scatter("burstiness", "y_val", c="orange", ax=ax)
        plt.axhline(y=0, ls=":", c="black")
        plt.xlabel(r"$log_{10}(burstiness)$")
        plt.ylabel(r"$\left|\frac{b_{k}-\left(\alpha-1\right)}{\alpha-1}\right|$")
        plt.show()

    with info_output:
        corr = stats.spearmanr(df["burstiness"], df["y_val"])
        display(Latex(f"$\\rho_{{spearman}}$={corr.correlation:.2f} (pval={corr.pvalue:.4f})"))

    display(widgets.HBox((plot_output, info_output)))


bp_model_df = create_burst_model_df(get_bp_condition_subset(condition_info_df), condition_info_gene_subset)
plot_size_burstiness_relationship()

### Relationship between burst size and frequency

Predicted relationship:

> There is a reciprocal relationship between burst size and frequency, as the burst frequency is proportional to the inverse of the burst size ($1/\alpha$).

#### Scatter plot of the relationship (one point per condition for all genes)

The hypothesised relationship is $f_k=\mu/(\alpha-1)=\mu/b_k$ so the plotted burst frequency is divided by the mean RNA count ($\mu$) to investigate a direct reciprocal relationship.

In [None]:
def plot_burst_size_vs_frequency(df, point_type, reciprocal_of_size, zoom_axes, colour_by, burstiness_range):
    x_var = f"bs_{point_type}"
    y_var = f"bf_{point_type}"

    plot_mask = (df["burstiness"] >= burstiness_range[0]) & (df["burstiness"] <= burstiness_range[1])

    plot_output = widgets.Output()
    info_output = widgets.Output()

    x = df[x_var].values
    rx = 1 / x
    if reciprocal_of_size:
        x = rx
    y = df[y_var] / df["mean_count"]

    with plot_output:
        colours = "orange"
        cmap = None
        if isinstance(colour_by, str):
            colours = df.loc[plot_mask, colour_by]
            cmap = "viridis"

        line_x = np.linspace(np.nanmin(x), np.nanmax(x), 1000)
        line_y = line_x if reciprocal_of_size else 1 / line_x

        plt.scatter(x[plot_mask], y[plot_mask], c=colours, cmap=cmap)
        plt.plot(line_x, line_y, ls="--", c="black")

        plt.xlabel("1 / burst size" if reciprocal_of_size else "Burst size")
        plt.ylabel(r"Burst frequency / $\mu$")
        if zoom_axes:
            if reciprocal_of_size:
                plt.xlim(0, 1.5)
                plt.ylim(0, 1.5)
            else:
                plt.xlim(0, 50)
                plt.ylim(0, 1.5)
        if colour_by is not None:
            plt.colorbar()
        plt.show()

    with info_output:
        n_points = np.count_nonzero(plot_mask)
        print(f"No. of points: {n_points:,}")
        if n_points >= 2:
            corr = stats.pearsonr(rx[plot_mask], y[plot_mask])
            display(Latex(f"$\\rho$={corr[0]:.2f} (pval={corr[1]:.4f})"))

            for name, Model in [("OLS", sm.OLS), ("RLM", sm.RLM)]:
                lr_res = Model(y[plot_mask], rx[plot_mask]).fit()
                ci = lr_res.conf_int(0.01).squeeze()
                print(f"{name}: {lr_res.params[0]:.2f} with 99% CI ({ci[0]:.2f}, {ci[1]:.2f})")


    display(widgets.HBox((plot_output, info_output)))


widgets.interactive(
    plot_burst_size_vs_frequency,
    df=widgets.fixed(bp_model_df),
    point_type=widgets.RadioButtons(options=[["txburst fit", "txburst"], ["Curve fit", "curve"]]),
    reciprocal_of_size=True,
    zoom_axes=False,
    colour_by=[
        ["Nothing", None],
        ["Burstiness (log10)", "burstiness"],
        ["Mean-variance weight", "mv_weight"],
        ["Size curve R2", "bs_curve_r2"],
        ["Frequency curve R2", "bf_curve_r2"]
    ],
    burstiness_range=make_value_range_slider(bp_model_df, "burstiness"),
)

#### Scatter plot of the relationship for the points of a selected gene

In [None]:
def plot_gene_bp_f_v_s(gene_id, reciprocal_of_size, link_points, show_curve, colour_by):
    bp_df =  bp_model_df.loc[bp_model_df.gene == gene_id].sort_values(by="mean_count")
    model_df = condition_info_gene_subset.loc[gene_id]

    xs = np.asarray(bp_df[["bs_txburst", "bs_curve"]])
    if reciprocal_of_size:
        xs = 1 / xs
    ys = np.asarray(bp_df[["bf_txburst", "bf_curve"]].div(bp_df.mean_count, axis=0))

    plt.subplots(figsize=(15, 6))

    if reciprocal_of_size:
        lx = ly = (0, np.max((xs.max(), ys.max())))
    else:
        lx = np.linspace(xs.min(), xs.max(), 1000)
        ly = 1 / lx
    plt.plot(lx, ly, ls="--", c="black", label="y=x")

    gradient = x=model_df["mv_slope"] - 1
    if reciprocal_of_size:
        gradient = 1 / gradient
    plt.axvline(x=gradient, ls="-.", c="grey")

    if link_points:
        plt.plot(xs.T, ys.T, ls=":", c="purple")

    if show_curve:
        condition_info_subset = get_bp_condition_subset(condition_info_df.loc[condition_info_df.gene == gene_id])
        mean_min, mean_max = condition_info_subset["mean"].agg(["min", "max"])
        mean_vals = np.linspace(mean_min, mean_max, 1000)

        sa, sb, sc = model_df[["bs_pf_a", "bs_pf_b", "bs_pf_c"]]
        fa, fb, fc = model_df[["bf_pf_a", "bf_pf_b", "bf_pf_c"]]
        bs_vals = power_function(mean_vals, sa, sb, sc)
        bf_vals = power_function(mean_vals, fa, fb, fc)
        if reciprocal_of_size:
            bs_vals = 1 / bs_vals
        plt.plot(bs_vals, bf_vals / mean_vals, ls="-", c="gray", label="curve")

    txburst_colour = bp_df[colour_by] if isinstance(colour_by, str) else "green"
    curve_colour = bp_df[colour_by] if isinstance(colour_by, str) else "red"
    plt.scatter(xs[:, 0], ys[:, 0], marker="o", label="txburst", c=txburst_colour)
    plt.scatter(xs[:, 1], ys[:, 1], marker="P", label="curve", c=curve_colour)

    plt.xlabel("1 / burst size" if reciprocal_of_size else "Burst size")
    plt.ylabel(r"$Burst frequency / $\mu$")

    if isinstance(colour_by, str):
        plt.colorbar()
        plt.legend()
    else:
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))

    plt.show()


widgets.interactive(
    plot_gene_bp_f_v_s,
    gene_id=make_gene_selector(filtered_gene_ids),
    reciprocal_of_size=True,
    link_points=False,
    show_curve=False,
    colour_by=widgets.Dropdown(options=[["Nothing", None], ["Burstiness (log10)", "burstiness"]]),
)

### Modulation of burst size and frequency

Under a burst regime, the predicted constant burst size necessitates that:

> Changes of gene expression are controlled solely by frequency modulation [i.e., $f_k=\mu/(\alpha-1)$]

And, in the general case:

> ...both burst size and frequency may undergo modulation as the mean mRNA expression varies. The relative contribution of burst size and frequency modulation is related to the $k_{off}$ value (or $k_{off}/k_{on}$ ratio...)

#### Scatter plots of modulation as a function of burstiness

In [None]:
def plot_bp_modulation():
    condition_subset = get_bp_condition_subset(condition_info_df.loc[condition_info_df.gene.isin(condition_info_gene_subset.index)])
    cv_df = condition_subset.groupby("gene")[["bs_point", "bf_point"]].agg(stats.variation)
    cv_df.columns = ["bs_cv", "bf_cv"]
    cv_df = cv_df.join(condition_info_gene_subset["kb_pf_converge"].rename("burstiness"))
    cv_df["cv_ratio"] = cv_df["bf_cv"] / cv_df["bs_cv"]

    plot_columns = ["bs_cv", "bf_cv", "cv_ratio"]
    plot_outputs = [widgets.Output() for _ in plot_columns]

    x = cv_df["burstiness"]

    for column, output in zip(plot_columns, plot_outputs):
        with output:
            y = cv_df[column]

            lr_res = sm.RLM(y, sm.add_constant(x)).fit()
            lr_x = np.asarray((x.min(), x.max()))
            lr_y = (lr_x * lr_res.params[1]) + lr_res.params[0]

            plt.scatter(x, y, c="orange")
            plt.plot(lr_x, lr_y, ls="--", c="black")
            plt.xlabel("$log_{10}(burstiness)$")
            plt.ylabel(column)
            plt.show()

            display(lr_res.summary2())

    display(widgets.HBox(plot_outputs))


plot_bp_modulation()

### Display a scatter plot of $R^2$ values for burst size against frequency to indicate dominant modulation

"Score" is the $R^2$ value of the curve fit, if successful, and $R^2$ of the linear fit otherwise.

In [None]:
ax = sns.scatterplot(
    x="bs_score",
    y="bf_score",
    data=condition_info_gene_subset,
)
plt.plot((-0.5, 1), (-0.5, 1), "-")
plt.axvline(x=0, ls=":")
plt.axhline(y=0, ls=":")
ax.set_aspect(1)
plt.show()

print(f"{np.count_nonzero(condition_info_gene_subset.bf_score > condition_info_gene_subset.bs_score):,} points above line")
print(f"{np.count_nonzero(condition_info_gene_subset.bf_score < condition_info_gene_subset.bs_score):,} points below line")

### Show plots of trends against mean RNA for filtered genes

Genes are sorted by descending $R^2$ value of the burst frequency linear regression fit.

In [None]:
def plot_relationship_scatter(ax, condition_info, y_var):
    x = condition_info.loc[:, "mean"]
    y = condition_info.loc[:, y_var]

    ax.scatter(x, y)

    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)
    ax.set_xlabel("mean")
    ax.set_ylabel(y_var)


def plot_relationship_line(ax, condition_info, lr_info, y_var_prefix, clip_x_range=False):
    x_min = condition_info["mean"].min() if clip_x_range else 0
    x_range = np.asarray((x_min, condition_info["mean"].max()))
    ax.plot(x_range, (x_range * lr_info[f"{y_var_prefix}_slope"]) + lr_row[f"{y_var_prefix}_intercept"])


def plot_relationship_curve(ax, condition_info, lr_info, y_var_prefix, clip_x_range=False):
    a, b, c = [lr_info[f"{y_var_prefix}_pf_{coef}"] for coef in ["a", "b", "c"]]
    if a is np.nan: return

    x_min = condition_info["mean"].min() if clip_x_range else np.finfo(np.float).eps
    x = np.linspace(x_min, condition_info["mean"].max())
    y = power_function(x, a, b, c)

    ax.plot(x, y, "--")


def plot_gene_bp_size_and_frequency(bp_df, ax, point_type="curve"):
    xs = np.asarray(1 / bp_df[["bs_txburst", "bs_curve"]])
    ys = np.asarray(bp_df[["bf_txburst", "bf_curve"]].div(bp_df.mean_count, axis=0))

    ax.scatter(xs[:, 0], ys[:, 0], marker="o", label="txburst")
    ax.scatter(xs[:, 1], ys[:, 1], marker="P", label="curve")

    max_point = np.max((xs.max(), ys.max()))
    ax.plot((0, max_point), (0, max_point), ":")

    ax.set_xlabel("1 / burst size")
    ax.set_ylabel("Burst frequency / $\mu$")
    ax.legend()


lr_to_plot = condition_info_gene_subset.sort_values(by="bf_score", ascending=False)
for idx, (gene_id, lr_row) in enumerate(lr_to_plot.iterrows(), start=1):
    print(f"{idx}. {gene_symbol_map.lookup(gene_id)}")
    with pd.option_context("display.max_columns", None):
        display(lr_row.to_frame().T)

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]
    bp_model_subset = bp_model_df.loc[bp_model_df.gene == gene_id]

    _, axes = plt.subplots(ncols=5, figsize=(20, 4))

    plot_relationship_scatter(axes[0], condition_info_subset, "variance")
    plot_relationship_line(axes[0], condition_info_subset, lr_row, "mv")
    axes[0].set_title(f"$R^2=${lr_row['mv_r2']:.2f}, $\\alpha=${lr_row['mv_slope']:.2f}")

    valid_bp_condition_info = condition_info_subset.loc[condition_info_subset.valid_bp]

    for prefix, ax in zip(("bf", "bs"), axes[1:3]):
        plot_relationship_scatter(ax, valid_bp_condition_info, f"{prefix}_point")
        plot_relationship_line(ax, valid_bp_condition_info, lr_row, prefix, clip_x_range=True)
        plot_relationship_curve(ax, valid_bp_condition_info, lr_row, prefix, clip_x_range=True)
        ax.set_title(f"Linear $R^2=${lr_row[f'{prefix}_r2']:.2f}, Curve $R^2=${lr_row[f'{prefix}_pf_r2']:.2f}")

    axes[2].axhline(y=lr_row["mv_slope"] - 1, ls=":")

    plot_relationship_scatter(axes[3], valid_bp_condition_info, "log_burstiness")
    plot_relationship_curve(axes[3], valid_bp_condition_info, lr_row, "kb", clip_x_range=True)
    axes[3].set_title(f"Curve $R^2=${lr_row[f'kb_pf_r2']:.2f}")

    plot_gene_bp_size_and_frequency(bp_model_subset, axes[4])
    axes[4].axvline(x=1 / (lr_row["mv_slope"] - 1), ls=":")

    plt.tight_layout()
    plt.show()