In [None]:
import warnings

import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.optimize import curve_fit
from sklearn.linear_model import HuberRegressor
from IPython.display import display, Latex

import rp2
from rp2 import hagai_2018, create_gene_symbol_map
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
analysis_species = "mouse"
analysis_counts = "median"
analysis_treatments = ["unst", "lps"]
analysis_time_points = ["0", "2", "4", "6"]

min_conditions = 6
mv_outlier_epsilon = 1.35

index_columns = ["replicate", "treatment", "time_point"]
all_index_columns = ["gene"] + index_columns

In [None]:
gene_symbol_map = create_gene_symbol_map(analysis_species)

In [None]:
condition_info_df = pd.read_csv(get_txburst_results_csv_path(analysis_species, index_columns, count_type=analysis_counts))
condition_info_df.replicate = condition_info_df.replicate.astype(str)
condition_info_df["k_burstiness"] = condition_info_df.k_off / condition_info_df.k_on

condition_info_df = condition_info_df.loc[condition_info_df.treatment.isin(analysis_treatments)]
condition_info_df = condition_info_df.loc[condition_info_df.time_point.isin(analysis_time_points)]
print(f"{len(condition_info_df):,} conditions with {condition_info_df.gene.nunique():,} genes")

condition_info_df["valid_bp"] = condition_info_df.bs_point.notna() & condition_info_df.bf_point.notna()

valid_counts = condition_info_df.groupby("gene").valid_bp.agg(np.count_nonzero)
valid_gene_ids = valid_counts.index[valid_counts >= min_conditions]
print(f"{len(valid_gene_ids):,} genes have {min_conditions} or more conditions with valid burst parameters")

# condition_info_df = condition_info_df.loc[condition_info_df.valid_bp]
condition_info_df = condition_info_df.loc[condition_info_df.gene.isin(valid_gene_ids)]
# print(f"{len(condition_info_df):,} conditions with {condition_info_df.gene.nunique():,} genes")

In [None]:
def calculate_count_stats(condition_subset):
    counts_adata = hagai_2018.load_counts(analysis_species, scaling=analysis_counts)
    print(f"Counts for {counts_adata.n_obs:,} cells and {counts_adata.n_vars:,} genes")

    counts_adata = counts_adata[:, counts_adata.var_names.isin(condition_subset.gene)]
    for column in index_columns:
        counts_adata = counts_adata[counts_adata.obs[column].isin(condition_subset[column])]

    counts_adata = counts_adata.copy()
    print(f"Calculating count statistics for {counts_adata.n_obs:,} cells and {counts_adata.n_vars:,} genes")

    return hagai_2018.calculate_counts_condition_stats(counts_adata, group_columns=index_columns)


condition_info_df = condition_info_df.set_index(all_index_columns).join(
    calculate_count_stats(condition_info_df).set_index(all_index_columns),
    how="left",
).reset_index()

In [None]:
def calculate_regression(df, x_var, y_var, regressor=None, output_outliers=False):
    regressor = regressor or HuberRegressor()

    rx, ry = df.loc[:, [x_var, y_var]].to_numpy().reshape(1, -1, 2).T

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        r = regressor.fit(rx, ry)

    results = {
        "slope": r.coef_.item(),
        "intercept": r.intercept_.item(),
        "r2": r.score(rx, ry),
    }
    if output_outliers:
        results["outliers"] = r.outliers_
        results["r2_without_outliers"] = r.score(rx[~r.outliers_], ry[~r.outliers_])

    return results


def make_gene_selector(gene_ids):
    gene_symbols = gene_symbol_map.lookup(gene_ids).sort_values()
    return widgets.Select(
        options=list(zip(gene_symbols.values, gene_symbols.index)),
        rows=3,
    )


@widgets.interact(gene_id=make_gene_selector(valid_gene_ids), huber_epsilon=widgets.FloatSlider(mv_outlier_epsilon, min=1.01, max=10, step=0.01))
def plot_mean_var(gene_id, huber_epsilon):
    plot_output = widgets.Output()
    info_output = widgets.Output()
    output = widgets.HBox((plot_output, info_output))

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]

    lr_results = calculate_regression(condition_info_subset, "mean", "variance", regressor=HuberRegressor(epsilon=huber_epsilon), output_outliers=True)
    outliers = lr_results["outliers"]

    with plot_output:
        x, y = condition_info_subset.loc[:, ["mean", "variance"]].to_numpy().T
        plt.scatter(
            x[~outliers],
            y[~outliers],
            marker="o",
        )
        plt.scatter(
            x[outliers],
            y[outliers],
            marker="x",
            label="outlier",
        )

        lr_x = np.asarray((0, x.max()))
        lr_y = (lr_x * lr_results["slope"]) + lr_results["intercept"]
        plt.plot(lr_x, lr_y)

        plt.xlim(left=0)
        plt.ylim(bottom=0)
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.show()

    with info_output:
        print(f"No. of conditions: {len(condition_info_subset)}")
        print(f"No. of outliers: {np.count_nonzero(outliers)}")
        display(Latex(f"$R^2$ with outliers: {lr_results['r2']:.3f}"))
        display(Latex(f"$R^2$ without outliers: {lr_results['r2_without_outliers']:.3f}"))

    display(output)

