In [1]:
from typing import Callable, Dict, List, Optional, Tuple, Union

import h5py
# не удалять! import hdf5plugin !
import hdf5plugin
import pandas as pd
import torch
from torch.utils.data import Dataset

---------------------------------------------------

In [2]:
meta_file = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/metadata.csv'
df_meta = pd.read_csv(meta_file, index_col='cell_id')

In [3]:
df_meta.head(10)

Unnamed: 0_level_0,day,donor,cell_type,technology
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
e0dde41ed6f2,3,27678,MasP,citeseq
25b1de7f18f6,3,27678,MkP,citeseq
59e175749a4c,3,27678,MkP,citeseq
cc43f415f240,3,27678,NeuP,citeseq
cf6cb48a1aca,3,27678,HSC,citeseq
7d03cdc2150c,3,27678,EryP,citeseq
ed27b16f6b29,3,27678,NeuP,citeseq
20a5293b5a5f,3,27678,NeuP,citeseq
9c110ee995b5,3,27678,HSC,citeseq
655fb0bf81df,3,27678,HSC,citeseq


In [15]:
ids = list(df_meta[df_meta['technology'] == 'citeseq'].index)
print(len(ids))
print(ids[:10])

119191
['e0dde41ed6f2', '25b1de7f18f6', '59e175749a4c', 'cc43f415f240', 'cf6cb48a1aca', '7d03cdc2150c', 'ed27b16f6b29', '20a5293b5a5f', '9c110ee995b5', '655fb0bf81df']


------------------------------------------------------------------------------

In [5]:
eval_ids = pd.read_csv('/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/evaluation_ids.csv', index_col='cell_id')

In [6]:
mx = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_multi_inputs.h5'
my = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_multi_targets.h5'
cx = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_cite_inputs.h5'
cy = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_cite_targets.h5'

In [29]:
test_my = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/test_multi_inputs.h5'
test_cy = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/test_cite_inputs.h5'

------------------------------------------------------------------------------

## Изучение структуры

In [8]:
mi_flow = h5py.File(mx, 'r')
col_names = list(mi_flow[list(mi_flow.keys())[0]])

In [9]:
col_names

['axis0', 'axis1', 'block0_items', 'block0_values']

In [12]:
mi_flow[list(mi_flow.keys())[0]]['axis0'].shape

(228942,)

In [24]:
mi_flow[list(mi_flow.keys())[0]]['axis0'][145]

b'chr10:100653097-100653634'

In [13]:
mi_flow[list(mi_flow.keys())[0]]['axis1'].shape

(105942,)

In [17]:
mi_flow[list(mi_flow.keys())[0]]['axis1'][0]

b'56390cf1b95e'

In [14]:
mi_flow[list(mi_flow.keys())[0]]['block0_items'].shape

(228942,)

In [23]:
mi_flow[list(mi_flow.keys())[0]]['block0_items'][145]

b'chr10:100653097-100653634'

In [15]:
mi_flow[list(mi_flow.keys())[0]]['block0_values'].shape

(105942, 228942)

In [19]:
mi_flow[list(mi_flow.keys())[0]]['block0_values'][0]

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [20]:
mi_flow[list(mi_flow.keys())[0]]['block0_values'][0].shape

(228942,)

Итак, файл состоит из **названия участка генома, уникального идентификатора клетки, и большой numpy матрицы** являющейся ATAC данными.
Строчек в numpy массиве столько же сколько уникальных id клеток, таким образом строчка в матрице является вектором фичей конкретной клетки. Мы подразумеваем что порядок совпадает и сделан без ошибок. Так же у нас есть название для каждого столбца в матрице, указывающее на позицию ATAC фичи, но он полностью совпадает со списком позиций. Информация об участке на днк нам сейчас безполезна. Однако ее тоже можно обрабатывать.

**dna_pos:** mi_flow[list(mi_flow.keys())[0]]['axis0'] (список строк в битовом виде)

**cell_id:** mi_flow[list(mi_flow.keys())[0]]['axis1'] (список строк в битовом виде)

**atac_array:** mi_flow[list(mi_flow.keys())[0]]['block0_values']

