In [None]:
import scanpy as sc
import pandas as pd
from matplotlib import pyplot as plt
import os
import numpy as np
from matplotlib import rcParams
import seaborn as sns
import anndata as ad

# import function
from cellbender.remove_background.downstream import anndata_from_h5
from cellbender.remove_background.downstream import load_anndata_from_input_and_output

os.chdir(os.getcwd())

## Add metadata to cellbender results: 
* each batch into separate anndata 
* Add metadata from batchx_raw.h5ad -> save 
* save them into a dictionary 

In [None]:
base_dir='/data/gpfs/projects/punim2121/Atherosclerosis/aligned_data'

anndata_dict={}
for ref_genome in os.listdir(base_dir):
    if not ref_genome.startswith('.'):
        anndata_dict[ref_genome]={}

        ## Loop over batches and open cellbender output + batchx_raw.h5ad
        for batch in os.listdir('../../data/cellbender_output/'+ref_genome):
            if not batch.startswith('.'):
                input_file=os.path.join('../../data','_'.join([ref_genome,batch,'raw.h5ad']))
                output_file=os.path.join('../../data/cellbender_output/',ref_genome,batch,batch+'_cb.h5')
                                            
                adata_batch=load_anndata_from_input_and_output(
                            input_file=input_file,
                            output_file=output_file,
                            input_layer_key='raw',  # this will be the raw data layer
                            retain_input_metadata=True,
                            analyzed_barcodes_only=True)
                
                ## Read raw data and add its metadata to the cellbender output anndata files
                adata_batch_raw=sc.read_h5ad(input_file)
                meta_colnames=['patient','condition','batch','ribo_frac','mt_frac']
                adata_batch.obs[meta_colnames]=adata_batch_raw.obs.loc[adata_batch.obs.index,meta_colnames]
                
                ## Save the metadata updated cellbender output files
                adata_batch.write_h5ad(os.path.join('../../data','_'.join([ref_genome,batch,'cellbender_corrected.h5ad'])))
                
                ## Calculate basic variables for the obs layer
                adata_batch.obs['n_counts_raw']=adata_batch.layers['raw'].sum(axis=1).A1
                #adata_batch.obs['n_genes']=(adata_batch.layers['raw']>0).sum(axis=1).A1
  

                anndata_dict[ref_genome][batch]=adata_batch
            

In [None]:
adata_batch.var['counts_removed_cellbender']

## Check number of counts removed from genes

In [None]:
import math

for ref_genome in [*anndata_dict][:]:
    
    ncols=2
    nrows=math.ceil(len([*anndata_dict[ref_genome]])/ncols)
    fig=plt.figure(figsize=(ncols*9,nrows*5))
    fig2=plt.figure(figsize=(ncols*9,nrows*5))
    fig.suptitle(ref_genome,fontweight='bold',fontsize=16,y=1)
    fig2.suptitle(ref_genome,fontweight='bold',fontsize=16,y=1)

    for n,batch in enumerate(sorted([*anndata_dict[ref_genome]])):
        ax=fig.add_subplot(nrows,ncols,n+1)
        ax2=fig2.add_subplot(nrows,ncols,n+1)
        adata_batch=anndata_dict[ref_genome][batch]
        
        ## Get total counts for each gene for raw and cellbender corrected counts
        raw_counts=adata_batch.layers['raw'].sum(axis=0).A1
        cb_counts=adata_batch.layers['cellbender'].sum(axis=0).A1
        
        ## Init col for removed counts
        adata_batch.var['counts_removed_cellbender']=0
        
        ## Set filter which filters for genes that were corrected by cellbender
        filt=raw_counts!=cb_counts

        ## Calculate absolut removed counts + ratio of removed_count compared to raw count number
        adata_batch.var.loc[filt,'counts_removed_cellbender']=raw_counts[filt]-cb_counts[filt]
        adata_batch.var.loc[filt,'ratio_of_counts_removed_cellbender']=(raw_counts[filt]-cb_counts[filt])/raw_counts[filt]

        plot_df_ratio_sorted=adata_batch.var.loc[(filt)&(raw_counts>100),\
                                ['counts_removed_cellbender','ratio_of_counts_removed_cellbender']].sort_values(by='ratio_of_counts_removed_cellbender',ascending=False)

        plot_df_count_sorted=adata_batch.var.loc[(filt)&(raw_counts>100),\
                                ['counts_removed_cellbender','ratio_of_counts_removed_cellbender']].sort_values(by='counts_removed_cellbender',ascending=False) 

        sns.barplot(data=plot_df_ratio_sorted.loc[plot_df_ratio_sorted.index[0:10],'counts_removed_cellbender'],ax=ax)
        sns.barplot(data=plot_df_count_sorted.loc[plot_df_count_sorted.index[0:10],'counts_removed_cellbender'],ax=ax2)
        #sns.histplot(plot_df.loc[plot_df.index[0:],'ratio_of_counts_removed_cellbender'])
        ax.set_title(batch)
        ax2.set_title(batch)
        
    fig.tight_layout()    


