In [2]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ALLCools.clustering import *
from ALLCools.integration.seurat_class import SeuratIntegration
from ALLCools.plot import *
from wmb import aibs, brain, cemba

In [3]:
categorical_key = ["L4Region", 'CellType',"DissectionRegion"]
ref_dataset = "mc"

## Input LSI before integration

In [4]:
if ref_dataset.lower() == 'mc':
    ref_adata = anndata.read_h5ad("mc_pca.h5ad")
    query_adata = anndata.read_h5ad("merfish_pca.h5ad")
else:
    ref_adata = anndata.read_h5ad("merfish_pca.h5ad")
    query_adata = anndata.read_h5ad("mc_pca.h5ad")

In [5]:
adata_list = [ref_adata, query_adata]

### Init empty adata_merge

In [6]:
mc_annot = cemba.get_mc_annot()

In [7]:
for _key in categorical_key:
    ref_adata.obs[_key] = mc_annot[_key].to_pandas()


In [8]:
from scipy.sparse import csr_matrix

cells = sum([a.shape[0] for a in adata_list])
features = adata_list[0].shape[1]

adata_merge = anndata.AnnData(
    X=csr_matrix((cells, features), dtype=np.float32),
    obs=pd.concat([a.obs for a in adata_list]),
    var=adata_list[0].var,
)

In [9]:
n_pc = adata_list[0].obsm["X_pca"].shape[1]
if n_pc < 10:
    n_cca_components = n_pc
else:
    n_cca_components = max(n_pc - 10, 10)

n_cca_components

40

In [10]:
n_adata_features = adata_merge.shape[1]

n_features = 60
n_features = min(int(n_adata_features * n_cca_components / 11), n_features)

In [11]:
min_sample = adata_merge.obs["Modality"].value_counts().min()

## Integration and transform

In [12]:
integrator = SeuratIntegration()

In [13]:
# take ~2.5-3h for 300K mC + 4M 10X merfish
anchor = integrator.find_anchor(
        adata_list,
        k_local=None,
        key_local="X_pca",
        k_anchor=5,
        key_anchor="X",
        dim_red="cca",
        max_cc_cells=100000,
        k_score=30,
        k_filter=min(100, min_sample),
        scale1=True,
        scale2=True,
        n_components=n_cca_components,
        n_features=n_features,
        alignments=[[[0], [1]]],
    )

Find anchors across datasets.
Run CCA
non zero dims 40
Find Anchors using k=30


  self._set_arrayXarray(i, j, x)


Anchor selected with high CC feature graph: 9934 / 23144
Score Anchors
Identified 9934 anchors between datasets 0 and 1.


In [14]:
if min_sample < 500:
    k_weight = 50
elif min_sample < 300:
    k_weight = 30
else:
    k_weight = 100

In [15]:
try:
    corrected = integrator.integrate(
        key_correct="X_pca",
        row_normalize=True,
        k_weight=k_weight,
        sd=1,
        alignments=[[[0], [1]]],
    )
except BaseException:
    for k_weight in range(50, 0, -5):
        if k_weight < 2:
            raise
        print(k_weight)
        try:
            corrected = integrator.integrate(
            key_correct="X_pca",
            row_normalize=True,
            k_weight=k_weight,
            sd=1,
            alignments=[[[0], [1]]],
            )
        except BaseException:
            pass
adata_merge.obsm["X_pca_integrate"] = np.concatenate(corrected)

Merge datasets
[[0], [1]]
Initialize
Find nearest anchors. 

  data=np.array(corrected),


k_weight:  100
Normalize graph
Transform data


## Label transfer

In [16]:
# transfer_results = integrator.label_transfer(
#     ref=[0],
#     qry=[1],
#     categorical_key=categorical_key,
#     key_dist='X_pca'
# )
# for k, v in transfer_results.items():
#     v.to_hdf(f'{k}_transfer.hdf', key='data')
# integrator.save_transfer_results_to_adata(adata_merge, transfer_results)

## Save

In [17]:
adata_merge.write_h5ad("final.h5ad")

In [18]:
adata_merge

AnnData object with n_obs × n_vars = 60198 × 448
    obs: 'blank_count', 'n_counts', 'n_genes', 'Modality', 'L4Region', 'CellType', 'DissectionRegion'
    var: 'chrom-mC', 'cov_mean-mC', 'end-mC', 'start-mC', 'cef-mC', 'id-merfish'
    obsm: 'X_pca_integrate'

In [19]:
# integrator.save("integration")

In [20]:
import subprocess
subprocess.run(['rm', '-f', 'mc_pca.h5ad', 'merfish_pca.h5ad'])

CompletedProcess(args=['rm', '-f', 'mc_pca.h5ad', 'merfish_pca.h5ad'], returncode=0)