In [25]:
list(mi_flow.keys())

['train_multi_inputs']

--------------------------------------------------------

In [33]:
my_flow = h5py.File(my, 'r')
print(list(my_flow.keys()))

['train_multi_targets']


In [36]:
col_names = list(my_flow['train_multi_targets'])

In [37]:
col_names

['axis0', 'axis1', 'block0_items', 'block0_values']

In [38]:
my_flow['train_multi_targets']['block0_items'][0]

b'ENSG00000121410'

In [41]:
my_flow['train_multi_targets']['axis0'][0]

b'ENSG00000121410'

--------------------------------------------------------

--------------------------------------------------------

# Test new dataset class

In [1]:
from typing import Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path

import h5py
# не удалять! import hdf5plugin !
import hdf5plugin
import pandas as pd
import torch
from torch.utils.data import Dataset
import numpy as np

In [2]:
class FSCCDataset(Dataset):
    file_types = ['inputs', 'targets']
    h5_reserved_names: List[str] = ['train_multi_inputs', 'train_multi_targets', 'train_cite_inputs',
                                    'train_cite_targets', 'test_multi_inputs', 'test_cite_inputs']

    dataflows = {'cite': {'train': {'inputs': None, 'targets': None},
                          'test': {'inputs': None}},
                 'multi': {'train': {'inputs': None, 'targets': None},
                           'test': {'inputs': None}}}

    metadata = None
    meta_unique_vals: Dict = {}
    metadata_file: str = 'metadata.csv'
    meta_transform_names: List[str] = ['day', 'donor', 'cell_type']
    meta_names: List[str] = ['day', 'donor', 'cell_type', 'technology']
    meta_keys: List[str] = ['cell_id', 'day', 'donor', 'cell_type', 'technology']

    col_name: str = 'axis0'
    pos_name: str = 'position'
    index_name: str = 'cell_id'
    cell_id_name: str = "axis1"
    target_name: str = 'gene_id'
    features_name: str = "block0_values"

    def __init__(self,
                 dataset_path: Union[str, Path],
                 task: str, mode: str,
                 meta_transform: Optional[str] = None,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None):
        self.task = task
        self.mode = mode
        self.data_ids = None
        self.data_shapes = None
        self.dataset_path = dataset_path
        
        self.transform = transform
        self.target_transform = target_transform
        self.meta_transform = meta_transform
        # init dataset
        self._read_task_dataset(dataset_path)

    def _read_metadata(self, path: str) -> pd.DataFrame:
        df = pd.read_csv(path, index_col=self.index_name)
        for key in self.meta_names:
            self.meta_unique_vals[key] = list(df[key].unique())

        return df

    def _transform_metalabels(self, meta_dict: Dict, cell_id: str) -> Dict:
        if self.meta_transform:
            if self.meta_transform == 'index':
                for key in self.meta_transform_names:
                    meta_dict[key] = self.meta_unique_vals[key].index(self.metadata[key][cell_id])
            elif self.meta_transform == 'one_hot':
                for key in self.meta_transform_names:
                    one_hot_vector = np.zeros((len(self.meta_unique_vals[key]),))
                    one_hot_vector[self.meta_unique_vals[key].index(self.metadata[key][cell_id])] = 1
                    meta_dict[key] = one_hot_vector
            else:
                raise ValueError(f"The argument 'meta_transform' can only take values from a list "
                                 f"['index', 'one_hot', None], but '{self.meta_transform}' was found.")
        else:
            meta_dict = {key: self.metadata[key][cell_id] for key in self.meta_names}

        return meta_dict

    def _get_task_flow(self, folder_path: Path, mode: str, task: str, file_type: str) -> None:
        file_name = '_'.join([mode, task, file_type])
        print(f"[ Reading {file_name}.h5 file ... ]")
        f_path = str(folder_path.joinpath(f"{file_name}.h5").absolute())
        flow, feature_shape = self.get_hdf5_flow(f_path)
        # write data in structure
        self.dataflows[task][mode][file_type] = flow
        self.data_shapes[task][mode][file_type] = feature_shape
        print(f"[ Reading {file_name}.h5 file is complete. ]")

    def _read_task_dataset(self, folder_path: Union[str, Path]) -> None:
        self.data_shapes = {self.task: {self.mode: {s: None for s in self.file_types}}}
        
        if isinstance(folder_path, str):
            folder_path = Path(folder_path)
        # read metadata file
        self.metadata = self._read_metadata(str(folder_path.joinpath(self.metadata_file)))
        # read all h5 files
        if self.mode == 'train':
            for file_type in self.file_types:
                self._get_task_flow(folder_path, self.mode, self.task, file_type)
        elif self.mode == 'test':
            self._get_task_flow(folder_path, self.mode, self.task, self.file_types[0])
        else:
            raise ValueError(f"Argument 'mode' can only take values from a list: ['train', 'test'], "
                             f"but {self.mode} was found.")

        self.data_ids = self._set_data_ids()

    def _set_data_ids(self):
        feature_flow = self.dataflows[self.task][self.mode]['inputs']
        return [x.decode("utf-8") for x in feature_flow[self.cell_id_name]]

    def __len__(self):
        return len(self.data_ids)

    def __getitem__(self, item: int) -> Union[Tuple[torch.Tensor, Dict[str, str]],
                                              Tuple[torch.Tensor, torch.Tensor, Dict[str, str]]]:
        cell_id = self.data_ids[item]
        features = self.dataflows[self.task][self.mode]['inputs']
        meta_data = {self.index_name: cell_id, self.pos_name: features[self.col_name][item].decode("utf-8")}
        meta_data = self._transform_metalabels(meta_data, cell_id)

        x = features[self.features_name][item]
        if self.transform:
            x = self.transform(x)

        if self.dataflows[self.task][self.mode].get('targets'):
            targets = self.dataflows[self.task][self.mode]['targets']
            meta_data[self.target_name] = targets[self.col_name][item].decode("utf-8")
            y = targets[self.features_name][item]
            if self.target_transform:
                y = self.target_transform(y)

            return x, y, meta_data
        else:
            return x, meta_data

    def get_hdf5_flow(self, file_path: str):
        file_flow = h5py.File(file_path, 'r')

        file_keys = list(file_flow.keys())
        assert len(file_keys) == 1, AssertionError(f"Incorrect file format, '{file_path}' file have more than one "
                                                   f"group: {file_keys}.")

        file_name = file_keys[0]
        assert file_name in self.h5_reserved_names, \
            AssertionError(f"Incorrect file format, group name must be in {self.h5_reserved_names}, "
                           f"but {file_name} was found.")

        datasets_names = list(file_flow[file_name])
        assert self.features_name in datasets_names, AssertionError(f"Incorrect file format, dataset name "
                                                                    f"{self.features_name} was not found in hdf5 file "
                                                                    f"datasets list.")
        assert self.cell_id_name in datasets_names, AssertionError(f"Incorrect file format, dataset name "
                                                                   f"{self.cell_id_name} was not found in hdf5 file "
                                                                   f"datasets list.")
        assert self.col_name in datasets_names, AssertionError(f"Incorrect file format, dataset name {self.col_name} "
                                                               f"was not found in hdf5 file datasets list.")

        lines, features_shape = file_flow[file_name][self.features_name].shape

        return file_flow[file_name], (lines, features_shape)

    def reindex_dataset(self,
                        day: Optional[Union[int, List[int]]] = None,
                        donor: Optional[Union[int, List[int]]] = None,
                        cell_type: Optional[Union[str, List[str]]] = None) -> None:
        conditions = []
        if (day is not None) and isinstance(day, int):
            conditions.append((self.metadata['day'] == day))
        elif (day is not None) and isinstance(day, list):
            conditions.append((self.metadata['day'].isin(day)))

        if (donor is not None) and isinstance(donor, int):
            conditions.append((self.metadata['donor'] == donor))
        elif (donor is not None) and isinstance(donor, list):
            conditions.append((self.metadata['donor'].isin(donor)))

        if (cell_type is not None) and isinstance(cell_type, int):
            conditions.append((self.metadata['cell_type'] == cell_type))
        elif (cell_type is not None) and isinstance(cell_type, list):
            conditions.append((self.metadata['cell_type'].isin(cell_type)))

        if len(conditions) > 0:
            feature_flow = self.dataflows[self.task][self.mode]['inputs']
            ids = {x.decode("utf-8") for x in feature_flow[self.cell_id_name]}

            final_cond = conditions[0]
            if len(conditions) > 1:
                for cond in conditions[1:]:
                    final_cond &= cond

            cond_index = set(self.metadata[final_cond].index)
            self.data_ids = list(cond_index & ids)

    def reset(self, task: Optional[str] = None, mode: Optional[str] = None):
        if task is not None:
            self.task = task
        if mode is not None:
            self.mode = mode

        self._read_task_dataset(self.dataset_path)
        self.data_ids = self._set_data_ids()


