In [1]:
import scvelo as scv
import scanpy as sc
from sklearn.metrics import pairwise_distances
import numpy as np
from sklearn.model_selection import StratifiedKFold
from pathlib import Path
import utils
import importlib
import pandas as pd
import anndata as ad
importlib.reload(utils)
import gc

In [2]:
DATA_DIR = Path("/root/autodl-tmp/dataset") 
DATASET = "repro_mouse_schiebinger"
K_FOLD = 20
CLUSTER_KEY = 'day'
SEED = 1234

In [3]:
(DATA_DIR / DATASET / "raw").mkdir(parents=True, exist_ok=True)
SAVE_DATA = True
if SAVE_DATA:
    (DATA_DIR / DATASET / "processed").mkdir(parents=True, exist_ok=True)

## Load the full dataset

In [4]:
adata = sc.read_h5ad(DATA_DIR / DATASET / "raw" / f"{DATASET}.h5ad")
adata = adata[~pd.isna(adata.obs['day']), :]

In [5]:
loom_list = []
names = pd.read_csv(DATA_DIR / DATASET / "raw" / "repro_name.csv")
for i in range(names.shape[0]):
    ldata = sc.read(DATA_DIR / DATASET / "raw" / f"run_{names['new_name'][i]}.loom")
    ldata.obs_names = [
            f"{names['old_name'][i]}_" + name.split(":")[1][:-1] + f"-1"
            for name in ldata.obs_names
        ]
    ldata.var_names_make_unique()
    loom_list.append(ldata)

In [None]:
batch_key = names['old_name'].values
loom_merged = ad.concat(loom_list, axis=0, join="outer", label="batch", keys=batch_key)
loom_merged.write(DATA_DIR / DATASET / "raw" / "loom_merged.h5ad")

In [None]:
loom_merged = sc.read_h5ad(DATA_DIR / DATASET / "raw" / "loom_merged.h5ad")
shared_obs_names = loom_merged.obs_names.intersection(adata.obs_names)
shared_var_names = loom_merged.var_names.intersection(adata.var_names)
loom_sub = loom_merged[shared_obs_names, shared_var_names].copy()
adata = adata[shared_obs_names, shared_var_names].copy()

In [None]:
adata.layers['spliced'] = loom_sub.layers['spliced']
adata.layers['unspliced'] = loom_sub.layers['unspliced']

In [None]:
adata_serum = adata[adata.obs['serum']=='True']
adata_serum.write(DATA_DIR / DATASET / "raw" / "merged_only_serum.h5ad")

## Split the data

In [None]:
sub_adata_lst = utils.split_anndata_stratified(adata_serum, n_splits=K_FOLD, cluster_key=CLUSTER_KEY)

## Preprocessing each sub adata

In [None]:
adata.layers['raw_spliced'] = adata.layers['spliced']
adata.layers['raw_unspliced'] = adata.layers['unspliced']
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)
if adata.n_vars < 2000:
    sc.pp.highly_variable_genes(adata, n_top_genes=adata.n_vars, subset=True)
if 'X_pca' in adata.obsm:
    del adata.obsm['X_pca']
    del adata.uns['pca']
if "neighbors" in adata.uns.keys():
    del adata.uns['neighbors']
scv.pp.moments(adata, n_neighbors=30, n_pcs=30)
utils.fill_in_neighbors_indices(adata)
sc.tl.umap(adata, random_state=SEED)
adata.obs['u_lib_size_raw'] = adata.layers['raw_unspliced'].toarray().sum(-1) 
adata.obs['s_lib_size_raw'] = adata.layers['raw_spliced'].toarray().sum(-1)
if SAVE_DATA:
    adata.write_h5ad(DATA_DIR / DATASET / "processed" / f"adata_preprocessed_full.h5ad")
del adata
gc.collect()

for i in range(5):
    sub_adata = sub_adata_lst[i].copy()
    sub_adata.layers['raw_spliced'] = sub_adata.layers['spliced']
    sub_adata.layers['raw_unspliced'] = sub_adata.layers['unspliced']
    scv.pp.filter_and_normalize(sub_adata, min_shared_counts=20, n_top_genes=2000)
    if sub_adata.n_vars < 2000:
        sc.pp.highly_variable_genes(sub_adata, n_top_genes=sub_adata.n_vars, subset=True)
    if 'X_pca' in sub_adata.obsm:
        del sub_adata.obsm['X_pca']
    if 'pca' in sub_adata.uns:
        del sub_adata.uns['pca']
    if "neighbors" in sub_adata.uns.keys():
        del sub_adata.uns['neighbors']
    scv.pp.moments(sub_adata, n_neighbors=30, n_pcs=30)
    utils.fill_in_neighbors_indices(sub_adata)
    sc.tl.umap(sub_adata, random_state=SEED)
    sub_adata.obs['u_lib_size_raw'] = sub_adata.layers['raw_unspliced'].toarray().sum(-1) 
    sub_adata.obs['s_lib_size_raw'] = sub_adata.layers['raw_spliced'].toarray().sum(-1)
    if SAVE_DATA:
        sub_adata.write_h5ad(DATA_DIR / DATASET / "processed" / f"adata_preprocessed_{i}.h5ad")
    del sub_adata
    gc.collect()

In [None]:
adata = sc.read(DATA_DIR / DATASET / "processed" / f"adata_preprocessed_{i}.h5ad")
sc.pl.umap(adata, color=["day"])