## Plot Barcode rank - UMI plots (__both raw and cellbender corrected counts__)  for each batch

In [None]:
for ref_genome in [*anndata_dict][:]:
    for batch in sorted([*anndata_dict[ref_genome]]):
        adata_batch=anndata_dict[ref_genome][batch]
        
        for n_count in ['n_raw','n_cellbender']:
        
            plot_df=adata_batch.obs[['cell_probability','cell_size','droplet_efficiency','n_raw', 'n_cellbender']].copy()
            plot_df=plot_df.sort_values(by=n_count,ascending=False)
            plot_df['id']=np.arange(1,len(adata_batch.obs)+1)

            batch_plot_df=plot_df.copy()
            print('batch',batch,'min n_count',plot_df[n_count].min())
            batch_plot_df['id']=np.arange(1,len(batch_plot_df)+1)
            batch_plot_df['log_id']=np.log10(batch_plot_df['id'])
            batch_plot_df['log_n_counts']=np.log10(batch_plot_df[n_count]+1)
            batch_plot_df['droplet']=pd.cut(x=batch_plot_df['cell_probability'],bins=[0,0.5,1],labels=['empty','cell'])

            p=sns.jointplot(data=batch_plot_df,x='log_id',y='log_n_counts',hue='droplet',height=8,kind='scatter',s=10,
                            ylim=(0,batch_plot_df['log_n_counts'].max()),
                            xlim=(batch_plot_df['log_id'].min(),batch_plot_df['log_id'].max()))
            # Create a second y-axis on the right side
            ax2=p.ax_joint.twinx()
            # Customize the second y-axis
            ax2.set_ylabel("Cell probability")
            ax2.axhline(batch_plot_df.loc[batch_plot_df['droplet']=='empty','cell_probability'].max(),color='red')

            p.fig.suptitle(ref_genome + '-'+str(batch)+' '+n_count.split('_')[1]+' cell counts',fontsize=25,fontweight='bold')
            # Access the matplotlib axes
            ax=p.ax_joint
            ax.set_ylabel('Log number of counts',fontsize=15,fontweight='bold')
            ax.set_xlabel('Log Barcode id',fontsize=15,fontweight='bold')


## Lognormalize data and save them into a dictionary

In [None]:
anndata_dict_norm={}

for ref_genome in [*anndata_dict][:]:
    anndata_dict_norm[ref_genome]={}

    for batch in sorted([*anndata_dict[ref_genome]]):
        adata_batch=anndata_dict[ref_genome][batch]
        adata_pp=adata_batch.copy()
        
        sc.pp.normalize_total(adata_pp, layer='raw',target_sum=1e4, exclude_highly_expressed=True)
        sc.pp.log1p(adata_pp)

        adata_pp.obs['n_counts_raw_norm']=adata_pp.X.sum(axis=1).A1
        adata_pp.obs['log_n_counts_raw_norm']=np.log10(adata_pp.obs['n_counts_raw_norm']+1)
        adata_pp.obs['droplet']=pd.cut(x=adata_pp.obs['cell_probability'],bins=[0,0.5,1],labels=['empty','cell'])
        
        sc.pp.neighbors(adata_pp, n_neighbors=15, n_pcs=15, metric='euclidean')
        sc.tl.umap(adata_pp)
        anndata_dict_norm[ref_genome][batch]=adata_pp
    

##  Plot UMAPs: Erythtocyte markers, droplet identity, cell counts

In [None]:
import math
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

ery_genes=['log_n_counts_raw_norm','droplet','HBB','HBA1','HBA2']
rcParams['figure.figsize']=(6,4.5)
ncols=2
nrows=math.ceil(len(ery_genes)/ncols)

for ref_genome in [*anndata_dict_norm][:]:
    for batch in sorted([*anndata_dict_norm[ref_genome]][:]):
        adata_pp=anndata_dict_norm[ref_genome][batch]
        adata_pp.obs['droplet']=pd.cut(x=adata_pp.obs['cell_probability'],bins=[0,0.5,1],labels=['empty','cell'])
        fig=plt.figure(figsize=(ncols*6,nrows*5))
        fig.suptitle('-'.join([ref_genome,batch]))

        for n,col in enumerate(ery_genes):
            ax=fig.add_subplot(nrows,ncols,n+1)
            sc.pl.umap(adata_pp, color=col,show=False,size=15,ncols=2,neighbors_key='X_umap',ax=ax)  
            n=n+1
