## GSE181279 raw data merging and batch effect correction using scanpy and SCALEX

In [None]:
import scalex
from scalex import SCALEX
from matplotlib import pyplot as plt
from matplotlib.pyplot import rc_context
import anndata as ad
import scanpy as sc
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib
import scipy
import hdf5plugin
import os
import csv
import triku as tk
import warnings
warnings.filterwarnings('ignore')

In [None]:
DATASET_ACCESSION_NUMBER = "GSE181279"

### Initialization of important utility functions

In [None]:
def annotate_by_batch(batch_name: str):
    if '_AD' in batch_name:
        return "Alzheimer's disease"
    if '_NC' in batch_name:
        return "Normal control"
    return "Uknown class"

def create_anndata_list(main_dir):
    postfix = "matrix.mtx.gz"
    anndata_list = []
    for file in os.listdir(main_dir):
        if file.endswith(postfix):
            prefix = file[:-len(postfix)]
            print(f"Processing batch with prefix: {prefix} ...", end=" ")
            adata = sc.read_10x_mtx(main_dir, prefix=prefix)
            adata.obs['batch_name'] = prefix[:-1]
            adata.obs['disease_type'] = annotate_by_batch(prefix)
            anndata_list.append(adata)
            print("done.")
    return anndata_list

In [None]:
def sc_preprocess(adata_raw):
    sc.pp.filter_cells(adata_raw, min_genes=600)
    sc.pp.filter_genes(adata_raw, min_cells=3)
    adata_raw = adata_raw[:, [gene for gene in adata_raw.var_names if not str(gene).startswith(tuple(['ERCC', 'MT-', 'mt-']))]]
    sc.pp.normalize_total(adata_raw, target_sum=1e4)
    sc.pp.log1p(adata_raw)
    sc.pp.highly_variable_genes(adata_raw, min_mean=0.0125, max_mean=3, min_disp=0.5)
    adata_raw.raw = adata_raw
    adata_raw = adata_raw[:, adata_raw.var.highly_variable]
    sc.pp.scale(adata_raw, max_value=10)
    sc.pp.pca(adata_raw)
    # the utilities below can be commented out for use of the preprocessor in conjunction with other integration techniques
    sc.pp.neighbors(adata_raw)
    sc.tl.umap(adata_raw)
    sc.tl.tsne(adata_raw)
    return adata_raw

In [None]:
gse_list = create_anndata_list(f"datasets/{DATASET_ACCESSION_NUMBER}_RAW")
gse_list

In [None]:
gse_merged = ad.AnnData.concatenate(*gse_list)
gse_merged

### (Optional) save merged anndata object to disk

In [None]:
gse_merged.write_h5ad(
    f"{DATASET_ACCESSION_NUMBER}_merged.h5ad",
    compression=hdf5plugin.FILTERS["zstd"],
    compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
)

### (Optional) get merged dataset from h5ad file

In [None]:
# This cell can be used after executing the code above
gse_merged = sc.read(f"./datasets/{DATASET_ACCESSION_NUMBER}_merged.h5ad")
gse_merged

In [None]:
gse_merged.obs.batch_name.value_counts()

In [None]:
gse_merged.obs

### Preprocess merged dataset using the function created above

In [None]:
gse_merged_preprocessed = sc_preprocess(gse_merged)
gse_merged_preprocessed

### (Optional) Store preprocessed data into .h5ad file

In [None]:
gse_merged_preprocessed.write_h5ad(
    f"./datasets/{DATASET_ACCESSION_NUMBER}_merged_qcdr.h5ad",
    compression=hdf5plugin.FILTERS["zstd"],
    compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
)

### Visualize merged and preprocessed dataset before performing batch effect correction

#### Note: preprocessing is only executed for the purposes of visualization; SCALEX performs scanpy's preprocessing steps before integrating the data.

In [None]:
with rc_context({'figure.figsize': (6, 6)}):
    sc.pl.tsne(gse_merged_preprocessed,color=['batch_name'], title=' ', legend_loc=None)

