In [2]:
import numpy as np

from monai.transforms import (
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    EnsureTyped,
)

from monai.data import (
    load_decathlon_datalist,
    set_track_meta,
    ThreadDataLoader,
    CacheDataset
)

files = load_decathlon_datalist('raw_data/dataset_0.json', True, "validation")
print(files)

[{'image': 'raw_data\\imagesTr\\img0035.nii.gz', 'label': 'raw_data\\labelsTr\\label0035.nii.gz'}, {'image': 'raw_data\\imagesTr\\img0036.nii.gz', 'label': 'raw_data\\labelsTr\\label0036.nii.gz'}, {'image': 'raw_data\\imagesTr\\img0037.nii.gz', 'label': 'raw_data\\labelsTr\\label0037.nii.gz'}, {'image': 'raw_data\\imagesTr\\img0038.nii.gz', 'label': 'raw_data\\labelsTr\\label0038.nii.gz'}, {'image': 'raw_data\\imagesTr\\img0039.nii.gz', 'label': 'raw_data\\labelsTr\\label0039.nii.gz'}, {'image': 'raw_data\\imagesTr\\img0040.nii.gz', 'label': 'raw_data\\labelsTr\\label0040.nii.gz'}]


In [6]:
import torchvision
import torch

class DictTransform:
    def __init__(self, keys, transform):
        self.keys = keys
        self.transform = transform

    def __call__(self, x):
        x = x.copy()
        for key in self.keys:
            x[key] = self.transform(x[key])
        return x
    
class PreprocessForModel:
    pixel_mean=(torch.Tensor([123.675, 116.28, 103.53]) / 255).view(-1, 1, 1)
    pixel_std=(torch.Tensor([58.395, 57.12, 57.375]) / 255).view(-1, 1, 1)
    img_size=1024

    def __init__(self, normalize=False):
        self.normalize = normalize

    def get_preprocess_shape(self, oldh: int, oldw: int, long_side_length: int):
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

    def __call__(self, x):
        x = x.copy()
        target_size = self.get_preprocess_shape(x['image'].shape[1], x['image'].shape[2], self.img_size)
        tr_img = torchvision.transforms.Resize(target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
        tr_label = torchvision.transforms.Resize(target_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT, antialias=False)
        x['image'] = tr_img(x['image'])
        x['label'] = tr_label(x['label'])

        if self.normalize:
            x['image'] = (x['image'] - self.pixel_mean.to(x['image'].device)) / self.pixel_std.to(x['image'].device)
        h, w = target_size
        padh = self.img_size - h
        padw = self.img_size - w
        x['image'] = torch.nn.functional.pad(x['image'], (0, padw, 0, padh))
        x['label'] = torch.nn.functional.pad(x['label'], (0, padw, 0, padh))
        return x

transform = torchvision.transforms.Compose(
    [DictTransform(["image", "label"], torchvision.transforms.Lambda(lambda x: x.unsqueeze(0).repeat(3, 1, 1))),
    PreprocessForModel(normalize=False)]
)

In [18]:
import os
set_track_meta(True)
_default_transform = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=True, dtype=np.float64),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, dtype=np.float64),
        CropForegroundd(keys=["image", "label"], source_key="image", dtype=np.float64),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        EnsureTyped(keys=["image", "label"], track_meta=False, dtype=np.float64),
    ]
)
cache = CacheDataset(
    data=files, 
    transform=_default_transform, 
    cache_rate=1.0, 
    num_workers=4
)
set_track_meta(False)

for d in cache:
    file_path = d['image_meta_dict']['filename_or_obj']
    file_name = os.path.basename(file_path)
    index_of_dot = file_name.index('.')
    file_name_without_extension = file_name[:index_of_dot]
    print(file_name_without_extension)
    # print(d['image_meta_dict']['filename_or_obj']) ## raw_data\imagesTr\img0035.nii.gz
    images, labels = d['image'][0], d['label'][0]
    h = images.shape[2]
    # data_list = []
    for i in range(h):
        data = {
            "image": images[:, :, i],
            "label": labels[:, :, i],
            "h": i / h
        }
        #print(data['image'].shape)
        #print(data['label'].shape)
        data = transform(data)
        #print(image.shape)
        image = data['image'].numpy().transpose(1, 2, 0)
        image = (image*255).astype(np.uint8)
        label = data['label'][0].numpy()
        print(image.shape)
        print(label.shape)

Loading dataset:   0%|          | 0/6 [00:00<?, ?it/s]pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:43,655 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:43,658 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:43,668 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
2023-06-07 17:10:43,670 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:44,451 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:44,512 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:44,539 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


Loading dataset:  17%|█▋        | 1/6 [00:01<00:06,  1.38s/it]pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:45,038 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:45,423 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:45,806 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:45,871 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


Loading dataset:  33%|███▎      | 2/6 [00:03<00:07,  1.84s/it]pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


2023-06-07 17:10:47,306 - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


Loading dataset: 100%|██████████| 6/6 [00:04<00:00,  1.22it/s]


img0035
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024, 1024, 3)
(1024, 1024)
(1024,

KeyboardInterrupt: 

In [22]:
import numpy as np
a = np.array([1,2])
b=np.array([3,4])
c=[a,b]
print(np.array(c).shape)

(2, 2)
