This code, and the following two notebooks, are largely inspired by the Allen Brain Cell Atlas's
MERFISH atlas. The code used was sourced and adapted from this github: https://github.com/ZhuangLab/whole_mouse_brain_MERFISH_atlas_scripts_2023/blob/main/scripts/integrate_MERFISH_with_scRNA-seq/create_integration_workspace.ipynb

Downloading the Allen Brain Cell Atlas: 20230521 release https://data.nemoarchive.org/other/grant/aibs_internal/zeng/transcriptome/scell/10x_v3/mouse/processed/counts/

In [None]:
import os
import gc
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
import igraph
import networkx as nx
import metis
import ALLCools
from ALLCools.integration.seurat_class import SeuratIntegration
from scipy.sparse import csr_matrix
import scipy.sparse

In [None]:
adata_file = 'AIT17.0.rawcount_logCPM_10Xv3/AIT17.0.rawcount_logCPM_10Xv3.h5ad'
annotation_file = 'AIT17.0.cl.df.v6_lock_230504/AIT17.0.cl.df.v6_lock_230504.tsv'

annotation_df = pd.read_csv(annotation_file, sep='\t')
annotation_df['cl'] = annotation_df['cl'].astype(str)
subclass_map = {t[0] : t[1] for t in annotation_df[['cl', 'subclass_label']].itertuples(index=False)}

In [None]:
adata = sc.read_h5ad(adata_file)
adata.obs['subclass_label'] = adata.obs['cl'].map(subclass_map)

# Filter out cells without subclass annotations or annotated as low quality
mask_not_na = ~adata.obs['subclass_label'].isna()
mask_not_lq = adata.obs['subclass_label'] != 'LQ'
adata = adata[mask_not_na & mask_not_lq].copy()
gc.collect()

adata_raw = anndata.AnnData(X=adata.layers['rawcount'], obs=adata.obs, var=adata.var)
adata_raw.write_h5ad('AIT17_10Xv3.h5ad', compression='gzip')

In [None]:
# Workspace creation
def partition_subclasses(adata, n_partitions):
    '''Assign integration partitions as the adata.obs['integration_partion'] column.'''
    adata.obs['integration_partition'] = 'p0'
    
    # Get a sparse matrix of that counts edges between subclasses
    ones = adata.obsp['distances'].copy()
    ones.data = np.ones(len(ones.data))

    g_neighobrs = sc._utils.get_igraph_from_adjacency(ones, directed=True)
    vc = igraph.VertexClustering(g_neighobrs, membership=adata.obs['subclass_label'].cat.codes.values)
    cluster_graph = vc.cluster_graph(combine_edges='sum')
    #cluster_mtx = sc._utils.get_sparse_from_igraph(vc.cluster_graph(combine_edges='sum'), weight_attr='weight')
    edges = cluster_graph.get_edgelist()
    weights = [e['weight'] for e in cluster_graph.es]
    num_clusters = len(vc)
    row_indices, col_indices = zip(*edges)
    cluster_mtx = csr_matrix((weights, (row_indices, col_indices)), shape=(num_clusters, num_clusters))

    
    # Partition the cluster level graph
    G_cluster = nx.from_scipy_sparse_array(cluster_mtx)
    for i in adata.obs['subclass_label'].cat.codes.values:
        G_cluster.nodes[i]['weight'] = np.sum(adata.obs['subclass_label'] 
                                              == adata.obs['subclass_label'].cat.categories[i] )
    G_cluster.graph['node_weight_attr'] = 'weight'
    
    (cut, parts) = metis.part_graph(G_cluster, n_partitions, recursive=False,
                                   tpwgts=[1 / n_partitions] * n_partitions) 
    
    # Assign the partitions
    for i in range(len(parts)):
        mask = (adata.obs['subclass_label'] == adata.obs['subclass_label'].cat.categories[i])
        adata.obs.loc[mask, 'integration_partition'] = 'p' + str(parts[i])
    
    adata.obs['integration_partition'] = adata.obs['integration_partition'].astype('category')
    
def get_cluster_mean_expression_matrix(adata, cluster_column):
    '''Get a dataframe of mean gene expression of each cluster.'''
    if scipy.sparse.issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X
   
    cell_exp_mtx = pd.DataFrame(X, index=adata.obs[cluster_column], columns=adata.var.index)    
    return cell_exp_mtx.groupby(by=cluster_column).mean()

def calculate_correlation_matrix(adata, cluster_column):
    '''Calculate gene expression correlations between clusters.'''
    cluster_mean_exp = get_cluster_mean_expression_matrix(adata, cluster_column)
    
    # Initialzie the correlation matrix
    cluster_ids = np.array(cluster_mean_exp.index)
    N_clusters = len(cluster_ids)
    correlation_mtx = pd.DataFrame(np.ones((N_clusters, N_clusters)), index=cluster_ids, columns=cluster_ids)
    
    # Fill the correlation matrix
    for i in range(N_clusters):
        cluster_id1 = cluster_ids[i]
        mean_exps1 = np.array(cluster_mean_exp.loc[cluster_id1])
        
        for j in range(i + 1, N_clusters):
            cluster_id2 = cluster_ids[j]
            mean_exps2 = np.array(cluster_mean_exp.loc[cluster_id2])
            
            r, p = scipy.stats.pearsonr(mean_exps1, mean_exps2)

            correlation_mtx.loc[cluster_id1, cluster_id2] = r
            correlation_mtx.loc[cluster_id2, cluster_id1] = r
            
    return correlation_mtx