In [None]:
with rc_context({'figure.figsize': (6, 6)}):
    sc.pl.tsne(gse_merged_preprocessed,color=['disease_type'], title=' ', legend_loc=None)

### Execute SCALEX batch effect correction method

In [None]:
gse_merged_corrected = SCALEX(f'./datasets/{DATASET_ACCESSION_NUMBER}_merged.h5ad',batch_name='batch',min_features=600, min_cells=3, outdir='gse181279_output/',show=False,gpu=7)

In [None]:
gse_merged_corrected

### (Optional) store merged and corrected dataset to h5ad file

In [None]:
gse_merged_corrected.write_h5ad(
    f"./{DATASET_ACCESSION_NUMBER}_merged_corrected.h5ad",
    compression=hdf5plugin.FILTERS["zstd"],
    compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
)

### (Optional) get integrated/corrected dataset from h5ad file

In [None]:
# This cell can be used once the code above has been executed
gse_merged_corrected = sc.read(f"./datasets/{DATASET_ACCESSION_NUMBER}_merged_corrected.h5ad")
gse_merged_corrected

In [None]:
gse_merged_corrected.obs

In [None]:
gse181279_corrected_genes = pd.DataFrame(gse_merged_corrected.var.index)
gse181279_corrected_genes.rename(columns={0: "genes"}, inplace=True)
gse181279_corrected_genes

In [None]:
sc.pp.pca(gse_merged_corrected)
sc.pp.neighbors(gse_merged_corrected, metric='cosine', n_neighbors=int(0.5 * len(gse_merged_corrected) ** 0.5))
sc.tl.tsne(gse_merged_corrected)
gse_merged_corrected

### Visualize results after batch effect correction with SCALEX (using tsne)

In [None]:
sc.settings.set_figure_params(dpi=80, facecolor='white',figsize=(6,6),frameon=True)

In [None]:
with rc_context({'figure.figsize': (6, 6)}):
    sc.pl.tsne(gse_merged_corrected,color=['batch_name'], title=' ',legend_fontsize=10)

In [None]:
with rc_context({'figure.figsize': (6, 6)}):
    sc.pl.tsne(gse_merged_corrected,color=['disease_type'], title = ' ', legend_fontsize=10)

### Find differentially expressed genes

In [None]:
sc.tl.rank_genes_groups(adata=gse_merged_corrected, groupby="disease_type",use_raw=False, reference="Normal control", n_genes=500, method='logreg', key_added="logreg")

In [None]:
sc.pl.rank_genes_groups(gse_merged_corrected, n_genes=25, sharey=False, key = "logreg", ncols=2)

In [None]:
gse_merged_corrected.uns['logreg']['names']

In [None]:
degs = pd.DataFrame(gse_merged_corrected.uns['logreg']['names'])
degs.rename(columns={"Alzheimer's disease": "genes"}, inplace=True)
degs = degs.reset_index()
degs.rename(columns={"index": "degs_index"}, inplace=True)

In [None]:
degs

In [None]:
degs_lst = np.recarray.tolist(gse_merged_corrected.uns['logreg']['names'])
degs_lst = [''.join(i) for i in degs_lst]

In [None]:
gse181279_corrected_degs = pd.DataFrame({'index': gse181279_corrected_genes.index[gse181279_corrected_genes['genes'].isin(degs_lst)],
                                          'genes': gse181279_corrected_genes.loc[gse181279_corrected_genes['genes'].isin(degs_lst), 'genes']})

In [None]:
degmerge = pd.merge(gse181279_corrected_degs, degs, on='genes')
degmerge = degmerge.sort_values(by=['degs_index'])

In [None]:
degmerge

In [None]:
degmerge[['index', 'genes']].to_csv(f'{DATASET_ACCESSION_NUMBER}_corrected_degs.csv', index=False)

### Gene selection using triku

In [None]:
gse_merged_corrected

In [None]:
tk.tl.triku(object_triku=gse_merged_corrected, n_features=500, use_raw=False)

