In [None]:
import pandas as pd
import scanpy as sc
import anndata
import scanorama
import scipy.sparse as ss
import matplotlib.pyplot as plt
import seaborn as sns
from ALLCools.plot import *
import pathlib
from cemba_data.tools.integration.utilities import calculate_direct_confusion

%matplotlib inline

## Parameters

In [None]:
mc_adata_path = 'Markers/mc.cluster_markers.h5ad'
atac_adata_path = 'Markers/atac.pseudo_cell.cluster_markers.h5ad'

# scanorama
scanorama_dim = 30
sigma = 100
alpha = 0
knn = 20

# clustering
k = 30
n_pcs = 20
n_jobs=40

In [None]:
output_dir = 'Integration'
output_dir = pathlib.Path(output_dir)
output_dir.mkdir(exist_ok=True)

In [None]:
region_palette = pd.read_csv('/home/hanliu/project/mouse_rostral_brain/metadata/palette/dissection_region.palette.csv',
                                header=None, index_col=0, squeeze=True).to_dict()
sub_region_palette = pd.read_csv('/home/hanliu/project/mouse_rostral_brain/metadata/palette/sub_region.palette.csv',
                                header=None, index_col=0, squeeze=True).to_dict()
major_region_palette = pd.read_csv('/home/hanliu/project/mouse_rostral_brain/metadata/palette/major_region.palette.csv',
                                header=None, index_col=0, squeeze=True).to_dict()

cell_class_palette = pd.read_csv('/home/hanliu/project/mouse_rostral_brain/metadata/palette/cell_class.palette.csv',
                                header=None, index_col=0, squeeze=True).to_dict()
major_type_palette = pd.read_csv('/home/hanliu/project/mouse_rostral_brain/metadata/palette/major_type.palette.csv',
                                header=None, index_col=0, squeeze=True).to_dict()
sub_type_palette = pd.read_csv('/home/hanliu/project/mouse_rostral_brain/metadata/palette/sub_type.palette.csv',
                                header=None, index_col=0, squeeze=True).to_dict()

## Load Data

### mC

In [None]:
mc_adata = anndata.read_h5ad(mc_adata_path)
mc_cell_tidy_data = mc_adata.obs.copy()

In [None]:
mc_adata

### atac 

In [None]:
atac_adata = anndata.read_h5ad(atac_adata_path)
atac_cell_tidy_data = atac_adata.obs.copy()

In [None]:
atac_cell_tidy_data.shape

In [None]:
total_df = pd.DataFrame(atac_adata.X.todense(),
                        index=atac_adata.obs_names,
                        columns=atac_adata.var_names)
cluster_center = total_df.groupby(atac_adata.obs['SubType']).mean()

In [None]:
atac_adata = anndata.AnnData(cluster_center.values,
                             obs=pd.DataFrame([], index=cluster_center.index),
                             var=pd.DataFrame([], index=cluster_center.columns))
atac_adata

## Preprocess

### Union

In [None]:
union = atac_adata.var_names & mc_adata.var_names

In [None]:
print(union.size, 'genes in common')

### atac scale

In [None]:
atac_adata = atac_adata[:, union].copy()

In [None]:
sc.pp.scale(atac_adata)

### Reverse mC and scale

In [None]:
mc_adata = mc_adata[:, union].copy()
mc_adata.X = mc_adata.X.max() - mc_adata.X
sc.pp.scale(mc_adata)

## Integration

In [None]:
import scanorama
results = scanorama.correct_scanpy([mc_adata, atac_adata],
                                   metric='angular',
                                   dimred=scanorama_dim,
                                   sigma=sigma,
                                   alpha=alpha,
                                   knn=knn)
mc_adata, atac_adata = results

In [None]:
mc_adata.obs['Modality'] = 'mC'
atac_adata.obs['Modality'] = 'atac'

In [None]:
adata = anndata.AnnData(X=ss.vstack([mc_adata.X, atac_adata.X]),
                       obs=pd.DataFrame([], index=pd.Index(mc_adata.obs_names.tolist() \
                                                           + atac_adata.obs_names.tolist())),
                       var=pd.DataFrame([], index=union))
adata.obs['Modality'] = mc_adata.obs['Modality'].tolist(
) + atac_adata.obs['Modality'].tolist()

## Clustering Routine

In [None]:
sc.tl.pca(adata)

In [None]:
sc.pl.pca(adata,
          components=[
              '1,2', '3,4', '5,6', '7,8', '9,10', '11,12', '13,14', '15,16',
              '17,18', '19,20', '21,22', '23,24', '25,26', '27,28', '29,30',
              '31,32', '33, 34', '35, 36', '37, 38', '39, 40'
          ],
          color='Modality')

In [None]:
sc.pp.neighbors(adata, n_neighbors=k, n_pcs=n_pcs)
sc.tl.leiden(adata)
sc.tl.umap(adata)
sc.tl.tsne(adata, n_jobs=n_jobs)

