In [8]:
import monai
import pathlib
from itkwidgets import view, compare

In [134]:
transforms = monai.transforms.Compose([
        monai.transforms.LoadImageD(keys=['magnitude','mask','velx','vely','velz']),
        monai.transforms.SqueezeDimd(keys=['magnitude','mask','velx','vely','velz'], dim=3),
        monai.transforms.SqueezeDimd(keys=['magnitude','mask','velx','vely','velz'], dim=2),
        monai.transforms.SqueezeDimd(keys=['magnitude','mask','velx','vely','velz'], dim=-1),
        monai.transforms.AddChanneld(keys=['magnitude','mask','velx','vely','velz']),
        monai.transforms.NormalizeIntensityD(keys=['magnitude']), 
        monai.transforms.NormalizeIntensityD(keys=['velx','vely','velz'], channel_wise=False),
        monai.transforms.Resized(keys=["magnitude","mask",'velx','vely','velz'], spatial_size = [-1,-1,64], size_mode='all', mode= "nearest", allow_missing_keys=False),
        monai.transforms.Spacingd(keys=["magnitude","mask",'velx','vely','velz'], pixdim = [1.54, 1.54, -1], mode=  ["bilinear", "nearest", "bilinear", "bilinear", "bilinear"]),
        monai.transforms.SpatialPadd(
            keys=["magnitude","mask",'velx','vely','velz'], spatial_size=[64,64,-1], allow_missing_keys=False
        ),
        #monai.transforms.RandWeightedCropd(keys=["magnitude","mask",'velx','vely','velz'], w_key="mask", spatial_size=[64,64,-1], num_samples=1),
        monai.transforms.ConcatItemsd(keys=["magnitude", "velx", "vely", "velz"], name="inputs"),
        monai.transforms.EnsureTyped(keys=['inputs','mask'])
    ])

In [147]:
transforms_aug = monai.transforms.Compose([
        monai.transforms.LoadImageD(keys=['magnitude','mask','velx','vely','velz']),
        monai.transforms.SqueezeDimd(keys=['magnitude','mask','velx','vely','velz'], dim=3),
        monai.transforms.SqueezeDimd(keys=['magnitude','mask','velx','vely','velz'], dim=2),
        monai.transforms.SqueezeDimd(keys=['magnitude','mask','velx','vely','velz'], dim=-1),
        monai.transforms.AddChanneld(keys=['magnitude','mask','velx','vely','velz']),
        monai.transforms.NormalizeIntensityD(keys=['magnitude']), 
        monai.transforms.NormalizeIntensityD(keys=['velx','vely','velz'], channel_wise=False),
        monai.transforms.Resized(keys=["magnitude","mask",'velx','vely','velz'], spatial_size = [-1,-1,64], size_mode='all', mode= "nearest", allow_missing_keys=False),
        monai.transforms.Spacingd(keys=["magnitude","mask",'velx','vely','velz'], pixdim = [1.54, 1.54, -1], mode=  ["bilinear", "nearest", "bilinear", "bilinear", "bilinear"]),
        monai.transforms.SpatialPadd(
            keys=["magnitude","mask",'velx','vely','velz'], spatial_size=[64,64,-1], allow_missing_keys=False
        ),
        monai.transforms.RandFlipd(keys=["magnitude","mask",'velx','vely','velz'], spatial_axis=[0,1], prob=1),  # Random flip along x-axis
        #monai.transforms.RandFlipd(keys=["magnitude","mask",'velx','vely','velz'], spatial_axis=1, prob=1),  # Random flip along y-axis
        #monai.transforms.RandRotated(keys=["magnitude","mask",'velx','vely','velz'], prob = 1, range_z=[0.4,0.4], mode= ["bilinear", "nearest", "bilinear", "bilinear", "bilinear"]),
        #monai.transforms.RandWeightedCropd(keys=["magnitude","mask",'velx','vely','velz'], w_key="mask", spatial_size=[64,64,-1], num_samples=1),
        monai.transforms.ConcatItemsd(keys=["magnitude", "velx", "vely", "velz"], name="inputs"),
        monai.transforms.EnsureTyped(keys=['inputs','mask'])
    ])

In [148]:
base_directory = pathlib.Path()

cases = []
for p in base_directory.glob('**/*_velx.nii.gz'):
    cases.append({
        'magnitude': str(p).replace('_velx.nii.gz','_magnitude.nii.gz'),
        'mask': str(p).replace('_velx.nii.gz','_mask.nii.gz'),
        'velx': str(p),
        'vely': str(p).replace('_velx.nii.gz','_vely.nii.gz'),
        'velz': str(p).replace('_velx.nii.gz','_velz.nii.gz'),
    })

In [149]:
data_check = monai.data.Dataset(data=cases, transform = transforms)
data_set = monai.data.Dataset(data=cases, transform = transforms_aug)

In [150]:
n = 0 
data = data_check[n]
data_flip = data_set[n]

In [151]:
view(image = data["inputs"][0,:,:,:], label_image = data["mask"][0,:,:,:] , rotate= True, vmin=4000, vmax=1700, gradient_opacity=0.9)


Viewer(geometries=[], gradient_opacity=0.9, interpolation=False, point_sets=[], rendered_image=<itk.itkImagePy…

In [152]:
view(image = data_flip["inputs"][0,:,:,:], label_image = data_flip["mask"][0,:,:,:] , rotate= True, vmin=4000, vmax=1700, gradient_opacity=0.9)


Viewer(geometries=[], gradient_opacity=0.9, interpolation=False, point_sets=[], rendered_image=<itk.itkImagePy…