def merge_clusters(obs_df, cluster_col, clusters_to_merge):
    merged_cluster_id = sorted(clusters_to_merge)[0]
    obs_df[cluster_col][obs_df[cluster_col].isin(clusters_to_merge)] = merged_cluster_id
    obs_df[cluster_col] = obs_df[cluster_col].cat.remove_unused_categories()

In [None]:
workspace_path = 'integration_workspace'
input_seq_file = 'AIT17_10Xv3.h5ad'

# Make the top level directories
os.makedirs(workspace_path, exist_ok=True)
partition_path = os.path.join(workspace_path, 'partitions')
os.makedirs(partition_path, exist_ok=True)
os.makedirs(os.path.join(workspace_path, 'gene_expression_imputation'), exist_ok=True)

%%time
# Load the sequencing data
adata_seq = sc.read_h5ad(input_seq_file)

adata_seq.obs['subclass_label'] = adata_seq.obs['subclass_label'].astype('category')
adata_seq.obs['cl'] = adata_seq.obs['cl'].astype('category')

# Load the merfish data
adata_merfish = sc.read_h5ad('FINAL_ANNOTATION.h5ad')

# Rename the genes with different synonyms names
adata_merfish.var.rename(index={
    'AC102910.1': 'Gm30564', 
    'BC030499': 'Rskr',
    'Ctgf': 'Ccn2',
    'Fam196b': 'Insyn2b',
    'Fam19a1': 'Tafa1',
    'Fam19a2': 'Tafa2',
    'Fam19a4': 'Tafa4',
    'Fam46a': 'Tent5a',
    'Fam84b': 'Lratd2',
    'Nov': 'Ccn3',
    'Wisp1': 'Ccn4',
}, inplace=True)

In [None]:
# Consider the common genes
common_genes = np.array(adata_seq.var_names.intersection(adata_merfish.var_names))
adata_merfish = adata_merfish[:, adata_merfish.var.index.isin(common_genes)]
adata_merfish.write_h5ad(os.path.join(workspace_path, 'adata_merfish.h5ad'), compression='gzip')

adata_seq._inplace_subset_var(common_genes)

# Remove the cells without cluster or subclass labels
adata_seq = adata_seq[~adata_seq.obs['cl'].isna()]
adata_seq = adata_seq[~adata_seq.obs['subclass_label'].isna()]

In [None]:
%%time
# Define the integration partitions using the sequencing data
n_partitions = 80

sc.pp.normalize_total(adata_seq, target_sum=1000)
sc.pp.log1p(adata_seq)
sc.pp.scale(adata_seq)

n_pcs=100
sc.tl.pca(adata_seq, svd_solver='arpack', n_comps=n_pcs)
sc.pp.neighbors(adata_seq, use_rep='X_pca', n_neighbors=15, n_pcs=n_pcs)

partition_subclasses(adata_seq, n_partitions)

#Merge partitions with too few cells
partition_corr_df = calculate_correlation_matrix(adata_seq, 'integration_partition')
p_names, p_n_cells = np.unique(adata_seq.obs['integration_partition'], return_counts=True)
p_smalls = p_names[p_n_cells < 10000]
p_larges = p_names[p_n_cells >= 10000]
partition_corr_df = partition_corr_df.loc[p_smalls, p_larges]

for p1 in p_smalls:

    p2 = partition_corr_df.columns[np.argmax(partition_corr_df.loc[p1])]
    
    print(f'Merge partitions {p1} and {p2}')
    merge_clusters(adata_seq.obs, 'integration_partition', [p1, p2])
    
seq_annotation_df = adata_seq.obs.copy()

In [None]:
%%time
# Reload the sequencing data and assign the partitions
adata_seq = sc.read_h5ad(input_seq_file)
adata_seq.obs['integration_partition'] = seq_annotation_df['integration_partition']

adata_seq.obs['subclass_label'] = adata_seq.obs['subclass_label'].astype('category')
adata_seq.obs['cl'] = adata_seq.obs['cl'].astype('category')

adata_seq = adata_seq[~adata_seq.obs['cl'].isna()]
adata_seq = adata_seq[~adata_seq.obs['subclass_label'].isna()]

In [None]:
seq_annotation_df.to_csv('seq_annotation_df.csv')

In [None]:
seq_annotation_df# Split the sequencing data for the input of the final round of integration.
# Because imputation is done at this round, all genes are included.
partitions = adata_seq.obs['integration_partition'].cat.categories

print('Partition, N_cells, subclasses')
for pn in partitions:
    adata_subset = adata_seq[adata_seq.obs['integration_partition'] == pn]
    p = os.path.join(partition_path, pn.replace('/', '-').replace(' ', '_'))
    os.makedirs(p, exist_ok=True)
    adata_subset.write_h5ad(os.path.join(p, 'adata_seq.h5ad'), compression='gzip')
    
    print(f'{pn}, {adata_subset.shape[0]}, {list(np.unique(adata_subset.obs["subclass_label"]))}')

In [None]:
del adata_seq, adata_merfish

In [None]:
adata_seq_cg_list = []
for pn in os.listdir(partition_path):
    print(pn)
    adata_seq = sc.read_h5ad(os.path.join(partition_path, pn, 'adata_seq.h5ad'))
    adata_seq = adata_seq[:, adata_seq.var.index.isin(common_genes)].copy()
    adata_seq_cg_list.append(adata_seq)
    
adata_seq_cg = anndata.concat(adata_seq_cg_list)
adata_seq_cg.write_h5ad(os.path.join(workspace_path, 'adata_seq_common_genes.h5ad'), compression='gzip')