In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import pandas as pd

In [None]:
class BraTSDataset(Dataset):
    def __init__(self, source_folder: [str, Path], transform=None):
        if isinstance(source_folder, str):
            source_folder = Path(source_folder)
            
        self.images = sorted(list(source_folder.glob('*')))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        
        file_name = self.images[i].stem.split('_')[-1]
        slices = list(self.images[i].glob('*'))
        slices = [s.stem for s in slices if 'mask' not in s.stem]
        np.random.shuffle(slices)
        j = slices[0]
        
        image = np.load(self.images[i] / f'{j}.npy', allow_pickle=True)
        mask = np.load(self.images[i] / f'{j}_mask.npy', allow_pickle=True)
        sample = image, mask
        
        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
data = BraTSDataset(data_folder)

In [None]:
len(data)

In [None]:
%%timeit

data[100]

In [None]:
import os
from tqdm.notebook import tqdm

## Create a csv file with important metadata

In [None]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
df = []

for path, _, files in tqdm(os.walk(data_folder)):
    for file in files:
        
        subject_id = path.split('/')[-1].split('_')[-1]
        slice_id = file.split('.')[0].split('_')[0]
        sample_id = f"{subject_id}_{slice_id}" # SubjectID_SliceIndex
        is_mask = 'mask' in file
        if is_mask:
            mask = np.load(Path(path) / file, allow_pickle=True)
            is_nonzero_mask =  np.any(mask)
        else:
            is_nonzero_mask = np.nan
        
        df.append([Path(Path(path).stem) / file, sample_id, is_mask, subject_id, is_nonzero_mask])
        
df = pd.DataFrame(df, columns = ['relative_path', 'sample_id', 'is_mask', 'subject_id', 'is_nonzero_mask'])
print(df.is_nonzero_mask.value_counts())

df.to_csv(data_folder / 'meta.csv')

> Важное преимущество такого метода в том что вы убираете из класса описывающего ваш датасет
низкоуровневую работу со структурой ваших папок на диске, у вас просто есть таблица, в которой для каждого
файла указан путь к нему, и набор его id полей (например номер слайса и номер пациента, но могут быть и другие поля).

In [None]:
df = pd.read_csv(data_folder / 'meta.csv', index_col=0)

In [None]:
df.head()

## Edit BratSDataset class

In [None]:
class BraTSDataset(Dataset):
    def __init__(self, meta: pd.DataFrame, source_folder: [str, Path], transform=None):
        if isinstance(source_folder, str):
            source_folder = Path(source_folder)
            
        self.source_folder = source_folder
        self.meta_images = meta.query('is_mask == False').sort_values(by='sample_id').reset_index(drop=True)
        self.meta_masks = meta.query('is_mask == True').sort_values(by='sample_id').reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return self.meta_images.shape[0]

    def __getitem__(self, i):
        image = np.load(self.source_folder / self.meta_images.iloc[i]['relative_path'], allow_pickle=True)
        mask = np.load(self.source_folder / self.meta_masks.iloc[i]['relative_path'], allow_pickle=True)
        sample = image, mask
        
        if self.transform:
            image, mask = self.transform(sample)

        return torch.from_numpy(image).resize(1, 240, 240), torch.from_numpy(mask)