<a href="https://colab.research.google.com/github/bnoushin7/self-supervised-medical-imaging/blob/main/data_utils_colab_final_backup.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip install -r requirements.txt
!pip install monai==1.2.0
#!pip install nibabel

In [None]:
!unzip covid_19.zip

In [None]:
import os
import sys
import argparse
import monai
import numpy as np
from monai.data import CacheDataset, DataLoader, Dataset, DistributedSampler, SmartCacheDataset, load_decathlon_datalist
from monai.transforms import (
    AddChanneld,
    AsChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    RandCropByPosNegLabeld,
    RandSpatialCropSamplesd,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    ToTensord,
)
print(monai.__version__)

1.2.0


batchsize, channel, x, y ,z


In [None]:

def get_args():

    parser = argparse.ArgumentParser(description="Parse command line arguments.")
    parser.add_argument("--a_min", default=-1000, type=float, help="a_min in ScaleIntensityRanged")
    parser.add_argument("--a_max", default=1000, type=float, help="a_max in ScaleIntensityRanged")
    parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
    parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
    parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
    parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
    parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
    parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction")
    parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction")
    parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction")
    parser.add_argument("--batch_size", default=2, type=int, help="number of batch size")
    parser.add_argument("--sw_batch_size", default=2, type=int, help="number of sliding window batch size")
    parser.add_argument("--smartcache_dataset", action="store_true", help="use monai smartcache Dataset")
    parser.add_argument("--cache_dataset", action="store_true", help="use monai cache Dataset")
    parser.add_argument('-f', '--file', help=argparse.SUPPRESS)


    if 'ipykernel' in sys.argv[0]:
        return parser.parse_args(args=[])
    else:
        return parser.parse_args()
args = get_args()
print(args)


Namespace(a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, space_x=1.5, space_y=1.5, space_z=2.0, roi_x=96, roi_y=96, roi_z=96, batch_size=2, sw_batch_size=2, smartcache_dataset=False, cache_dataset=False, file='/root/.local/share/jupyter/runtime/kernel-a4b13a38-97be-45af-9434-7553971a94fe.json')


In [None]:
class MaskGenerator:
    def __init__(self, input_size=256, mask_patch_size=64, model_patch_size=32, mask_ratio=0.2):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio

        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0

        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size

        self.token_count = self.rand_size ** 3
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))

    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1

        mask = mask.reshape((self.rand_size, self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1).repeat(self.scale,axis=2)

        return mask

In [None]:
from monai.transforms import Transform

class Generate3DMask(Transform):
    def __init__(self, mask_generator):
        self.mask_generator = mask_generator

    def __call__(self, data):
        image = data['image']
        mask = self.mask_generator()  # Generate 3D mask

        # Ensure the mask has the same number of dimensions as the image
        # Typically, image might have a channel dimension, so we add one to the mask
        if len(mask.shape) < len(image.shape):
            mask = np.expand_dims(mask, axis=0)

        # Add the mask to the data dictionary
        data['mask'] = mask
        return data



In [None]:
def get_loader(args):

    dir_path = "/content/covid_19"
    jsonlist_covid = dir_path + "/dataset_TCIAcovid19_0.json"
    datadir_covid = dir_path + "/CT-Covid-19-August2020/"
    num_workers = 1


    datalist_covid = load_decathlon_datalist(jsonlist_covid, False, "training", base_dir=datadir_covid)

    vallist_covid = load_decathlon_datalist(jsonlist_covid, False, "validation", base_dir=datadir_covid)

    # Initialize the MaskGenerator with the appropriate parameters
    mask_generator = MaskGenerator()

    train_transforms = Compose(
        [
            LoadImaged(keys=["image"]),
            AddChanneld(keys=["image"]),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
            ),
            SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
            CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
            RandSpatialCropSamplesd(
                keys=["image"],
                roi_size=[args.roi_x, args.roi_y, args.roi_z],
                num_samples=args.sw_batch_size,
                random_center=True,
                random_size=False,
            ),
            Generate3DMask(mask_generator),  # Add the custom mask generation transform
            ToTensord(keys=["image", "mask"]),  # Ensure "mask" is also transformed to a tensor
        ]
    )

    val_transforms = Compose(
        [
            LoadImaged(keys=["image"]),
            AddChanneld(keys=["image"]),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
            ),
            SpatialPadd(keys="image", spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
            CropForegroundd(keys=["image"], source_key="image", k_divisible=[args.roi_x, args.roi_y, args.roi_z]),
            RandSpatialCropSamplesd(
                keys=["image"],
                roi_size=[args.roi_x, args.roi_y, args.roi_z],
                num_samples=args.sw_batch_size,
                random_center=True,
                random_size=False,
            ),
            ToTensord(keys=["image"]),
        ]
    )



    if args.cache_dataset:
        print("Using MONAI Cache Dataset")
        train_ds = CacheDataset(data=datalist_covid, transform=train_transforms, cache_rate=0.5, num_workers=num_workers)
    elif args.smartcache_dataset:
        print("Using MONAI SmartCache Dataset")
        train_ds = SmartCacheDataset(
            data=datalist_covid,
            transform=train_transforms,
            replace_rate=1.0,
            cache_num=2 * args.batch_size * args.sw_batch_size,
        )
    else:
        print("Using generic dataset")
        train_ds = Dataset(data=datalist_covid, transform=train_transforms)


    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, num_workers=num_workers, drop_last=True
    )

    val_ds = Dataset(data=vallist_covid, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, num_workers=num_workers, shuffle=False, drop_last=True)

    return train_loader, val_loader



def test_get_loader():

    train_loader, val_loader = get_loader(args)

    for batch in train_loader:
      print(f" batch['image'] size is:  {batch['image'].size()}")
      print(f" batch['mask']  size is:  {batch['mask'].size()}")

test_get_loader()

    # train_batch = next(iter(train_loader))
    # print("Train batch keys:", train_batch.keys())
    # print("Train batch 'image' shape:", train_batch["image"].shape)


    #val_batch = next(iter(val_loader))
    # print("Validation batch keys:", val_batch.keys())
    # print("Validation batch 'image' shape:", val_batch["image"].shape)




Using generic dataset
 batch['image'] size is:  torch.Size([4, 1, 96, 96, 96])
 batch['mask']  size is:  torch.Size([4, 1, 8, 8, 8])


torch.Size([1, 512, 512, 64])

AddChanneld
torch.Size([1, 1, 512, 512, 64])


Orientationd
torch.Size([1, 1, 512, 512, 64])


ScaleIntensityRanged
torch.Size([1, 1, 512, 512, 64])



SpatialPadd
torch.Size([1, 1, 512, 512, 96])


CropForegroundd
torch.Size([1, 1, 512, 512, 96])


RandSpatialCropSamplesd
torch.Size([2, 1, 96, 96, 96])
