In [1]:
from typing import List, Tuple, Union
import os

import scanpy as sc
import anndata as ad

import torch
from torch_geometric.data import Dataset

In [None]:

class SingleCellDataset(Dataset):
    """ Single cell dataset. 
        root: root directory of dataset
        transform: transform applied to each data object
        pre_transform: transform applied to the whole dataset
        pre_filter: filter applied to the whole dataset
    """
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.filenames = []
        self.processed_file_names = []
        # get all filenames
        for filename in os.listdir(root):
            if filename.endswith('.h5ad'):
                self.filenames.append(filename)
        # process all files
        self.process()        
    
    @property
    def raw_file_names(self):
        return self.filenames
    
    @property
    def processed_file_names(self):
        return self.processed_file_names

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            # Read adata object from 'raw_path'
            adata = sc.read_h5ad(raw_path)

            if self.pre_filter is not None and not self.pre_filter(adata):
                continue

            if self.pre_transform is not None:
                adata = self.pre_transform(adata)

    # pre_transform: transform applied to the whole dataset
    def pre_transform(self, adata):
        # filter out genes with zero expression
        sc.pp.filter_genes(adata, min_counts=1)
        # log transform
        sc.pp.log1p(adata)
        # scale
        sc.pp.scale(adata)
        return adata