# This notebook must be run with the ABC_Download conda environment within the ABC.sif singularity container

This code, and the following two notebooks, are largely inspired by the Allen Brain Cell Atlas's MERFISH atlas. The code used was sourced and adapted from this github: https://github.com/ZhuangLab/whole_mouse_brain_MERFISH_atlas_scripts_2023/blob/main/scripts/integrate_MERFISH_with_scRNA-seq/integration_round2.ipynb

In [1]:
import os
import shutil
import numpy as np
import scipy.sparse
import scipy.spatial
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc
sc.settings.n_jobs = 56
sc.settings.set_figure_params(dpi=180, dpi_save=300, frameon=False, figsize=(4, 4), fontsize=8, facecolor='white')

import ALLCools
from ALLCools.integration.seurat_class import SeuratIntegration

In [2]:
workspace_path = '/hpc/projects/group.quake/doug/references/ABC/whole/integration_workspace'

In [3]:
def impute_gene_expression(integrator, adata_list, X_ref, ref=0, qry=1, npc=30, kweight=30, sd=1,
                           chunk_size=1000, random_state=0):
    data_qry = adata_list[qry].obsm['X_pca']

    anchor, G, D, cum_qry = integrator.find_nearest_anchor(adata_list, data_qry=data_qry, ref=[ref], qry=[qry],
                                                      npc=npc, k_weight=kweight, sd=sd, random_state=random_state)
    
    if scipy.sparse.issparse(X_ref):
        X_ref = X_ref.toarray()
    
    X_anchor = X_ref[anchor[:, 0]]
    imputed_chunks = []
    
    for chunk_start in np.arange(0, data_qry.shape[0], chunk_size):
        imputed_chunks.append(scipy.sparse.csr_matrix((
                        D[chunk_start:(chunk_start + chunk_size), :, None] *
                        X_anchor[G[chunk_start:(chunk_start + chunk_size)]]
                    ).sum(axis=1).astype(np.float32)))

    return scipy.sparse.vstack(imputed_chunks)

In [None]:
# This is going to remain unrun so that it can be stored in .git
%%time
imputation_path = os.path.join(workspace_path, 'gene_expression_imputation')
partition_path = os.path.join(workspace_path, 'partitions')
partitions = sorted(os.listdir(partition_path))

