In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from rp2 import hagai_2018, GeneSymbolMap, create_gene_symbol_map
from rp2.paths import get_data_path, get_output_path

In [None]:
study_species = "mouse"
study_treatment_set = "lps"
study_lr_method = "ols"

all_index_columns = ["gene", "replicate", "treatment", "time_point"]

Load the list of LPS-responsive genes

In [None]:
responsive_phagocyte_genes = hagai_2018.load_lps_responsive_genes()
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

Load lists of parameters calculated by txburst (and disregard time point "6A")

In [None]:
txburst_path = get_output_path("txburst")

txburst_results = {}
for csv_path in txburst_path.glob(f"{study_species}_*.csv"):
    print(f"Loading: {csv_path.name}")
    df = pd.read_csv(csv_path)
    df = df.loc[df.time_point != "6A"]
    txburst_results[csv_path.stem] = df

Summarise the numbers of genes and conditions with burst parameters for each set of results

In [None]:
for name, txburst_params_df in txburst_results.items():
    txburst_params_kept_subset = txburst_params_df.loc[txburst_params_df.keep]

    print(f"{name}:")
    print(f"  {len(txburst_params_kept_subset):,} conditions across {txburst_params_kept_subset.gene.nunique():,} genes have burst parameters")
    n_responsive_genes_without_params = len(responsive_phagocyte_genes) - txburst_params_kept_subset.gene.nunique()
    print(f"  {n_responsive_genes_without_params:,} responsive genes have no burst params")
    print("  Distribution of those that do:")
    sns.countplot(
        x="n_conditions",
        data=txburst_params_kept_subset.gene.value_counts().to_frame("n_conditions"),
    )
    plt.show()

    sns.scatterplot(
        x="bs_point",
        y="bf_point",
        hue="time_point",
        style="keep",
        markers={True: "o", False: "X"},
        data=txburst_params_df,
    )
    plt.xlabel("Burst size")
    plt.ylabel("Burst frequency")
    plt.show()

Load the per condition statistics calculated for QCed genes

In [None]:
condition_stats_df = pd.read_csv(get_output_path(f"{study_species}_{study_treatment_set}_stats_per_condition_per_gene.csv"))
condition_stats_df.time_point = condition_stats_df.time_point.astype(str)
print(f'QCed "{study_treatment_set}" treatment has {len(condition_stats_df):,} conditions across {condition_stats_df.gene.nunique():,} genes')

Load the linear regression parameters calculated per gene

In [None]:
lr_fit_df = pd.read_csv(get_output_path(f"{study_species}_{study_treatment_set}_lr_fit_per_gene.csv"), index_col="gene")
lr_fit_df = lr_fit_df.loc[lr_fit_df.method == study_lr_method]

Create a gene ID-to-symbol map

In [None]:
symbol_map = create_gene_symbol_map(study_species)

Display plots combining per condition statistics and burst parameters for each set of txburst results

In [None]:
for name, txburst_params_df in txburst_results.items():
    txburst_params_kept_subset = txburst_params_df.loc[txburst_params_df.keep]
    print(f"{name}:")

    index_columns = [column for column in txburst_params_kept_subset.columns if column in all_index_columns]
    combined_df = condition_stats_df.set_index(index_columns).join(txburst_params_kept_subset.set_index(index_columns), how="inner").reset_index()
    print(f"  Shares {len(combined_df):,} conditions across {combined_df.gene.nunique():,} genes")

    genes_df = combined_df.gene.value_counts().to_frame("n_conditions").join(lr_fit_df, how="inner")

    sns.countplot(
        x="n_conditions",
        data=genes_df,
    )
    plt.show()

    sns.boxplot(
        x="n_conditions",
        y="r2",
        data=genes_df,
    )
    plt.show()

    sorted_r2 = genes_df.loc[genes_df.n_conditions >= 9].r2.sort_values(ascending=False)

    plot_genes = sorted_r2[:5].index
    plot_column_names = ["k_on", "k_off", "k_syn", "bs_point", "bf_point"]

    n_plot_rows = len(plot_column_names)
    n_plot_columns = len(plot_genes)

    _, axes = plt.subplots(
        n_plot_rows,
        n_plot_columns,
        sharex="col",
        figsize=(4 * n_plot_columns, 4 * n_plot_rows)
    )
    for ci, gene_id in enumerate(plot_genes):
        for ri, column_name in enumerate(plot_column_names):
            ax = axes[ri, ci]
            sns.scatterplot(
                x="mean",
                y=column_name,
                style="outlier",
                data=combined_df.loc[combined_df.gene == gene_id],
                ax=ax,
            )
            ax.set_xlim(left=0)
            ax.set_ylim(bottom=0)
            if ri == 0:
                ax.set_title(symbol_map.lookup(gene_id))
    for ax in axes[:, 1:].flat:
        ax.set_ylabel(None)

    plt.show()