# Figure to display the proportion of age associated features per cell type and proportion of those age effects mediated by cis ATAC features

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, read_parquet, DataFrame as PandasDF
from scanpy import read_h5ad
from os.path import exists
from seaborn import scatterplot, barplot
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
category = 'cluster_name' # 'curated_type' for broad and 'cluster_name' for specific
REGRESSION_TYPE = 'glm_tweedie'

In [None]:
# parameters
project = 'aging_phase2'
if category == 'curated_type':
    prefix_type = 'broad'
elif category == 'cluster_name':
    prefix_type = 'specific' 
modality = 'GEX'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# in files
results_file = f'{results_dir}/{project}.{modality}.{prefix_type}.{REGRESSION_TYPE}_fdr_filtered.age.csv'
age_sum_file = f'{figures_dir}/{project}.{modality}.{prefix_type}.{REGRESSION_TYPE}.summary.csv'
med_age_sum_file =f'{figures_dir}/{project}.{modality}.{prefix_type}.{REGRESSION_TYPE}.conditioned.age.summary.csv'

# out files
fig_filename = f'{figures_dir}/{project}.{modality}.{prefix_type}.{REGRESSION_TYPE}.mediated_summary.png'

# constants
DEBUG = False
ALPHA = 0.05
if DEBUG:
    print(f'results_file = {results_file}')
    print(f'age_sum_file = {age_sum_file}')
    print(f'med_age_sum_file = {med_age_sum_file}')
    print(f'fig_filename = {fig_filename}')

#### functions

In [None]:
def load_quantification(cell_name: str, verbose: bool=False) -> PandasDF:
    this_file = f'{quants_dir}/{project}.{modality}.{prefix_type}.{cell_name}.pb.parquet'
    if not exists(this_file):
        return None
    df = read_parquet(this_file)
    if verbose:
        print(f'shape of read {cell_name} quantifications {df.shape}')        
        display(df.sample(5))
    return df

## load input data

In [None]:
age_sum_df = read_csv(age_sum_file, index_col=0)
print(f'age_sum_df shape is {age_sum_df.shape}')
if DEBUG:
    display(age_sum_df.head())

In [None]:
med_sum_df = read_csv(med_age_sum_file, index_col=0)
med_sum_df = med_sum_df.set_index('tissue')
med_sum_df = med_sum_df.rename(columns={'percent': 'percent_med'})
print(f'med_sum_df shape is {med_sum_df.shape}')
if DEBUG:
    display(med_sum_df.head())

## load the modality's results

In [None]:
results_df = read_csv(results_file)
print(f'shape of {modality} results {results_df.shape}')
if DEBUG:
    display(results_df.sample(5))
    display(results_df.type.value_counts())

## summarize the detected effects per cell-type

In [None]:
results_df['abs_coef'] = results_df.coef.abs()
effects_df = results_df.groupby('tissue').abs_coef.mean()
print(f'shape of effects_df is {effects_df.shape}')
if DEBUG:
    display(effects_df)

## merge the summary tables

In [None]:
props_df = age_sum_df.merge(med_sum_df, how='inner', 
                            left_index=True, right_index=True)
print(f'props_df shape is {props_df.shape}')
props_df = props_df.merge(effects_df, how='inner', 
                            left_index=True, right_index=True)
print(f'new props_df shape is {props_df.shape}')
if DEBUG:
    display(props_df)

## visualize the proportions

In [None]:
with rc_context({'figure.figsize': (15, 11), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    scatterplot(data=props_df.sort_values('percent_aaf', ascending=False).reset_index(),
                x='percent_aaf', y='percent_med', size='abs_coef', hue='index')
    plt.title(f'{modality} features that are age associated ')  
    plt.xlabel('% of age associated genes')
    plt.ylabel('% of ATAC mediated age effects')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, 
               borderaxespad=0,prop={'size': 10})    
    plt.savefig(fig_filename)  
    plt.show()

In [None]:
!date