for pn in partitions:
    print(f'Integrate for {pn}')
    integration_path = os.path.join(partition_path, pn)
    input_adata_seq_file = os.path.join(integration_path, 'adata_seq.h5ad')
    input_adata_merfish_file = os.path.join(integration_path, 'adata_merfish_integrated.h5ad')
    
    if not (os.path.exists(input_adata_seq_file) and os.path.exists(input_adata_merfish_file)):
        print('Missing input file. Skip the integration.')
        continue
        
    ## Load and preprocess the data
    adata_seq_raw = sc.read_h5ad(input_adata_seq_file)
    adata_seq = adata_seq_raw.copy()
    adata_merfish_raw = sc.read_h5ad(input_adata_merfish_file)
    adata_merfish = adata_merfish_raw.copy()
    adata_merfish.X = adata_merfish.layers['counts']

    # Consider the common genes
    common_genes = np.array(adata_seq.var_names.intersection(adata_merfish.var_names))
    adata_seq._inplace_subset_var(common_genes)
    
    # Normalize
    sc.pp.normalize_total(adata_seq, target_sum=1000)

    sc.pp.log1p(adata_seq)

    sc.pp.normalize_total(adata_merfish, target_sum=1000)

    sc.pp.log1p(adata_merfish)
    
    # Select variable genes
    sc.pp.highly_variable_genes(adata_seq)
    hv_genes = list(adata_seq.var.index[adata_seq.var['dispersions'] > 0])
    adata_seq = adata_seq[:, adata_seq.var.index.isin(hv_genes)]
    adata_merfish = adata_merfish[:, adata_merfish.var.index.isin(hv_genes)]
    print(f'Use {len(hv_genes)} highly variable genes for integration.')
    
    # Scale the data

    sc.pp.scale(adata_seq)

    sc.pp.scale(adata_merfish)
    
    # Merge the datasets
    adata_merge = adata_seq.concatenate(adata_merfish,
                                    batch_categories=['seq', 'merfish'],
                                    batch_key='modality',
                                    index_unique=None)
    
    # PCA
    n_pcs = min(100, len(hv_genes) - 1)

    sc.tl.pca(adata_merge, svd_solver='arpack', n_comps=n_pcs)
    
    adata_list = [adata_merge[adata_merge.obs['modality'] == 'seq'],
                  adata_merge[adata_merge.obs['modality'] == 'merfish']]
    
    ## Integration
    # Find the integration anchors

    integrator = SeuratIntegration()
    integrator.find_anchor(adata_list,
                       k_local=None,
                       key_local='X_pca',
                       k_anchor=5,
                       key_anchor='X',
                       dim_red='cca',
                       max_cc_cells=100000,
                       k_score=30,
                       k_filter=None, #why?
                       scale1=False,
                       scale2=False,
                       n_components=n_pcs,
                       n_features=200,
                       alignments=[[[0], [1]]])
    
    # Label transfer
    #cell_type_cols = ['Level1_label', 'Level2_label', 'cl']
    cell_type_cols = ['subclass_label', 'cl']
    transfer_results = integrator.label_transfer(
        ref=[0],
        qry=[1],
        categorical_key=cell_type_cols,
        key_dist='X_pca',
        k_weight=30,
        npc=n_pcs
    )
    integrator.save_transfer_results_to_adata(adata_merge, transfer_results)
    
    # Assign the transfered labels and the confidence
    for cell_type_col in cell_type_cols:
        adata_merfish_raw.obs[cell_type_col + '_transfer'] = transfer_results[cell_type_col].idxmax(axis=1
                                                                                           ).astype('category')
        adata_merfish_raw.obs[cell_type_col + '_confidence'] = transfer_results[cell_type_col].max(axis=1)
        
        n_transfered = len(np.unique(adata_merfish_raw.obs[cell_type_col + '_transfer']))
        n_total = len(np.unique(adata_merge.obs[cell_type_col + '_transfer']))
        print(f'Transfered {n_transfered}/{n_total} {cell_type_col}.')
    
    # Save the label transfer results
    adata_merfish_raw.write_h5ad(os.path.join(integration_path, 'adata_merfish_label_transfer.h5ad'), 
                                 compression='gzip')     
    
    # Impute gene expression
    X_imputed = impute_gene_expression(integrator, adata_list, adata_seq_raw.X, 
                                       ref=0, qry=1, npc=n_pcs, chunk_size=5000)
    adata_imputed = anndata.AnnData(X=X_imputed, obs=adata_merfish_raw.obs.copy(), 
                                var=adata_seq_raw.var.copy(), dtype=np.float32)
    
    # Create the imputation gene partition table
    gene_partition_file = os.path.join(imputation_path, 'gene_partition.csv')
    if not os.path.exists(gene_partition_file):
        gene_partition_df = adata_seq_raw.var.copy()
        gene_partition_df.index.name = 'gene'
        gene_partition_df['gene_partition'] = np.arange(gene_partition_df.shape[0], dtype=int) // 1000
        gene_partition_df.to_csv(gene_partition_file)
        
    else:
        gene_partition_df = pd.read_csv(gene_partition_file).set_index('gene')
    
    gene_partitions = np.unique(gene_partition_df['gene_partition'])
    
    # Save the imputation results
    for sn in np.unique(adata_imputed.obs['subclass_label_transfer']):
        adata_subset = adata_imputed[adata_imputed.obs['subclass_label_transfer'] == sn]
        
        os.makedirs(os.path.join(imputation_path, sn.replace('/', '-').replace(' ', '_')), exist_ok=True)
        for g_p in gene_partitions:
            p_genes = np.array(gene_partition_df[gene_partition_df['gene_partition'] == g_p].index)
            adata_ct_p = adata_subset[:, adata_subset.var.index.isin(p_genes)]
            adata_ct_p.write_h5ad(os.path.join(imputation_path, sn.replace('/', '-').replace(' ', '_'), 
                                               f'{g_p}.h5ad'), compression='gzip')

        
    ## Co-embedding
    # Correct the PCs using the integration anchors
    corrected = integrator.integrate(key_correct='X_pca',
                                 row_normalize=True,
                                 n_components=n_pcs,
                                 k_weight=100,
                                 sd=1,
                                 alignments=[[[0], [1]]])

    adata_merge.obsm['X_pca_integrate'] = np.concatenate(corrected)
    
    # Calculate KNN using the integrated PCs
    sc.pp.neighbors(adata_merge, use_rep='X_pca_integrate')
    
    if len(np.unique(adata_merge.obs['subclass_label_transfer'])) > 1:
        # Generate the PAGA plot for the initial arrangement of the UMAP
        sc.tl.paga(adata_merge, groups='subclass_label_transfer')
        sc.pl.paga(adata_merge, save='_tmp.png', cmap='gist_ncar')
        shutil.move('figures/paga_tmp.png', os.path.join(integration_path, 'integration_paga_round2.png'))
    
        sc.tl.umap(adata_merge, init_pos='paga', min_dist=0.5)
    else:
        sc.tl.umap(adata_merge, min_dist=0.5)
        
    # Save the umap
    sc.pl.umap(adata_merge, color='modality', save='_tmp.png')
    shutil.move('figures/umap_tmp.png', os.path.join(integration_path, 'integration_umap_modality.png'))
    #sc.pl.umap(adata_merge, color='Level1_label_transfer', save='_tmp.png')
    #shutil.move('figures/umap_tmp.png', os.path.join(integration_path, 'integration_umap_Level1_label.png'))
    sc.pl.umap(adata_merge, color='subclass_label_transfer', save='_tmp.png', palette='gist_ncar')
    shutil.move('figures/umap_tmp.png', os.path.join(integration_path, 'integration_umap_subclass_label.png'))
    sc.pl.umap(adata_merge, color='cl_transfer', save='_tmp.png', palette='gist_ncar')
    shutil.move('figures/umap_tmp.png', os.path.join(integration_path, 'integration_umap_cl.png'))
    
    
    # Save the merged adata
    adata_merge.write_h5ad(os.path.join(integration_path, 'adata_merged_round2.h5ad'), compression='gzip')