In [3]:
# meta = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/metadata.csv'

# mx = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_multi_inputs.h5'
# my = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_multi_targets.h5'
# test_my = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/test_multi_inputs.h5'

# cx = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_cite_inputs.h5'
# cy = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_cite_targets.h5'
# test_cy = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/test_cite_inputs.h5'
# --------------------------------------------------------------------------------------------------
dataset_folder = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/'

In [16]:
# dataset = SCCDataset(meta_file=meta, features_file=mx, targets_file=my, meta_transform='one_hot')
dataset = FSCCDataset(dataset_folder, 'cite', 'train', meta_transform='one_hot')

[ Reading train_cite_inputs.h5 file ... ]
[ Reading train_cite_inputs.h5 file is complete. ]
[ Reading train_cite_targets.h5 file ... ]
[ Reading train_cite_targets.h5 file is complete. ]


In [17]:
len(dataset)

70988

In [18]:
dataset[100]

(array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 array([-2.22691208e-01, -2.52655208e-01,  2.27418989e-01,  3.20292640e+00,
         8.51944625e-01,  7.98958540e-01,  1.33187664e+00, -2.11750388e-01,
         1.48527831e-01,  8.68058920e-01, -6.94457829e-01,  1.44752157e+00,
         1.12667215e+00,  4.21157300e-01,  1.00356808e+01,  2.43488699e-02,
         5.31223774e-01, -1.04394287e-01,  1.14951277e+00, -1.19618416e-01,
        -1.10979307e+00,  3.15820336e+00,  1.07843137e+00,  1.47565472e+00,
         4.69347954e+00, -5.56258976e-01, -3.09214503e-01, -5.44111073e-01,
         1.09092668e-01,  1.47972658e-01, -5.45077980e-01,  5.77424824e-01,
         3.26725483e-01, -6.73494756e-01, -2.35411733e-01, -4.06153679e-01,
         8.13318908e-01,  5.87703705e+00,  8.13279212e-01,  5.66352129e-01,
        -8.29805374e-01,  1.88042748e+00, -4.18716282e-01,  2.64636803e+00,
        -8.00555229e-01, -4.04821068e-01,  1.34808111e+00,  1.23556578e+00,
        -7.95719445e-01,  6.818538

In [19]:
dataset.meta_unique_vals

{'day': [3, 4, 7, 2, 10],
 'donor': [27678, 32606, 13176, 31800],
 'cell_type': ['MasP', 'MkP', 'NeuP', 'HSC', 'EryP', 'MoP', 'BP', 'hidden'],
 'technology': ['citeseq', 'multiome']}

In [8]:
dataset.meta_names

['day', 'donor', 'cell_type', 'technology']

In [8]:
dataset.data_shapes

{'cite': {'test': {'inputs': (48203, 22050), 'targets': None}}}

In [9]:
dataset.reindex_dataset(day=3, donor=[27678, 32606, 13176])

In [10]:
len(dataset)

6488

In [11]:
dataset.reset('multi', 'train')

[ Reading train_multi_inputs.h5 file ... ]
[ Reading train_multi_inputs.h5 file is complete. ]
[ Reading train_multi_targets.h5 file ... ]
[ Reading train_multi_targets.h5 file is complete. ]


In [12]:
len(dataset)

105942