In [None]:
import os
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt

os.environ[ 'NUMBA_CACHE_DIR' ] = '/tmp/'
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80, facecolor='white')
sns.set_style('ticks')

def count_barplot(adata, axes, skip_x = False, sample_col = 'sample_name', title = ''):
    s = adata.obs["sample_name"].value_counts()
    s = s.to_frame().reset_index()
    sns.barplot(s, x='sample_name', y='count', color='orange', ax = axes)
    axes.set_title(title)
    if skip_x:
        axes.set_xticks([])
        axes.set_xlabel('')
    else:
        axes.tick_params(axis='x', labelrotation=45)
    for index, row in s.iterrows():
        axes.text(row['sample_name'], row['count'], row['count'], color='black', ha="center")

def library_size_plot(adata, axes, skip_x = False, title = ''):
    sc.pl.violin(adata, ['log1p_total_counts'], groupby = 'sample_name',
             jitter=0.4, multi_panel=True, show=False, ax = axes)
    if skip_x:
        axes.set_xticks([])
        axes.set_xlabel('')
    else:
        axes.tick_params(axis='x', labelrotation=45)
    axes.set_title(title)

def gene_counts(adata, axes, skip_x = False, title = ''):
    sc.pl.violin(adata, ['log1p_n_genes_by_counts'], groupby = 'sample_name',
             jitter=0.4, multi_panel=True, show=False, ax = axes)
    if skip_x:
        axes.set_xticks([])
        axes.set_xlabel('')
    else:
        axes.tick_params(axis='x', labelrotation=45)
    axes.set_title(title)

def mitochondiral_content(adata, axes, skip_x = False, title = ''):
    sc.pl.violin(adata, ['mito_frac'], groupby = 'sample_name',
             jitter=0.4, multi_panel=True,show=False, ax = axes)
    if skip_x:
        axes.set_xticks([])
        axes.set_xlabel('')
    else:
        axes.tick_params(axis='x', labelrotation=45)  
    axes.set_title(title)
    
experiment_name = ''

In [None]:
path = './'
filelist_all = os.listdir(path)
filelist = [x for x in filelist_all if x.endswith(".h5ad")]

batch_corrected = ['scvi' in x for x in filelist]
print('adatas: ' + '\n'.join([os.path.join(path, filename) for filename in filelist]))
adatas = [sc.read_h5ad(os.path.join(path, filename)) for filename in filelist]
adata_names = ['SCVI' if x else 'Before SCVI' for x in batch_corrected]

# Cell counts

In [None]:
adata_scvi = adatas[np.where(batch_corrected)[0][0]]
filters = adata_scvi.obs['doublet'] | adata_scvi.obs['outlier']
adata_scvi_filtered = adata_scvi[~filters]

In [None]:
print(f'The data contains {adata_scvi.shape[0]} cells before doublet and outlier removal.')
print(f'{sum(filters)} cells were detected as doublets or outliers.')
print(f'{adata_scvi_filtered.shape[0]} cells remained after filtering.')

# Violin plots of QC metrics before and after filtering

## Library size distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (11, 6))
library_size_plot(adata_scvi, axes = axes[0], title = 'Before filtering')
library_size_plot(adata_scvi_filtered, axes = axes[1], title = 'Filtered')

## n_genes in each cell

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (11, 6))
gene_counts(adata_scvi, axes = axes[0], title = 'Before filtering')
gene_counts(adata_scvi_filtered, axes = axes[1], title = 'Filtered')

## Mitochondrial (MT)-content

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (11, 6))
mitochondiral_content(adata_scvi, axes = axes[0], title = 'Before filtering')
mitochondiral_content(adata_scvi_filtered, axes = axes[1], title = 'Filtered')

# UMAP before and after batch correction

### Custom parameters

In [None]:
if plots is not None: 
    plots = pd.read_csv(plots, header=None)
    plots = plots.iloc[0,:].values.tolist()
    for plot in plots: 
        print(plot)
        fig, axes = plt.subplots(1, 2, figsize = (10, 5))
        for i, adata in enumerate(adatas):
            ax = axes[i%2]
            sc.pl.umap(adata, color = plot, title = adata_names[i], show = False, ax = ax, frameon = False)
        plt.tight_layout()
        plt.show()

### Sample names

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (10, 5))
for i, adata in enumerate(adatas):
    ax = axes[i%2]
    sc.pl.umap(adata, color = ['sample_name'], title = adata_names[i], frameon = False, show = False, ax = ax)
plt.tight_layout()
plt.show()

### Leiden clusters

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (10, 5))
for i, adata in enumerate(adatas):
    ax = axes[i%2]
    sc.pl.umap(adata, color = ['leiden'], title = adata_names[i], frameon = False, show = False, ax = ax)
plt.tight_layout()
plt.show()

### Outliers (in orange)

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (10, 5))
for i, adata in enumerate(adatas):
    ax = axes[i%2]
    sc.pl.umap(adata, color = ['outlier'], title = adata_names[i], frameon = False, show = False, ax = ax)
plt.tight_layout()
plt.show()

### Doublets (in orange)

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (10, 5))
for i, adata in enumerate(adatas):
    ax = axes[i%2]
    sc.pl.umap(adata, color = ['doublet'], title = adata_names[i], frameon = False, show = False, ax = ax)
plt.tight_layout()
plt.show()

# Marker gene plot before and after batch correction

In [None]:
for i, adata in enumerate(adatas):
    print(adata_names[i])
    sc.pl.rank_genes_groups(adatas[i], n_genes=20, groupby="leiden", title = adata_names[i])

# Celltypist annotations

In [None]:
adata_celltypist = [x for ct,x in enumerate(adatas) if 'celltypist' in os.path.basename(filelist[ct])][0]
top_labels = adata_celltypist.obs['predicted_labels'].value_counts().nlargest(10).index
top_labels_adata = adata_celltypist[adata_celltypist.obs['predicted_labels'].isin(top_labels)].copy()
top_labels_adata.obs['predicted_labels'] = top_labels_adata.obs['predicted_labels'].astype(str)
sc.pl.umap(
    top_labels_adata,
    color=['leiden', "predicted_labels"],
    wspace=0.4, 
)