In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.optimize import curve_fit
from sklearn.linear_model import HuberRegressor

import rp2
from rp2 import hagai_2018, create_gene_symbol_map
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
analysis_species = "mouse"
analysis_counts = "median"
analysis_treatments = ["unst", "lps"]
analysis_time_points = ["0", "2", "4", "6"]

min_conditions = 6

index_columns = ["replicate", "treatment", "time_point"]
all_index_columns = ["gene"] + index_columns

In [None]:
gene_symbol_map = create_gene_symbol_map(analysis_species)

In [None]:
condition_info_df = pd.read_csv(get_txburst_results_csv_path(analysis_species, index_columns, count_type=analysis_counts))
condition_info_df.replicate = condition_info_df.replicate.astype(str)

condition_info_df = condition_info_df.loc[condition_info_df.treatment.isin(analysis_treatments)]
condition_info_df = condition_info_df.loc[condition_info_df.time_point.isin(analysis_time_points)]
print(f"{len(condition_info_df):,} conditions with {condition_info_df.gene.nunique():,} genes")

condition_info_df["valid"] = condition_info_df.bs_point.notna() & condition_info_df.bf_point.notna()
condition_info_df = condition_info_df.loc[condition_info_df.valid]

valid_counts = condition_info_df.groupby("gene").valid.count()
valid_gene_ids = valid_counts.index[valid_counts >= min_conditions]
print(f"Reduced to {len(valid_gene_ids):,} genes with {min_conditions} or more conditions")

condition_info_df = condition_info_df.loc[condition_info_df.gene.isin(valid_gene_ids)]
print(f"{len(condition_info_df):,} conditions with {condition_info_df.gene.nunique():,} genes")

In [None]:
def calculate_count_stats():
    counts_adata = hagai_2018.load_counts(analysis_species, scaling=analysis_counts)
    counts_adata = counts_adata[counts_adata.obs.treatment.isin(analysis_treatments)]
    counts_adata = counts_adata.copy()
    return hagai_2018.calculate_counts_condition_stats(counts_adata, group_columns=index_columns)


condition_info_df = condition_info_df.set_index(all_index_columns).join(
    calculate_count_stats().set_index(all_index_columns),
    how="left",
).reset_index()

In [None]:
def calculate_regression(df, x_var, y_var, Regressor=HuberRegressor, include_outliers=False):
    rx, ry = df.loc[:, [x_var, y_var]].to_numpy().reshape(1, -1, 2).T

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        r = Regressor().fit(rx, ry)

    results = {
        "slope": r.coef_.item(),
        "intercept": r.intercept_.item(),
        "r2": r.score(rx, ry),
    }
    if include_outliers:
        results["outliers"] = r.outliers_
        # results["r2_without_outliers"] = r.score(rx[~r.outliers_], ry[~r.outliers_])

    return results


condition_info_df["mv_outlier"] = False
mv_gene_info_df = pd.DataFrame()

for gene_id, gene_df in condition_info_df.groupby("gene"):
    results = calculate_regression(gene_df, "mean", "variance", include_outliers=True)

    outliers = results.pop("outliers")
    results["n_outliers"] = np.count_nonzero(outliers)

    mv_gene_info_df = mv_gene_info_df.append(pd.DataFrame(index=[gene_id], data=results))
    condition_info_df.loc[gene_df.index, "mv_outlier"] = outliers

In [None]:
def power_function(x, a, b, c):
    return (a * np.power(x, b)) + c


def calculate_curve_fit(df, x_var, y_var):
    x, y = df.loc[:, [x_var, y_var]].to_numpy().T
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            params, cov = curve_fit(
                power_function,
                x,
                y,
                p0=[1, 1, 0],
            )
    except:
        params = [np.nan] * 3
    return {
        "a": params[0],
        "b": params[1],
        "c": params[2],
    }


def calculate_per_gene_info(df):
    lr_dict = {
        "bf": calculate_regression(df, "mean", "bf_point"),
        "bs": calculate_regression(df, "mean", "bs_point"),
    }

    curve_dict = {
        "bf": calculate_curve_fit(df, "mean", "bf_point"),
        "bs": calculate_curve_fit(df, "mean", "bs_point"),
    }

    results_dict = {
        "n_conditions": len(df),
    }
    for lr_n, lr_v in lr_dict.items():
        for n, v in lr_v.items():
            results_dict[f"{lr_n}_{n}"] = v
    for curve_n, curve_v in curve_dict.items():
        for n, v in curve_v.items():
            results_dict[f"{curve_n}_pf_{n}"] = v

    return pd.Series(results_dict)


gene_info_df = condition_info_df.groupby("gene").apply(calculate_per_gene_info).join(
    mv_gene_info_df.rename(columns={n: f"mv_{n}" for n in mv_gene_info_df.columns}),
)

In [None]:
ax = sns.scatterplot(
    x="bs_r2",
    y="bf_r2",
    data=gene_info_df,
)
plt.plot((-0.5, 1), (-0.5, 1), "-")
plt.axvline(x=0, ls=":")
plt.axhline(y=0, ls=":")
ax.set_aspect(1)
plt.show()

In [None]:
def plot_relationship_scatter(ax, condition_info, lr_info, y_var_prefix, y_var=None):
    y_var = y_var or f"{y_var_prefix}_point"
    x_range = np.asarray((0, condition_info["mean"].max()))

    ax.scatter(
        condition_info.loc[~condition_info.mv_outlier, "mean"],
        condition_info.loc[~condition_info.mv_outlier, y_var],
        marker="o",
    )
    ax.scatter(
        condition_info.loc[condition_info.mv_outlier, "mean"],
        condition_info.loc[condition_info.mv_outlier, y_var],
        marker="x",
        label="mv outlier",
    )
    ax.plot(x_range, (x_range * lr_info[f"{y_var_prefix}_slope"]) + lr_row[f"{y_var_prefix}_intercept"])
    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)
    ax.set_xlabel("mean")
    ax.set_ylabel(y_var)

    r2 = lr_info[f"{y_var_prefix}_r2"]
    ax.set_title(f"$R^2=${r2:.2f}")


def plot_relationship_curve(ax, condition_info, lr_info, y_var_prefix):
    a, b, c = [lr_info[f"{y_var_prefix}_pf_{coef}"] for coef in ["a", "b", "c"]]
    if a is np.nan: return

    x = np.linspace(np.finfo(np.float).eps, condition_info["mean"].max())
    y = power_function(x, a, b, c)

    ax.plot(x, y, "--")


lr_to_plot = gene_info_df.sort_values(by="bf_r2", ascending=False)
for idx, (gene_id, lr_row) in enumerate(lr_to_plot.iterrows(), start=1):
    print(f"{idx}. {gene_symbol_map.lookup(gene_id)}")
    display(lr_row.to_frame().T)

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]

    _, axes = plt.subplots(ncols=3, figsize=(12, 4), sharex=True)
    for prefix, ax in zip(("bf", "bs"), axes[:2]):
        plot_relationship_scatter(ax, condition_info_subset, lr_row, prefix)
        plot_relationship_curve(ax, condition_info_subset, lr_row, prefix)
    plot_relationship_scatter(axes[2], condition_info_subset, lr_row, "mv", y_var="variance")

    plt.tight_layout()
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()