## Notebook to combine the demultiplexed GEX per pool per lane anndata objects into a single anndata

- this is just to combine them not correct, see combat (old/simple), MNN, BBKNN, and scVI for that
- when loading each pool prior to combining go ahead and remove and predicted ambient RNA cells that may still after demultiplexing

In [None]:
!date

#### import libraries and set notebook variables

In [None]:
from pandas import read_csv
from scanpy import read_h5ad
from os.path import exists
from anndata import concat as ad_concat
from seaborn import barplot
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

In [None]:
# naming
project = 'aging_phase2'
modality = 'GEX'
set_name = f'{project}_{modality}'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
demux_dir = f'{wrk_dir}/demux'
cellbender_dir = f'{wrk_dir}/cellbender'
info_dir = f'{wrk_dir}/sample_info'
quants_dir = f'{wrk_dir}/quants'

# in files
info_file = f'{info_dir}/{project}.sample_info.csv'

# out files
output_file = f'{quants_dir}/{set_name}.raw.h5ad'

# variables
DEBUG = False
lane_range = range(1, 9)

### load the sample info data

In [None]:
info_df = read_csv(info_file)
print(f'shape of info {info_df.shape}')
info_df = info_df.loc[(~info_df.gex_pool.isna()) & (~info_df.atac_pool.isna())]
print(f'shape of info {info_df.shape}')
# make sure pool nums are ints and not floats
info_df.gex_pool = info_df.gex_pool.astype('int')
info_df.atac_pool = info_df.atac_pool.astype('int')
print(f'shape of info {info_df.shape}')
if DEBUG:
    display(info_df.head())
    display(info_df.gex_pool.value_counts())

#### combine the individual anndatas into single large anndata

In [None]:
%%time
adata_list = []

pools = set(info_df.gex_pool.unique()) | set(info_df.atac_pool.unique())

for pool in pools:
    for lane in lane_range:
        gex_pool = f'{demux_dir}/{modality}_P{pool}_{lane}.h5ad'
        if exists(gex_pool):
            this_adata = read_h5ad(gex_pool)
            cellbender_file = f'{cellbender_dir}/sample_ec_GEX_P{int(pool)}_{lane}_out_cell_barcodes.csv'
            cb_barcodes = read_csv(cellbender_file, header=None)
            this_adata = this_adata[this_adata.obs.index.isin(cb_barcodes[0])]
            if DEBUG:
                print(f'{modality}_P{pool}_{lane}: {this_adata}')
                print(f'cellbender shape {cb_barcodes.shape}')
                print(len(set(cb_barcodes[0]) & set(this_adata.obs.index))/this_adata.obs.shape[0])
            adata_list.append(this_adata)

all_adata = ad_concat(adata_list)
all_adata.obs_names_make_unique()

## drop any GEX cells that are labeled 'unknown'

In [None]:
all_adata = all_adata[all_adata.obs.sample_id != 'unknown']

In [None]:
print(all_adata)
if DEBUG:
    display(all_adata.obs.sample(10))

#### save the combined anndata object

In [None]:
all_adata.write(output_file)

#### visualization functions

In [None]:
# function to plot the barcode counts by sample
def plot_sample_barcode_counts(this_df, sample_name, id_col='sample_id'):
    with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
        plt.style.use('seaborn-v0_8-talk')
        barplot(x=this_df[id_col].value_counts().index,
                y=this_df[id_col].value_counts().values,
                order=this_df[id_col].value_counts().index,
                palette='Blues_d')
        plt.xticks(rotation=90, fontsize=8)
        plt.tight_layout()
        plt.title(sample_name)
        plt.ylabel('barcode counts')
        plt.show()
    print(this_df[id_col].value_counts())

#### visualize the counts by sample

In [None]:
plot_sample_barcode_counts(all_adata.obs, set_name)

In [None]:
plot_sample_barcode_counts(all_adata.obs, set_name, 'donor_id')

In [None]:
plot_sample_barcode_counts(all_adata.obs, set_name, 'gex_pool')

In [None]:
plot_sample_barcode_counts(all_adata.obs, set_name, 'age')

In [None]:
plot_sample_barcode_counts(all_adata.obs, set_name, 'sex')

In [None]:
!date