In [None]:
gse_merged_triku_var = gse_merged_corrected.var
triku_selected_genes = gse_merged_triku_var[gse_merged_triku_var['triku_highly_variable'] == True].index
triku_selected_genes_lst = pd.Index.tolist(triku_selected_genes)
triku_selected_genes_lst

In [None]:
triku_genes_df = pd.DataFrame(triku_selected_genes)

In [None]:
triku_genes_df.rename(columns={0: "genes"}, inplace=True)

In [None]:
triku_genes_df

In [None]:
triku_selected_genes = pd.DataFrame({'index': gse181279_corrected_genes.index[gse181279_corrected_genes['genes'].isin(triku_selected_genes_lst)],
                                          'genes': gse181279_corrected_genes.loc[gse181279_corrected_genes['genes'].isin(triku_selected_genes_lst), 'genes']})

In [None]:
triku_selected_genes

In [None]:
triku_selected_genes[['index', 'genes']].to_csv(f'{DATASET_ACCESSION_NUMBER}_corrected_triku_genes.csv', index=False)

### Get tsne vectors from the two versions of the merged dataset (before and after batch effect correction with SCALEX)

In [None]:
tsne_before_corr = gse_merged_preprocessed.obsm['X_tsne']
tsne_before_corr1 = tsne_before_corr[:,0]
tsne_before_corr2 = tsne_before_corr[:,1]

In [None]:
tsne_after_corr = gse_merged_corrected.obsm['X_tsne']
tsne_after_corr1 = tsne_after_corr[:,0]
tsne_after_corr2 = tsne_after_corr[:,1]

In [None]:
len(tsne_before_corr1)

### Get and store obs from anndata objects

In [None]:
gse_merged_preprocessed_obs_df = gse_merged_preprocessed.obs
gse_merged_preprocessed_obs_df['tsne_before1'] = tsne_before_corr1.tolist()
gse_merged_preprocessed_obs_df['tsne_before2'] = tsne_before_corr2.tolist()
gse_merged_preprocessed_obs_df = gse_merged_preprocessed_obs_df.drop(['batch', 'n_genes', 'batch_name', 'disease_type'], axis=1)

In [None]:
gse_merged_preprocessed_obs_df

In [None]:
gse_merged_corrected_obs_df = gse_merged_corrected.obs
gse_merged_corrected_obs_df['tsne_after1'] = tsne_after_corr1.tolist()
gse_merged_corrected_obs_df['tsne_after2'] = tsne_after_corr2.tolist()
gse_merged_corrected_batch_name = gse_merged_corrected_obs_df['batch_name']
gse_merged_corrected_disease_type = gse_merged_corrected_obs_df['disease_type']
gse_merged_corrected_obs_df = gse_merged_corrected_obs_df.drop(['batch', 'n_genes', 'leiden', 'batch_name', 'disease_type'], axis=1)

In [None]:
gse_merged_corrected_obs_df

### Get intersection of two dataframes (keep only the number of cells after batch effect correction)

In [None]:
gse_common = gse_merged_preprocessed_obs_df.join(gse_merged_corrected_obs_df, how='inner')
gse_common['disease_type'] = gse_merged_corrected_disease_type.tolist()
gse_common['batch_name'] = gse_merged_corrected_batch_name.tolist()
gse_common

In [None]:
gse_common.index.tolist() == gse_merged_corrected_obs_df.index.tolist()

### Create .csv files of gse_common dataframe (contains; batch_name, disease_type, tsne vectors before and after correction)

In [None]:
gse_common.to_csv(f'./datasets/{DATASET_ACCESSION_NUMBER}_tsne_annotated.csv')

### Create .csv files from the corrected anndata object

In [None]:
gse_merged_corrected_df = pd.DataFrame(gse_merged_corrected.X.todense())
gse_merged_corrected_df.to_csv(f'{DATASET_ACCESSION_NUMBER}_corrected_matrix.csv')
gse_merged_corrected.obs.to_csv(f'{DATASET_ACCESSION_NUMBER}_corrected_observations.csv')
gse_merged_corrected.var.to_csv(f'{DATASET_ACCESSION_NUMBER}_corrected_variables.csv')