In [38]:
from torch.utils.data import Dataset, DataLoader
from utils.readers import VtkReader
import os
import numpy as np
import math
import functools
import collections
import re

In [2]:
main_directory = os.path.dirname(os.getcwd())
general_path =  os.path.join(main_directory, 'patients_data', '{}', 'train')
images_path = general_path.format('MRI_volumes', 'MRI')
masks_path = general_path.format('masks', 'mask')
print(images_path)
print(masks_path)

/home/jupyter/tfm/patients_data/MRI_volumes/train
/home/jupyter/tfm/patients_data/masks/train


In [3]:
patient_id = 8
patient_image_path = os.path.join(images_path, 'MRI_{}.vtk'.format(patient_id))
patient_mask_path = os.path.join(masks_path, 'mask_{}.vtk'.format(patient_id))
print(patient_image_path)
print(patient_mask_path)

/home/jupyter/tfm/patients_data/MRI_volumes/train/MRI_8.vtk
/home/jupyter/tfm/patients_data/masks/train/mask_8.vtk


In [4]:
image = VtkReader(patient_image_path)
mask = VtkReader(patient_mask_path)
print(image.shape)
print(image.shape)

(320, 320, 60)
(320, 320, 60)


In [85]:
class GeneratePatches(Dataset):
    def __init__(self, images_path, masks_path, patch_shape, shuffle_images = False):
        self.images_path = images_path
        self.masks_path = masks_path
        self.patch_shape = patch_shape
        self.shuffle_images = shuffle_images
        
        self.images_files = self.get_files(self.images_path)
        self.masks_files = self.get_files(self.masks_path)
        
        assert len(self.images_files) == len(self.masks_files), 'Num.images is different from num. masks'
        
        self.patients_information = self.get_patients_information()
        self.num_patients = len(self.patients_information)
        self.num_patches = self.calculate_num_patches()
    
    @staticmethod
    def get_files(path):
        return [os.path.join(path, file) for file in os.listdir(path) if file.split('.')[-1] == 'vtk']  

    @staticmethod
    def get_shape_image(file):
        return VtkReader(file).shape
        
    def get_patients_information(self):
        patient_info = collections.namedtuple('patient_info', 'patient_id image_location mask_location shape')
        list_patient_info = []
        for file in self.images_files:
            image_path = file
            patient_id = int(file.split('/')[-1].split('.')[0].split('_')[1])
            pattern = re.compile('mask_{}'.format(patient_id))
            mask_path = list(filter(pattern.search, self.masks_files))[0]
            shape = self.get_shape_image(file)
            list_patient_info.append(patient_info(patient_id, image_path, mask_path, shape))
        return list_patient_info
    
    def calculate_num_patches(self):
        num_patches = float('-inf')
        for image in self.patients_information:
            num = functools.reduce(lambda x, y: x*y, (map(lambda x, y: int(np.ceil(x/y)), image.shape, self.patch_shape)))
            if num > num_patches:
                num_patches = num
        return num_patches    
    
    def __len__(self):
        return self.num_patients * self.num_patches
    
    def __getitem__(self, idx):
        patch_id, image_id = divmod(idx, self.num_patients) 
        print("image_id:"", image_id, "patch_id: ", patch_id)

In [90]:
ddd = GeneratePatches(images_path, masks_path, (320,320, 64))
print(ddd.calculate_num_patches())
print(ddd.get_patients_information())
len(ddd)

1
[patient_info(patient_id=6, image_location='/home/jupyter/tfm/patients_data/MRI_volumes/train/MRI_6.vtk', mask_location='/home/jupyter/tfm/patients_data/masks/train/mask_6.vtk', shape=(320, 320, 60)), patient_info(patient_id=7, image_location='/home/jupyter/tfm/patients_data/MRI_volumes/train/MRI_7.vtk', mask_location='/home/jupyter/tfm/patients_data/masks/train/mask_7.vtk', shape=(320, 320, 64)), patient_info(patient_id=10, image_location='/home/jupyter/tfm/patients_data/MRI_volumes/train/MRI_10.vtk', mask_location='/home/jupyter/tfm/patients_data/masks/train/mask_10.vtk', shape=(320, 320, 60)), patient_info(patient_id=4, image_location='/home/jupyter/tfm/patients_data/MRI_volumes/train/MRI_4.vtk', mask_location='/home/jupyter/tfm/patients_data/masks/train/mask_4.vtk', shape=(320, 320, 60)), patient_info(patient_id=8, image_location='/home/jupyter/tfm/patients_data/MRI_volumes/train/MRI_8.vtk', mask_location='/home/jupyter/tfm/patients_data/masks/train/mask_8.vtk', shape=(320, 320, 

6

In [247]:
ee = np.random.randint(0, 100, (4, 5, 7))
print(ee)

[[[90 11 55 97 89 63 60]
  [38 50  5 17 33 74 44]
  [10 72 53 67 40 65 98]
  [34  9 10 55 15 84 78]
  [95 80 53 27 36 12 56]]

 [[ 1 28 27 17 54 86 55]
  [24 21 41 31 73  5 87]
  [35 68  4 68 52 54 56]
  [68 52  3 87  9 87 75]
  [36 16 68 94 23 81 51]]

 [[49 91 39 13 50 82 52]
  [ 7 74 35 47 69 81 88]
  [12 42 90  9 26 41 61]
  [82 83 45 12 29  8 49]
  [64 17 84 94 11  6  2]]

 [[92 38 40 80 80 15 47]
  [74 59 28 53 10 32 19]
  [66 68 46 89 11 10 37]
  [35  4 57 74 34 62 68]
  [52 32  7 42 17 51 74]]]


In [250]:
patch_shape = (2, 2, 2)
num_patches = functools.reduce(lambda x, y: x*y, (map(lambda x, y: int(np.ceil(x/y)), ee.shape, patch_shape)))
print("Num patches", num_patches)
patch_id = 0
num_channels, num_rows, num_cols = ee.shape
total_tensors_row = int(np.ceil(num_cols/patch_shape[2]))
total_tensors_col = int(np.ceil(num_rows/patch_shape[1]))

total_tensors_channel = total_tensors_row * total_tensors_col

start_row, start_col = divmod(patch_id%total_tensors_channel, total_tensors_row) 

start_row = min(patch_shape[1] * start_row, ee.shape[1])
end_row = min(start_row + patch_shape[1], ee.shape[1])

start_col = min(patch_shape[2]*start_col, ee.shape[2])
end_col = min(start_col + patch_shape[2], ee.shape[2])

start_channel, _ =  divmod(patch_id, total_tensors_channel)
start_channel = min(start_channel * patch_shape[0], ee.shape[0]) 
end_channel = min(start_channel + patch_shape[0], ee.shape[0])

ee[start_channel:end_channel, start_row:end_row, start_col:end_col]

Num patches 24


array([[[90, 11],
        [38, 50]],

       [[ 1, 28],
        [24, 21]]])