# Combine results for all genetic demultiplexing (from demuxlet) for generating summary information.

In [None]:
!date

#### import libraries and set notebook variables

In [None]:
from pandas import DataFrame, read_csv, concat
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

In [None]:
# variables and constants
cohort = 'aging'
pool_names = []
for pnum in range(1, 7):
    for lane in range(1, 9):
        pool_names.append(f'Aging_P00{pnum}_SCRN_{lane}')
DEBUG = True
SEABORN_STYLE = 'seaborn-bright'

# directories
wrk_dir = '/home/jupyter/brain_aging_phase1'
results_dir = f'{wrk_dir}/demux'
info_dir = f'{wrk_dir}/sample_info'

# in files
info_file = f'{info_dir}/{cohort}.pool_patient_sample_info.csv'

# out files
output_file = f'{results_dir}/{cohort}_demultiplexing.csv'

if DEBUG:
    print(info_file)
    print(output_file)
    print(pool_names)    

#### functions

In [None]:
def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'df shape = {df.shape}')
    if verbose:
        display(df.head())
        
def parse_pool_name(name):
    parts = name.split('_')
    # pool num 2nd item and lane num is 4th
    return parts[1], parts[3]        

# function to plot the barcode counts by sample
def plot_demux_counts(this_df, title, id_col='Sample_id'):
    with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
        plt.style.use(SEABORN_STYLE)
        sns.barplot(x=this_df[id_col].value_counts().index, 
                    y=this_df[id_col].value_counts().values, palette='Blues_d')
        plt.xticks(rotation=90)
        plt.tight_layout()
        plt.title(title)
        plt.ylabel('barcode counts')

## load the sample information

In [None]:
info_df = read_csv(info_file)
# drop the un-needed columns
cols_to_drop = ['Pool_no', 'Sample_no', 'Sequence_type', 'Source_id']
info_df.drop(columns=cols_to_drop, inplace=True)
peek_dataframe(info_df, 'loaded sample info df', DEBUG)

## load the demuxlet results and merge sample information in 

In [None]:
best_df_list = []
for pool_name in pool_names:
    # print(pool_name)
    pool_num, lane_num = parse_pool_name(pool_name)
    # print(pool_num, lane_num)
    best_file = f'{results_dir}/{pool_name}.best'
    # print(best_file)
    this_demux_df = read_csv(best_file, sep='\s+')
    # set another best sample column and if doublet or ambigous set that as ID
    this_demux_df['best_id'] = np.where(this_demux_df['DROPLET.TYPE'] == 'SNG', 
                                        this_demux_df['SNG.BEST.GUESS'], 
                                        np.where(this_demux_df['DROPLET.TYPE'] == 'DBL', 
                                            'doublet', 'ambiguous'))    
    # merge to sample pool info in
    pool_info_df = info_df.loc[info_df['pool_name'] == pool_num]
    this_demux_df = this_demux_df.merge(pool_info_df, how='left', 
                                        left_on='best_id', right_on='donor_id')
    # also add the 'lane'
    this_demux_df['lane_num'] = lane_num
    # make sure if pool info is set, will be missing for doublet and ambigious 
    this_demux_df.pool_name = pool_num    
    peek_dataframe(this_demux_df, f'loaded {pool_name} demux best', False)
    best_df_list.append(this_demux_df)
    temp = this_demux_df.loc[(this_demux_df['DROPLET.TYPE'] == 'SNG') & (this_demux_df.Sample_id.isna())]
    print(temp.shape)
    display(temp)

## combine the per pool demultiplexing into a single data frame

In [None]:
demux_df = concat(best_df_list)
peek_dataframe(demux_df, 'combined demux best df', DEBUG)

### replace the missing values

In [None]:
demux_df = demux_df.fillna('unknown')
peek_dataframe(demux_df, 'replaced missing values', DEBUG)

### update the sample IDs to reflect a doublet or ambigous prediction if it wasn't assigned

In [None]:
demux_df.loc[demux_df.best_id == 'ambiguous', 'Sample_id'] = 'ambiguous'
demux_df.loc[demux_df.best_id == 'doublet', 'Sample_id'] = 'doublet'
demux_df.loc[demux_df.best_id == 'ambiguous', 'donor_id'] = 'ambiguous'
demux_df.loc[demux_df.best_id == 'doublet', 'donor_id'] = 'doublet'
# since we did demultiplexing using full genotypes instead of just pool expected
# there are some small number of assignments that are incorrect keep those for
# computation purposes
demux_df.loc[demux_df.Sample_id == 'unknown', 'Sample_id'] = 'incorrect'
demux_df.loc[demux_df.donor_id == 'unknown', 'donor_id'] = 'incorrect'

In [None]:
# take a look of the pool info for incorrect matches
if DEBUG:
    temp = demux_df[demux_df.Sample_id == 'incorrect']
    for best_id in temp.best_id.unique():
        print(f'## {best_id}')
        display(temp.loc[temp.best_id == best_id].pool_name.value_counts())
        display(info_df.loc[info_df.donor_id == best_id].pool_name.unique())

In [None]:
if DEBUG:
    display(demux_df.info())

### check some of the counts

In [None]:
display(demux_df.pool_name.value_counts())

In [None]:
display(demux_df.donor_id.value_counts())

In [None]:
display(demux_df.best_id.value_counts())

In [None]:
display(demux_df.Sample_id.value_counts())

## save the full dataframe

In [None]:
%%time
demux_df.to_csv(output_file)

## what are some of important summary numbers

In [None]:
def compute_summary(df: DataFrame, label: str) -> (int, float):
    cnt = df.loc[df.Sample_id == label].shape[0]
    percent = (cnt/df.shape[0])*100
    return cnt, percent


for label_name in  ['ambiguous', 'doublet', 'incorrect']:
    print(f'## {label_name}')
    this_cnt, this_percent = compute_summary(demux_df, label_name)
    print(f'{label_name} count = {this_cnt}')
    print(f'{label_name} percent = {this_percent:.1f}%')
    percentages = []
    for pool in demux_df.pool_name.unique():
        pool_df = demux_df.loc[demux_df.pool_name == pool]
        this_cnt, this_percent = compute_summary(pool_df, label_name)
        percentages.append(this_percent)
    print(percentages)
    print(f'pool mean {np.mean(np.array(percentages)):.1f}%')
    print(f'pool std {np.std(np.array(percentages)):.1f}%')


## visualize the counts

In [None]:
plot_demux_counts(demux_df, 'Demultiplexing by Pool', 'pool_name')

In [None]:
plot_demux_counts(demux_df, 'Demultiplexing by Donor', 'donor_id')

In [None]:
plot_demux_counts(demux_df, 'Demultiplexing by sample', 'Sample_id')

In [None]:
!date