In [1]:
import json
import logging
import os
import re
from pathlib import Path

import anndata
import numpy as np
import pandas as pd
import scanpy as sc

In [3]:
import pollock.utils as utils

In [27]:
out_dir = '/data/pollock/benchmarking/pollock_datasets_with_folds'

In [19]:
data_dir = '/data/pollock/benchmarking/pollock_datasets/'
fps = utils.listfiles(data_dir, regex=r'.h5ad')
fps = [fp for fp in fps
       if 'train' not in fp
       if 'val' not in fp]

fmap = {}
for fp in fps:
    dtype = fp.split('/')[-2]
    root = fp.split('/')[-1]
    if '_' not in root:
        disease = root.split('.')[0]
    else:
        disease = root.split('_')[0]
    
    if dtype not in fmap:
        fmap[dtype] = {}
    
    if not (dtype == 'snATACseq' and ('peaks' in fp or 'motif' in fp)):
        fmap[dtype][disease] = fp
fmap

{'snRNAseq': {'ccrcc': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/ccrcc.h5ad',
  'brca': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/brca.h5ad',
  'gbm': '/data/pollock/benchmarking/pollock_datasets/snRNAseq/gbm.h5ad'},
 'snATACseq': {'brca': '/data/pollock/benchmarking/pollock_datasets/snATACseq/brca_gene_activity.h5ad',
  'gbm': '/data/pollock/benchmarking/pollock_datasets/snATACseq/gbm_gene_activity.h5ad',
  'ccrcc': '/data/pollock/benchmarking/pollock_datasets/snATACseq/ccrcc_gene_activity.h5ad'},
 'scRNAseq': {'cesc': '/data/pollock/benchmarking/pollock_datasets/scRNAseq/cesc.h5ad',
  'melanoma': '/data/pollock/benchmarking/pollock_datasets/scRNAseq/melanoma.h5ad',
  'brca': '/data/pollock/benchmarking/pollock_datasets/scRNAseq/brca.h5ad',
  'hnscc': '/data/pollock/benchmarking/pollock_datasets/scRNAseq/hnscc.h5ad',
  'myeloma': '/data/pollock/benchmarking/pollock_datasets/scRNAseq/myeloma.h5ad',
  'pdac': '/data/pollock/benchmarking/pollock_datasets/scRNAseq/

In [29]:
for dtype, d in fmap.items():
    for disease, fp in d.items():
        a = sc.read_h5ad(fp)
        out = os.path.join(out_dir, dtype)
        Path(out).mkdir(parents=True, exist_ok=True)
        for i in range(5):
            print(dtype, disease, i)
            train_ids, rest = utils.get_splits(a, 'cell_type', 500, oversample=False, split=.8)
            val_ids, _ = utils.get_splits(a[rest], 'cell_type', 500, oversample=False, split=1.)
            train, val = a[train_ids], a[val_ids]
            train.write_h5ad(os.path.join(out, f'{disease}_fold{i}_train.h5ad'))
            val.write_h5ad(os.path.join(out, f'{disease}_fold{i}_val.h5ad'))

snRNAseq ccrcc 0
snRNAseq ccrcc 1
snRNAseq ccrcc 2
snRNAseq ccrcc 3
snRNAseq ccrcc 4
snRNAseq brca 0
snRNAseq brca 1
snRNAseq brca 2
snRNAseq brca 3
snRNAseq brca 4
snRNAseq gbm 0
snRNAseq gbm 1
snRNAseq gbm 2
snRNAseq gbm 3
snRNAseq gbm 4
snATACseq brca 0
snATACseq brca 1
snATACseq brca 2
snATACseq brca 3
snATACseq brca 4
snATACseq gbm 0
snATACseq gbm 1
snATACseq gbm 2
snATACseq gbm 3
snATACseq gbm 4
snATACseq ccrcc 0
snATACseq ccrcc 1
snATACseq ccrcc 2
snATACseq ccrcc 3
snATACseq ccrcc 4
scRNAseq cesc 0
scRNAseq cesc 1
scRNAseq cesc 2
scRNAseq cesc 3
scRNAseq cesc 4
scRNAseq melanoma 0
scRNAseq melanoma 1
scRNAseq melanoma 2
scRNAseq melanoma 3
scRNAseq melanoma 4
scRNAseq brca 0
scRNAseq brca 1
scRNAseq brca 2
scRNAseq brca 3
scRNAseq brca 4
scRNAseq hnscc 0
scRNAseq hnscc 1
scRNAseq hnscc 2
scRNAseq hnscc 3
scRNAseq hnscc 4
scRNAseq myeloma 0
scRNAseq myeloma 1
scRNAseq myeloma 2
scRNAseq myeloma 3
scRNAseq myeloma 4
scRNAseq pdac 0
scRNAseq pdac 1
scRNAseq pdac 2
scRNAseq pdac 3
s

In [20]:
adata = sc.read_h5ad(fmap['scRNAseq']['brca'])
adata

AnnData object with n_obs × n_vars = 98564 × 27131
    obs: 'cell_type', 'barcode', 'sample'

In [24]:
train_ids, rest = utils.get_splits(adata, 'cell_type', 500, oversample=False, split=.8)
val_ids, _ = utils.get_splits(adata[rest], 'cell_type', 500, oversample=False, split=1.)

In [25]:
from collections import Counter
Counter(adata[train_ids].obs['cell_type']).most_common()

[('CD8 T cell', 500),
 ('Endothelial', 500),
 ('Fibroblast', 500),
 ('Malignant', 500),
 ('NK', 500),
 ('Monocyte', 500),
 ('Treg', 500),
 ('CD4 T cell', 500),
 ('B cell', 500),
 ('Plasma', 500),
 ('Mast', 500),
 ('Dendritic', 473),
 ('Erythrocyte', 412)]

In [26]:
Counter(adata[val_ids].obs['cell_type']).most_common()

[('CD8 T cell', 500),
 ('Endothelial', 500),
 ('Fibroblast', 500),
 ('Malignant', 500),
 ('NK', 500),
 ('Monocyte', 500),
 ('CD4 T cell', 500),
 ('Treg', 500),
 ('B cell', 500),
 ('Plasma', 500),
 ('Mast', 246),
 ('Dendritic', 119),
 ('Erythrocyte', 103)]