In [5]:
def get_cluster_mean_expression_matrix(adata, cluster_column):
    '''Get a dataframe of mean gene expression of each cluster.'''
    if scipy.sparse.issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X
   
    cell_exp_mtx = pd.DataFrame(X, index=adata.obs[cluster_column], columns=adata.var.index)    
    return cell_exp_mtx.groupby(by=cluster_column).mean()


def calc_cell_cosines_to_cluster_mean_exps(adata, cluster_col, cluster_mean_exp_df):
    cluster_mean_exp_df = cluster_mean_exp_df.loc[:, adata.var.index]
    cluster_id_map = {cluster_mean_exp_df.index[i]:i for i in range(cluster_mean_exp_df.shape[0])}
    cell_cluster_ids = np.array(adata.obs[cluster_col].map(cluster_id_map))
    
    X_cluster_mean = cluster_mean_exp_df.values[cell_cluster_ids]
    
    if scipy.sparse.issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X
    
    cosines = []
    for i in range(adata.shape[0]):
        if np.sum(X[i]) > 0:
            cosines.append(1 - scipy.spatial.distance.cosine(X[i], X_cluster_mean[i]))
        else:
            cosines.append(0)
    
    return cosines

In [6]:
adatas = []
for pn in partitions:
    adata_file = os.path.join(partition_path, pn, 'adata_merfish_label_transfer.h5ad')
    adatas.append(sc.read_h5ad(adata_file))
    
adata = anndata.concat(adatas)

In [7]:
# Get the mean expressions of each cluster
adata_seq = sc.read_h5ad(os.path.join(workspace_path, 'adata_seq_common_genes.h5ad'))
sc.pp.normalize_total(adata_seq, target_sum=1000)
sc.pp.log1p(adata_seq)

cluster_mean_exp_df = get_cluster_mean_expression_matrix(adata_seq, 'cl')

# Get the MERFISH log1p expressions
adata_merfish_log1p = adata.copy()
sc.pp.normalize_total(adata_merfish_log1p, target_sum=1000)
sc.pp.log1p(adata_merfish_log1p)

# Calcualte the cluster cosine similarities
adata.obs['cluster_cosine_similarity'] = calc_cell_cosines_to_cluster_mean_exps(adata_merfish_log1p, 
                                                    'cl_transfer', cluster_mean_exp_df)

  utils.warn_names_duplicates("obs")
  return cell_exp_mtx.groupby(by=cluster_column).mean()
  np.log1p(X, out=X)
  np.log1p(X, out=X)


In [8]:
adata.obs['adjusted_subclass_label_confidence'] = (adata.obs['integration_partition_confidence'] 
                                                        * adata.obs['subclass_label_confidence'])
adata.obs['adjusted_cl_confidence'] = adata.obs['integration_partition_confidence'] * adata.obs['cl_confidence']


adata.write_h5ad(os.path.join(workspace_path, 'adata_merfish_label_transfer.h5ad'), compression='gzip')
adata.obs.to_csv(os.path.join(workspace_path, 'adata_merfish_label_transfer_metadata.csv'))

In [9]:
adata.write_h5ad('ABC_transferred.h5ad')