# This script creates several visualizations for regression coefficients

In [1]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.utils import resample
from scipy.stats import chi2
import ast

In [None]:
# Load data
# celltype_top_10_CI, celltype_top_10_coef_mean

#### Parse the elements in celltype_top_10_CI from string to tuple

In [None]:
def parse_ci(ci_str):
    return ast.literal_eval(ci_str)

confidence_df = pd.DataFrame()
coef_mean_df = pd.DataFrame()

# Apply the parse_ci function to all elements in celltype_top_10_CI
for col in celltype_top_10_CI.columns[1:]:  # Adjust this range according to your actual data structure
    confidence_df[col] = celltype_top_10_CI[col].apply(parse_ci)

for col in celltype_top_10_coef_mean.columns[1:]:  # Adjust this range according to your actual data structure
    coef_mean_df[col] = celltype_top_10_coef_mean[col].apply(lambda x: round(x, 2))

# add LR column to the confidence_df as the first column
confidence_df.insert(0, 'LR', celltype_top_10_CI['LR_pair'])
confidence_df['LR'] = confidence_df['LR'].str.replace(r'_CI', '')

coef_mean_df.insert(0, 'LR', celltype_top_10_coef_mean['LR_pair'])
coef_mean_df['LR'] = coef_mean_df['LR'].str.replace(r'_mean', '')

#### Error bar style

In [None]:
def plot_combined_cis_and_means(df_cis, df_means):
    num_lr_pairs = len(df_cis)
    fig, axes = plt.subplots(nrows=num_lr_pairs, figsize=(10, 8 * num_lr_pairs), sharex=True)

    if num_lr_pairs == 1:
        axes = [axes]  # Make it iterable if only one subplot
    
    # Iterate over each row in the confidence interval DataFrame
    for ax, (lr_pair, row_cis) in zip(axes, df_cis.iterrows()):
        means = df_means.loc[lr_pair]
        ax.set_title(f'LR Pair: {df_means.loc[lr_pair][0]}') # 
        ax.axhline(0, color='gray', linewidth=0.8)  # Horizontal line at zero

        # Collect cell types from DataFrame columns
        cell_types = df_cis.columns

        # Plot each cell type's mean and CI on the same subplot
        for i, cell_type in enumerate(cell_types[1:]):
            # print(f"cell type: {cell_type}")
            # print(f"i:{i}")
            mean = means[cell_type]
            ci = row_cis[cell_type]
            try:
                lower_ci, upper_ci = ci
                ax.errorbar(mean, i, xerr=[[mean - lower_ci], [upper_ci - mean]], fmt='o', label=cell_type, capsize=5)
            except ValueError:
                print(f"Error unpacking CI for {lr_pair} in {cell_type}: {ci}")

        ax.set_yticks(range(len(cell_types[1:])))
        ax.set_yticklabels(cell_types[1:])
        # set x axis label
        ax.set_xlabel('Coefficient Value')
        ax.set_xticks(np.arange(-3, 3.5, 0.5))
    
    # axes[0].legend(loc='upper right')
    plt.tight_layout()
    plt.show()
        
        


In [None]:
plot_combined_cis_and_means(confidence_df, coef_mean_df)