In [None]:
records = [adata.obs]
for coord in ['umap', 'pca', 'tsne']:
    temp_df = pd.DataFrame(adata.obsm[f'X_{coord}'][:, :2],
                           index=adata.obs_names,
                           columns=[f'{coord}_0', f'{coord}_1'])
    mc_cell_tidy_data[f'agg{coord}_0'] = pd.Series(
        adata.obsm[f'X_{coord}'][:, 0], index=adata.obs_names)
    mc_cell_tidy_data[f'agg{coord}_1'] = pd.Series(
        adata.obsm[f'X_{coord}'][:, 1], index=adata.obs_names)
    atac_cell_tidy_data[f'agg{coord}_0'] = pd.Series(
        adata.obsm[f'X_{coord}'][:, 0], index=adata.obs_names)
    atac_cell_tidy_data[f'agg{coord}_1'] = pd.Series(
        adata.obsm[f'X_{coord}'][:, 1], index=adata.obs_names)
    records.append(temp_df)
cell_tidy_data = pd.concat(records, axis=1)

In [None]:
atac_cell_tidy_data = cell_tidy_data[cell_tidy_data['Modality'] == 'atac'].copy()
atac_cell_tidy_data['SubType'] = atac_cell_tidy_data.index

In [None]:
mc_cell_tidy_data['co_cluster'] = pd.Series(adata.obs['leiden'],
                                            index=adata.obs_names)
atac_cell_tidy_data['co_cluster'] = pd.Series(adata.obs['leiden'],
                                             index=adata.obs_names)

## plot

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)

categorical_scatter(ax=ax,
                    data=mc_cell_tidy_data,
                    hue='MajorType',
                    coord_base='aggumap',
                    max_points=None,
                    scatter_kws=dict(s=5),
                    text_anno='MajorType',
                    text_anno_kws=dict(fontsize=4))
pass

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)

categorical_scatter(ax=ax,
                    data=mc_cell_tidy_data,
                    hue='SubType',
                    coord_base='aggumap',
                    max_points=None,
                    scatter_kws=dict(s=5),
                    text_anno='SubType',
                    text_anno_kws=dict(fontsize=4))
pass

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)

categorical_scatter(ax=ax,
                    data=mc_cell_tidy_data,
                    hue='MajorRegion',
                    coord_base='aggumap',
                    max_points=None,
                    scatter_kws=dict(s=5),
                    text_anno='MajorRegion',
                    text_anno_kws=dict(fontsize=4))
pass

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)

categorical_scatter(ax=ax,
                    data=atac_cell_tidy_data,
                    hue='SubType',
                    coord_base='umap',
                    max_points=None,
                    scatter_kws=dict(s=5),
                    # text_anno='SubType',
                    text_anno_kws=dict(fontsize=4))
pass

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)

_this_data = cell_tidy_data[cell_tidy_data['Modality'] == 'mC'].copy()
_this_data['SubType'] = mc_cell_tidy_data['SubType']
categorical_scatter(ax=ax,
                    data=_this_data,
                    scatter_kws=dict(color='orange'),
                    hue='SubType',
                    #palette=sub_type_palette,
                    coord_base='umap',
                    max_points=None,
                    s=5)

_this_data = cell_tidy_data[cell_tidy_data['Modality'] == 'atac'].copy()
_this_data['SubType'] = atac_cell_tidy_data['SubType'].apply(lambda i: i.split('+')[0])
categorical_scatter(ax=ax,
                    data=_this_data,
                    scatter_kws=dict(color='steelblue'),
                    text_anno='SubType', 
                    hue='SubType',
                    coord_base='umap',
                    max_points=None,
                    s=20)
pass

In [None]:
atac_cell_tidy_data['SubType'] = atac_cell_tidy_data['SubType'].apply(lambda i: i.split('+')[0])

## Confusion matrix

In [None]:
cfm = calculate_direct_confusion(mc_cell_tidy_data[['SubType', 'co_cluster']], 
                                 atac_cell_tidy_data[['SubType', 'co_cluster']])

In [None]:
fig = plt.figure(figsize=(20, 10), dpi=300)

ax = sns.heatmap(cfm, cbar=None)
ax.set_ylim(0, cfm.shape[0])


## Save

In [None]:
adata.write_h5ad(output_dir / 'Integration.h5ad')

In [None]:
cell_tidy_data.to_msgpack(output_dir / 'Integration_cell_tidy_data.msg')
cell_tidy_data.head()

In [None]:
mc_cell_tidy_data.to_msgpack(output_dir / 'mc_cell_tidy_data.with_integration_info.msg')
atac_cell_tidy_data.to_msgpack(output_dir / 'atac_cell_tidy_data.with_integration_info.msg')