In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from rp2 import get_output_path

In [None]:
study_species = "mouse"
study_treatment_set = "lps"
study_lr_method = "ols"

index_cols = ["gene", "replicate", "treatment", "time_point"]

In [None]:
common_genes = get_output_path(f"{study_species}_common_treatment_set_genes.txt").read_text().split("\n")
print(f"{len(common_genes):,} genes")

In [None]:
lr_fit_df = pd.read_csv(get_output_path(f"{study_species}_{study_treatment_set}_lr_fit_per_gene.csv"))
lr_fit_df = lr_fit_df.loc[lr_fit_df.method == study_lr_method].set_index("gene")
lr_fit_df = lr_fit_df.loc[common_genes]
print(lr_fit_df.shape)
display(lr_fit_df.head())

In [None]:
stats_df = pd.read_csv(get_output_path(f"{study_species}_{study_treatment_set}_stats_per_condition_per_gene.csv"))
stats_df.time_point = stats_df.time_point.astype(str)
stats_df = stats_df.loc[stats_df.gene.isin(common_genes)]
#stats_df = stats_df.loc[~stats_df.outlier]
print(stats_df.shape)
display(stats_df.head())

In [None]:
#stats_df = stats_df.set_index(index_cols)
#print(stats_df.shape)
#display(stats_df.head())

In [None]:
txburst_path = get_output_path("txburst")

burst_params_df = pd.read_csv(txburst_path.joinpath(f"{study_species}_responsive_genes.csv"))
print(burst_params_df.shape)
display(burst_params_df.head())

In [None]:
#burst_params_df = burst_params_df.set_index(["gene", "replicate", "treatment", "time_point"])
#print(burst_params_df.shape)
#display(burst_params_df.head())

In [None]:
concat_df = stats_df.set_index(index_cols).join(burst_params_df.set_index(index_cols), how="inner").reset_index()
concat_df = concat_df.sort_values(by=["gene", "replicate", "time_point", "treatment"])
print(concat_df.shape)
display(concat_df.head())

In [None]:
display(concat_df.gene.value_counts().sort_values(ascending=False))

In [None]:
sorted_r2 = lr_fit_df.loc[concat_df.gene.unique()].r2.sort_values(ascending=False)

plot_genes = sorted_r2[:5].index
plot_columns = ["k_on", "k_off", "k_syn"]

_, axes = plt.subplots(len(plot_columns), len(plot_genes), figsize=(20, 12))
for ci, gene_id in enumerate(plot_genes):
    for ri, column_name in enumerate(plot_columns):
        ax = axes[ri, ci]
        sns.scatterplot(
            x="mean",
            y=column_name,
            data=concat_df.loc[concat_df.gene == gene_id],
            ax=ax,
        )
        if ri == 0:
            ax.set_title(gene_id)
        if ri < (axes.shape[0] - 1):
            ax.set_xlabel("")
        if ci > 0:
            ax.set_ylabel("")
    
plt.show()
#g = sns.FacetGrid(plot_df, col="gene", hue="replicate")
#g.map(plt.scatter, "mean", "k_on")
#g.add_legend();