In [1]:
from typing import Callable, Optional, Tuple, Union

import h5py
import hdf5plugin  # не удалять!
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

In [2]:
meta_file = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/metadata.csv'
df_meta = pd.read_csv(meta_file, index_col='cell_id')

In [3]:
df_meta.head(10)

Unnamed: 0_level_0,day,donor,cell_type,technology
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
c2150f55becb,2,27678,HSC,citeseq
65b7edf8a4da,2,27678,HSC,citeseq
c1b26cb1057b,2,27678,EryP,citeseq
917168fa6f83,2,27678,NeuP,citeseq
2b29feeca86d,2,27678,EryP,citeseq
0fd801488185,2,27678,EryP,citeseq
526647a698f8,2,27678,HSC,citeseq
ab8f207a3dec,2,27678,HSC,citeseq
57f730249c87,2,27678,MasP,citeseq
08df3dcce25c,2,27678,HSC,citeseq


------------------------------------------------------------------------------

In [4]:
class FlowDataset(Dataset):
    _reserve_name: str = "block0_values"

    def __init__(self,
                 features_file: str,
                 targets_file: Optional[str] = None,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 device: Optional[Union[str, torch.device]] = None):
        self.features, self.features_names, self.features_shape = self.get_hdf5_flow(features_file)
        if targets_file:
            self.targets, self.targets_names, self.targets_shape = self.get_hdf5_flow(targets_file)
            assert self.targets_shape[0] == self.features_shape[0]
        else:
            self.targets = targets_file

        self.device = device
        self.transform = transform
        self.target_transform = target_transform

    def get_hdf5_flow(self, file_path: str):
        file_flow = h5py.File(file_path, 'r')
        col_names = list(file_flow[list(file_flow.keys())[0]])
        assert self._reserve_name in col_names

        lines, features_shape = file_flow[list(file_flow.keys())[0]]["block0_values"].shape
        data_flow = file_flow[list(file_flow.keys())[0]]

        return data_flow, col_names, (lines, features_shape)

    def __len__(self):
        return len(self.features['axis1'])

    def __getitem__(self, item: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if self.targets is not None:
            x = torch.from_numpy(self.features['block0_values'][item])
            y = torch.from_numpy(self.targets['block0_values'][item])

            if self.transform:
                x = self.transform(x)
            if self.target_transform:
                y = self.target_transform(y)
            if self.device:
                x, y = x.to(self.device), y.to(self.device)

            return x, y
        else:
            x = torch.from_numpy(self.features['block0_values'][item])

            if self.transform:
                x = self.transform(x)
            if self.device:
                x = x.to(self.device)

            return x

In [5]:
mx = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_multi_inputs.h5'
my = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_multi_targets.h5'
cx = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_cite_inputs.h5'
cy = '/home/mks/PycharmProjects/multimodal_single_cell_integration/dataset/train_cite_targets.h5'

In [6]:
train_m = FlowDataset(features_file=mx, targets_file=my)
train_c = FlowDataset(features_file=cx, targets_file=cy)

In [7]:
loader_m = DataLoader(train_m, batch_size=2, shuffle=False)
loader_c = DataLoader(train_c, batch_size=2, shuffle=False)

In [8]:
x, y = next(iter(loader_m))
print(f"Feature batch shape: {x.size()}")
print(f"Labels batch shape: {y.size()}")

Feature batch shape: torch.Size([2, 228942])
Labels batch shape: torch.Size([2, 23418])


In [9]:
x, y = next(iter(loader_c))
print(f"Feature batch shape: {x.size()}")
print(f"Labels batch shape: {y.size()}")

Feature batch shape: torch.Size([2, 22050])
Labels batch shape: torch.Size([2, 140])


In [10]:
df_meta[(df_meta['day'] == 2)].shape[0]

62250

In [11]:
df_meta[(df_meta['day'] == 2) & (df_meta['technology'] == 'citeseq')].shape[0]

29418

In [12]:
df_meta[(df_meta['day'] == 2) & (df_meta['technology'] == 'multiome')].shape[0]

32832

In [13]:
29418 + 32832

62250

In [14]:
z = train_m.features

In [15]:
z.keys()

<KeysViewHDF5 ['axis0', 'axis1', 'block0_items', 'block0_values']>

In [16]:
z['axis0'][:3]

array([b'GL000194.1:114519-115365', b'GL000194.1:55758-56597',
       b'GL000194.1:58217-58957'], dtype='|S26')

In [17]:
z['axis1'][3:7]

array([b'81cccad8cd81', b'15cb3d85c232', b'a7791bcf1152', b'072790e768b1'],
      dtype='|S12')

In [18]:
z['block0_items'][3:7]

array([b'GL000194.1:59535-60431', b'GL000195.1:119766-120427',
       b'GL000195.1:120736-121603', b'GL000195.1:137437-138345'],
      dtype='|S26')

In [19]:
q = train_c.features

In [20]:
q['axis1'][3:7]

array([b'ba7f733a4f75', b'fbcf2443ffb2', b'd80d84ca8e89', b'1ac2049b4c98'],
      dtype='|S12')

In [21]:
df_meta.index[:10]

Index(['c2150f55becb', '65b7edf8a4da', 'c1b26cb1057b', '917168fa6f83',
       '2b29feeca86d', '0fd801488185', '526647a698f8', 'ab8f207a3dec',
       '57f730249c87', '08df3dcce25c'],
      dtype='object', name='cell_id')

In [24]:
train_c.features['axis1'][5].decode("utf-8")

'd80d84ca8e89'