In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import HuberRegressor as Regressor

import rp2
from rp2 import hagai_2018, create_gene_symbol_map
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
analysis_species = "mouse"
analysis_counts = "median"
analysis_treatments = ["unst", "lps"]

min_conditions = 6

index_columns = ["replicate", "treatment", "time_point"]
all_index_columns = ["gene"] + index_columns

In [None]:
gene_symbol_map = create_gene_symbol_map(analysis_species)

In [None]:
condition_info_df = pd.read_csv(get_txburst_results_csv_path(analysis_species, index_columns, count_type=analysis_counts))
condition_info_df.replicate = condition_info_df.replicate.astype(str)

condition_info_df = condition_info_df.loc[condition_info_df.treatment.isin(analysis_treatments)]
print(f"{len(condition_info_df):,} conditions with {condition_info_df.gene.nunique():,} genes")

condition_info_df["valid"] = condition_info_df.bs_point.notna() & condition_info_df.bf_point.notna()
condition_info_df = condition_info_df.loc[condition_info_df.valid]

valid_counts = condition_info_df.groupby("gene").valid.count()
valid_gene_ids = valid_counts.index[valid_counts >= min_conditions]
print(f"{len(valid_gene_ids):,} filtered genes")

condition_info_df = condition_info_df.loc[condition_info_df.gene.isin(valid_gene_ids)]
print(f"{len(condition_info_df):,} conditions with {condition_info_df.gene.nunique():,} genes")

In [None]:
def calculate_count_stats():
    counts_adata = hagai_2018.load_counts(analysis_species, scaling=analysis_counts)
    counts_adata = counts_adata[counts_adata.obs.treatment.isin(analysis_treatments)]
    counts_adata = counts_adata.copy()
    return hagai_2018.calculate_counts_condition_stats(counts_adata, group_columns=index_columns)


condition_info_df = condition_info_df.set_index(all_index_columns).join(
    calculate_count_stats().set_index(all_index_columns),
    how="left",
).reset_index()

In [None]:
def calculate_regression(df, x_var, y_var):
    lr_x, lr_y = df.loc[:, [x_var, y_var]].to_numpy().reshape(1, -1, 2).T

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        lr = Regressor().fit(lr_x, lr_y)

    return {
        "slope": lr.coef_.item(),
        "intercept": lr.intercept_.item(),
        "r2": lr.score(lr_x, lr_y),
    }


def calculate_per_gene_info(df):
    lr_dict = {
        "bf": calculate_regression(df, "mean", "bf_point"),
        "bs": calculate_regression(df, "mean", "bs_point"),
        "mv": calculate_regression(df, "mean", "variance"),
    }

    results_dict = {
        "n_conditions": len(df),
    }
    for lr_n, lr_v in lr_dict.items():
        for n, v in lr_v.items():
            results_dict[f"{lr_n}_{n}"] = v

    return pd.Series(results_dict)


gene_info_df = condition_info_df.groupby("gene").apply(calculate_per_gene_info)

In [None]:
ax = sns.scatterplot(
    x="bs_r2",
    y="bf_r2",
    data=gene_info_df,
)
plt.plot((-0.5, 1), (-0.5, 1), "-")
plt.axvline(x=0, ls=":")
plt.axhline(y=0, ls=":")
ax.set_aspect(1)
plt.show()

In [None]:
def plot_relationship_scatter(ax, condition_info, lr_info, y_var_prefix, y_var=None):
    y_var = y_var or f"{y_var_prefix}_point"
    x_range = np.asarray((0, condition_info["mean"].max()))

    sns.scatterplot(
        x="mean",
        y=y_var,
        data=condition_info,
        ax=ax,
    )
    ax.plot(x_range, (x_range * lr_info[f"{y_var_prefix}_slope"]) + lr_row[f"{y_var_prefix}_intercept"])


lr_to_plot = gene_info_df.sort_values(by="bf_r2", ascending=False)
for idx, (gene_id, lr_row) in enumerate(lr_to_plot.iterrows(), start=1):
    print(f"{idx}. {gene_symbol_map.lookup(gene_id)}")
    display(lr_row.to_frame().T)

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]

    _, axes = plt.subplots(ncols=3, figsize=(12, 4), sharex=True)
    plot_relationship_scatter(axes[0], condition_info_subset, lr_row, "bf")
    axes[0].set_title(f"$R^2=${lr_row.bf_r2:.2f}")
    plot_relationship_scatter(axes[1], condition_info_subset, lr_row, "bs")
    axes[1].set_title(f"$R^2=${lr_row.bs_r2:.2f}")
    plot_relationship_scatter(axes[2], condition_info_subset, lr_row, "mv", y_var="variance")
    axes[2].set_title(f"$R^2=${lr_row.mv_r2:.2f}")
    plt.tight_layout()
    plt.show()