In [None]:
condition_info_df["mv_outlier"] = False
mv_gene_info_df = pd.DataFrame()

for gene_id, gene_df in condition_info_df.groupby("gene"):
    results = calculate_regression(gene_df, "mean", "variance", output_outliers=True)

    outliers = results.pop("outliers")
    results["n_outliers"] = np.count_nonzero(outliers)
    results["r2"] = results.pop("r2_without_outliers")

    mv_gene_info_df = mv_gene_info_df.append(pd.DataFrame(index=[gene_id], data=results))
    condition_info_df.loc[gene_df.index, "mv_outlier"] = outliers

In [None]:
def power_function(x, a, b, c):
    return (a * np.power(x, b)) + c


def calculate_curve_fit(df, x_var, y_var):
    x, y = df.loc[:, [x_var, y_var]].to_numpy().T
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            params, cov = curve_fit(
                power_function,
                x,
                y,
                p0=[1, 1, 0],
            )
    except:
        params = [np.nan] * 3
    return {
        "a": params[0],
        "b": params[1],
        "c": params[2],
    }


def calculate_per_gene_info(df):
    lr_dict = {
        "bf": calculate_regression(df, "mean", "bf_point"),
        "bs": calculate_regression(df, "mean", "bs_point"),
    }

    curve_dict = {
        "bf": calculate_curve_fit(df, "mean", "bf_point"),
        "bs": calculate_curve_fit(df, "mean", "bs_point"),
    }

    results_dict = {
        "n_conditions": len(df),
    }
    for lr_n, lr_v in lr_dict.items():
        for n, v in lr_v.items():
            results_dict[f"{lr_n}_{n}"] = v
    for curve_n, curve_v in curve_dict.items():
        for n, v in curve_v.items():
            results_dict[f"{curve_n}_pf_{n}"] = v

    return pd.Series(results_dict)


gene_info_df = condition_info_df.loc[condition_info_df.valid_bp].groupby("gene").apply(calculate_per_gene_info).join(
    mv_gene_info_df.rename(columns={n: f"mv_{n}" for n in mv_gene_info_df.columns}),
)

In [None]:
ax = sns.scatterplot(
    x="bs_r2",
    y="bf_r2",
    data=gene_info_df,
)
plt.plot((-0.5, 1), (-0.5, 1), "-")
plt.axvline(x=0, ls=":")
plt.axhline(y=0, ls=":")
ax.set_aspect(1)
plt.show()

In [None]:
def plot_relationship_scatter(ax, condition_info, y_var):
    x = condition_info.loc[:, "mean"]
    y = condition_info.loc[:, y_var]

    ax.scatter(
        x[~condition_info.mv_outlier],
        y[~condition_info.mv_outlier],
        marker="o",
    )
    ax.scatter(
        x[condition_info.mv_outlier],
        y[condition_info.mv_outlier],
        marker="x",
        label="mv outlier",
    )
    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)
    ax.set_xlabel("mean")
    ax.set_ylabel(y_var)


def plot_relationship_line(ax, condition_info, lr_info, y_var_prefix):
    x_range = np.asarray((0, condition_info["mean"].max()))
    ax.plot(x_range, (x_range * lr_info[f"{y_var_prefix}_slope"]) + lr_row[f"{y_var_prefix}_intercept"])

    r2 = lr_info[f"{y_var_prefix}_r2"]
    ax.set_title(f"$R^2=${r2:.2f}")


def plot_relationship_curve(ax, condition_info, lr_info, y_var_prefix):
    a, b, c = [lr_info[f"{y_var_prefix}_pf_{coef}"] for coef in ["a", "b", "c"]]
    if a is np.nan: return

    x = np.linspace(np.finfo(np.float).eps, condition_info["mean"].max())
    y = power_function(x, a, b, c)

    ax.plot(x, y, "--")


lr_to_plot = gene_info_df.sort_values(by="bf_r2", ascending=False)
for idx, (gene_id, lr_row) in enumerate(lr_to_plot.iterrows(), start=1):
    print(f"{idx}. {gene_symbol_map.lookup(gene_id)}")
    display(lr_row.to_frame().T)

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]

    _, axes = plt.subplots(ncols=4, figsize=(16, 4), sharex=True)

    plot_relationship_scatter(axes[0], condition_info_subset, "variance")
    plot_relationship_line(axes[0], condition_info_subset, lr_row, "mv")

    valid_bp_condition_info = condition_info_subset.loc[condition_info_subset.valid_bp]

    for prefix, ax in zip(("bf", "bs"), axes[1:3]):
        plot_relationship_scatter(ax, valid_bp_condition_info, f"{prefix}_point")
        plot_relationship_line(ax, valid_bp_condition_info, lr_row, prefix)
        plot_relationship_curve(ax, valid_bp_condition_info, lr_row, prefix)

    plot_relationship_scatter(axes[3], valid_bp_condition_info, "k_burstiness")
    axes[3].set_ylim(bottom=valid_bp_condition_info.k_burstiness.min())
    axes[3].set_yscale("log")

    plt.tight_layout()
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()