## Notebook to push the identified donor IDs from demultiplexing back into the data
- this should only be used for the GEX pools and is done per pool, whereas for the ATAC that should be done using the aggregated data for the ATAC pools

In [None]:
!date

#### import libraries

In [None]:
from scanpy import read_10x_h5
from pandas import read_csv, concat
from numpy import where
from seaborn import barplot
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
modality = ''
pool_num = 0
lane_num = 0

In [None]:
# variables and constants
project = 'aging_phase2'
pool_name = f'{modality}_P{pool_num}_{lane_num}'
DEBUG = False
dpi_value = 50
USE_CELLBENDER = True

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
demux_dir = f'{wrk_dir}/demux'
info_dir = f'{wrk_dir}/sample_info'
src_dir = f'{wrk_dir}/src_data/{modality.lower()}'
cellbend_dir = f'{wrk_dir}/cellbender'

# in files
info_file = f'{info_dir}/{project}.sample_info.csv'
if modality == 'GEX':
    if USE_CELLBENDER:
        path_to_10x_h5 = f'{cellbend_dir}/sample_ec_{modality}_P{pool_num}_{lane_num}_out_filtered.h5'
    else:
        path_to_10x_h5 = (f'{src_dir}/sample_ec_{modality}_P{pool_num}_{lane_num}/'
                          'outs/filtered_feature_bc_matrix.h5')
else:
    print('ATAC not supported int this notebook')

# out files
output_file = f'{demux_dir}/{pool_name}.h5ad'

print(f'{info_file=}')
print(f'{path_to_10x_h5=}')
print(f'{output_file=}')

#### visualization functions

In [None]:
# function to plot the barcode counts by sample
def plot_sample_barcode_counts(this_df, sample_name, id_col='sample_id'):
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': dpi_value}):
        plt.style.use('seaborn-v0_8-talk') 
        barplot(x=this_df[id_col].value_counts().index, 
                y=this_df[id_col].value_counts().values, palette='Blues_d')
        plt.xticks(rotation=70)
        plt.tight_layout()
        plt.title(sample_name)
        plt.ylabel('barcode counts')
    print(this_df[id_col].value_counts())

### load the sample info data

In [None]:
info_df = read_csv(info_file)
print(f'shape of info {info_df.shape}')
if DEBUG:
    display(info_df.head())
    display(info_df.gex_pool.value_counts())
    display(info_df.atac_pool.value_counts())    

#### load the 10X matrix files

In [None]:
%%time
adata = read_10x_h5(path_to_10x_h5)
adata.var_names_make_unique()
    
print(adata)

#### load the demuxlet results

In [None]:
best_file = f'{demux_dir}/{pool_name}.best'
demux_df = read_csv(best_file, sep='\s+')
# set another best sample column and if doublet or ambigous set that as ID
demux_df['sample_id'] = where(demux_df['DROPLET.TYPE'] == 'SNG', demux_df['SNG.BEST.GUESS'], 
                              where(demux_df['DROPLET.TYPE'] == 'DBL', 
                                    'doublet', 'ambiguous'))
print(f'shape of demux {demux_df.shape}')    
if DEBUG:
    display(demux_df.sample(5))
    display(demux_df['DROPLET.TYPE'].value_counts())
    display(demux_df.sample_id.value_counts())

#### merge other info with obs IDs

In [None]:
obs_id_df = demux_df[['sample_id', 'BARCODE']].copy()
print(obs_id_df.shape)
# if none of the samples IDs match then file is copied from phase1
# that used donor ID instead of project sample ID, here is a hacky work around
if len(set(obs_id_df.sample_id) & set(info_df.sample_id)) == 0:
    pool_info_df = info_df.loc[info_df.gex_pool == pool_num]
    id_map_dict = pool_info_df[['hbcc_id', 'sample_id']].set_index('hbcc_id')['sample_id'].to_dict()
    obs_id_df = obs_id_df.replace(to_replace=id_map_dict)
    # for any sample ID's that are still the HBCC donor ID's set to unknown
    obs_id_df.loc[obs_id_df.sample_id.str.startswith('NHBCC-'), 'sample_id'] = 'unknown'

obs_id_df = obs_id_df.merge(info_df, how='left', 
                            left_on='sample_id', right_on='sample_id')
print(f'obs IDs shape {obs_id_df.shape}')
if DEBUG:
    display(obs_id_df.head())
    print(obs_id_df['sample_id'].value_counts())
    print(obs_id_df['hbcc_id'].value_counts())
    print(obs_id_df['sex'].value_counts())
    print(obs_id_df['ancestry'].value_counts())
    print(obs_id_df['age'].describe())    
    print(obs_id_df['gex_pool'].value_counts())
    print(obs_id_df['atac_pool'].value_counts())    
    print(obs_id_df['smoker'].value_counts())
    print(obs_id_df['pmi'].describe())
    print(obs_id_df['ph'].describe())
    print(obs_id_df['bmi'].describe())
    print(obs_id_df['rin'].describe())    

#### check that we aren't missing any barcodes

In [None]:
set(adata.obs) - set(demux_df['BARCODE'])

In [None]:
obs_id_df.head()

#### index the demultiplexed IDs with the anndata obs barcodes

In [None]:
obs_id_df = obs_id_df.set_index('BARCODE')
obs_id_df = obs_id_df.reindex(adata.obs.index)
# add columns for phase1 migrated pools compatibility later
obs_id_df['phase1_cluster'] = 'NA'
obs_id_df['phase1_celltype'] = 'NA'
# fill any missing barcode IDs
obs_id_df.sample_id = obs_id_df.sample_id.fillna('unknown')
obs_id_df['donor_id'] = obs_id_df['hbcc_id']
obs_id_df = obs_id_df.drop(columns=['hbcc_id'])
print(f'modified obs IDs shape{obs_id_df.shape}')
if DEBUG:
    display(obs_id_df.head())

In [None]:
obs_id_df.info()

#### add the sample info from demultiplexing to the obs

In [None]:
adata.obs = concat([adata.obs, obs_id_df], axis='columns')

In [None]:
adata.obs['donor_id'].value_counts()

In [None]:
adata.obs['sample_id'].value_counts()

#### visualize the counts by sample

In [None]:
plot_sample_barcode_counts(adata.obs, pool_name, 'sample_id')

In [None]:
plot_sample_barcode_counts(adata.obs, pool_name, 'donor_id')

#### filter out the doublet, ambiguous, and unknowns

In [None]:
# filtd_adata = adata[~adata.obs['sample_id'].isin(['doublet', 'unknown'])]
# filtd_adata = adata[adata.obs['Sample_id'].isin(donor_list)].copy()
filtd_adata = adata[~adata.obs['sample_id'].isna()].copy()
filtd_adata

In [None]:
filtd_adata.obs['sample_id'].value_counts()

In [None]:
filtd_adata.obs['donor_id'].value_counts()

#### visualize the counts by sample again without the doublets and unknowns

In [None]:
plot_sample_barcode_counts(filtd_adata.obs, pool_name)

In [None]:
plot_sample_barcode_counts(filtd_adata.obs, pool_name, 'donor_id')

#### save the modified anndata object

In [None]:
filtd_adata.write(output_file)

In [None]:
filtd_adata

In [None]:
!date