### Imports

In [1]:
import os
from monai.transforms import (
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    ToTensord,
    SaveImaged,
    Spacingd,
    EnsureTyped,
    AsChannelLastd,
    AsChannelFirstd,
    AsDiscreted,
    EnsureChannelFirstd,
    ScaleIntensityRanged,
    FillHolesd,
    RandCropByLabelClassesd,
    Resized, RandFlipd, RandRotate90d,
)
from monai.transforms.transform import MapTransform
from monai.transforms.inverse import InvertibleTransform
from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
from monai.transforms.intensity.array import (
    ScaleIntensityRangePercentiles,
)
import matplotlib.pyplot as plt
from ipywidgets.widgets import * 
import ipywidgets as widgets
import matplotlib.pyplot as plt
import glob 
import torch
import os
# import cv2

### Transformations

In [2]:
class RemoveDicts(MapTransform, InvertibleTransform):

    def __init__(
        self,
        keys: KeysCollection,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            self.push_transform(d, key)
        # print(d["image_meta_dict"]["filename_or_obj"])
        a = {"image": d["image"], "label": d["label"], "path": d["image_meta_dict"]["filename_or_obj"]}
        # print(a["path"])
        d = a
        return d

    def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
        d = deepcopy(dict(data))
        for key in self.key_iterator(d):
            d[key] = d[key]
            # Remove the applied transform
            self.pop_transform(d, key)
        return d

In [3]:
class ScaleIntensityRangePercentilesd(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensityRangePercentiles`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        lower: lower percentile.
        upper: upper percentile.
        b_min: intensity target range min.
        b_max: intensity target range max.
        clip: whether to perform clip after scaling.
        relative: whether to scale to the corresponding percentiles of [b_min, b_max]
        channel_wise: if True, compute intensity percentile and normalize every channel separately.
            default to False.
        dtype: output data type, if None, same as input image. defaults to float32.
        allow_missing_keys: don't raise exception if key is missing.
    """

    backend = ScaleIntensityRangePercentiles.backend

    def __init__(
        self,
        keys: KeysCollection,
        lower: float,
        upper: float,
        b_min: Optional[float],
        b_max: Optional[float],
        clip: bool = False,
        relative: bool = False,
        channel_wise: bool = False,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            d[key] = self.scaler(d[key])
        return d

In [4]:
class NNUnetScaleIntensity(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensityRange`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        a_min: intensity original range min.
        a_max: intensity original range max.
        b_min: intensity target range min.
        b_max: intensity target range max.
        clip: whether to perform clip after scaling.
        dtype: output data type, if None, same as input image. defaults to float32.
        allow_missing_keys: don't raise exception if key is missing.
    """
    def _compute_stats(self, volume, mask):
        volume = volume.copy()
        mask = np.greater(mask, 0) # get only non-zero positive pixels/labels
        volume = volume * mask
        volume = np.ma.masked_equal(volume,0).compressed()
        if len(volume) == 0:
            return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
        median = np.median(volume)
        mean = np.mean(volume)
        std = np.std(volume)
        mn = np.min(volume)
        mx = np.max(volume)
        percentile_99_5 = np.percentile(volume, 99.5)
        percentile_00_5 = np.percentile(volume, 00.5)
        print(median, mean, std, mn, mx, percentile_99_5, percentile_00_5)
        return median, mean, std, mn, mx, percentile_99_5, percentile_00_5

    def __init__(
        self,
        keys: KeysCollection,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            stats = self._compute_stats(d[key], d['label'])
            d[key] = np.clip(d[key], stats[6], stats[5])
            d[key] = (d[key] - stats[1]) / stats[2]
        return d

In [5]:
class ClosePreprocessing(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensityRange`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        a_min: intensity original range min.
        a_max: intensity original range max.
        b_min: intensity target range min.
        b_max: intensity target range max.
        clip: whether to perform clip after scaling.
        dtype: output data type, if None, same as input image. defaults to float32.
        allow_missing_keys: don't raise exception if key is missing.
    """

    def __init__(
        self,
        keys: KeysCollection,
        kernel_size: int = 10,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        self.kernel = np.ones((kernel_size,kernel_size),np.uint8)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        closed_slices = list()
        for slice in range(d["label"].shape[-1]):
            result = cv2.morphologyEx(d["label"][0, :, :, slice], cv2.MORPH_CLOSE, self.kernel)
            closed_slices.append(result)

        d["label"] = torch.Tensor(np.stack(closed_slices)).permute(1, 2, 0).unsqueeze(0)
        return d

In [54]:
from PIL import Image
class WriteToPNG(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensityRange`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        a_min: intensity original range min.
        a_max: intensity original range max.
        b_min: intensity target range min.
        b_max: intensity target range max.
        clip: whether to perform clip after scaling.
        dtype: output data type, if None, same as input image. defaults to float32.
        allow_missing_keys: don't raise exception if key is missing.
    """

    def __init__(
        self,
        keys: KeysCollection,
        output_dir: str,
        mode:str,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        self.output_dir = output_dir
        self.mode = mode

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            if isinstance(d[key], torch.Tensor):
                d[key] = d[key].detach().cpu().numpy()
            for slice in range(d[key].shape[-1]):
                # print("type: " + d[key].dtype)
                filename = os.path.basename(d["image_meta_dict"]["filename_or_obj"]).split(".")[0] + f"_{slice}.png"
                if key == "image":
                    if self.mode == "train":
                        save_dir = os.path.join(self.output_dir, 'imagesTr', filename)
                    else:
                        save_dir = os.path.join(self.output_dir, 'imagesTs', filename)
                else:
                    if self.mode == "train":
                        save_dir = os.path.join(self.output_dir, 'labelsTr', filename)
                    else:
                        save_dir = os.path.join(self.output_dir, 'labelsTs', filename)
                if not os.path.exists(os.path.dirname(save_dir)):
                    print(f"Creating directory: {os.path.dirname(save_dir)}")
                    os.makedirs(os.path.dirname(save_dir))
                print(f"Saving to {save_dir}")
                plt.imsave(save_dir, d[key][0, :, :, slice], cmap="gray")
                # img = Image.fromarray(d[key][0, :, :, slice].astype(np.float32))
                # img.save(save_dir)

        return d

### Dataset

In [55]:
train_images = sorted( glob.glob( os.path.join( "/mnt/chansey/lauraalvarez/","data", "vascular_injuries", "nii", "imagesTr", "*.nii.gz") ) )
train_labels = sorted( glob.glob( os.path.join( "/mnt/chansey/lauraalvarez/","data", "vascular_injuries", "nii", "labelsTr", "*.nii.gz") ) )
data_dicts = [ {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ]
test_images = sorted( glob.glob( os.path.join( "/mnt/chansey/lauraalvarez/","data", "vascular_injuries", "nii", "imagesTs", "*.nii.gz" ) ) )
test_labels = sorted( glob.glob( os.path.join( "/mnt/chansey/lauraalvarez/","data", "vascular_injuries", "nii", "labelsTs", "*.nii.gz") ) )
data_dicts_test = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_images, test_labels) ]

In [56]:
# transforms_bsl = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image"]), ToTensord(keys=["image", "label"]),])
transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        # RemoveDicts(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        # AsChannelFirstd(keys=["label"]),
        # AsDiscreted(keys=["label"], argmax=True),
        # Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 1), mode=("bilinear", "nearest"),),
        CropForegroundd(keys=["image", "label"], source_key="label"),
        NNUnetScaleIntensity(keys=["image"]),
        ClosePreprocessing(keys=["label"]),
        WriteToPNG(keys=["image", "label"], output_dir="/mnt/chansey/lauraalvarez/data/vascular_injuries/png/", mode="test"),
        ToTensord(keys=["image", "label"]),
    ]
)

# injure_org = transforms_bsl(data_dicts)
error_cases = list()
for data_dict in data_dicts_test:
    try:
        data_dict = transforms(data_dict)
    except Exception as e:
            error_cases.append(data_dict)
print(f"{len(error_cases)} error cases")
print(error_cases)
# injure_crop = transforms(data_dicts)
# print(injure_crop["image"].shape, injure_crop["label"].shape)

0 error cases
[]


In [57]:
blended_true_label = blend_images(injure_crop["image"], injure_crop["label"], alpha=0.9)
blended_final_true_label_closed = blended_true_label.permute(1,2,0,3)
print(blended_final_true_label_closed.shape)

NameError: name 'injure_crop' is not defined

In [None]:
from monai.visualize import matshow3d, blend_images
import torch 

def dicom_animation(slice):
    plt.figure(figsize=(18, 6))
    plt.title(f"liver no injured ")
    plt.imshow(blended_final_true_label_closed[:, :, :, slice], cmap="bone")
    plt.show()

interact(dicom_animation, slice=(0, blended_final_true_label_closed.shape[-1]-1))

interactive(children=(IntSlider(value=88, description='slice', max=177), Output()), _dom_classes=('widget-inte…

<function __main__.dicom_animation(slice)>

In [None]:
from monai.visualize import matshow3d, blend_images
import torch 

def dicom_animation(slice):
    plt.figure(figsize=(18, 6))
    plt.title(f"liver no injured ")
    plt.imshow(blended_final_true_label[:, :, :, slice], cmap="bone")
    plt.show()

interact(dicom_animation, slice=(0, blended_final_true_label.shape[-1]-1))

interactive(children=(IntSlider(value=88, description='slice', max=177), Output()), _dom_classes=('widget-inte…

<function __main__.dicom_animation(slice)>

## Load png example

In [6]:
class RemoveAlpha(MapTransform):

    def __init__(
        self,
        keys: KeysCollection,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            if key == "label":
                d[key] = d[key][...,:1]
            else:
                d[key] = d[key][...,:3]
        return d

In [7]:
class KeepOnlyClass(MapTransform, InvertibleTransform):

    def __init__(
        self,
        keys: KeysCollection,
        class_to_keep: int,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        self.class_to_keep = class_to_keep

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            self.push_transform(d, key)
            # d[key] = np.where((d[key] != self.class_to_keep), 0, d[key])
            # d[key] = np.where((d[key] == self.class_to_keep), 1, d[key])
            d[key] = np.where(d[key] == 255, 1, 0)
            values = d[key]
            n_values = np.max(values) + 1
            d[key]= np.squeeze(np.eye(n_values)[values])
            print(np.unique(d[key][:,:,0]))
            print(np.unique(d[key][:,:,1]))
            print(d[key].shape)
            # print(np.unique(d[key][:,:,0]))
            # print(np.unique(d[key][:,:,1]))

        return d

In [8]:
class ToGrayScale(MapTransform):
    def __init__(
        self,
        keys: KeysCollection,
        normalize: bool = False,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        self.normalize = normalize

    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            d[key] = d[key][..., :1]
            if self.normalize:
                d[key] = d[key] / 255
            print(d[key].shape)

        return d


In [10]:
transforms =  Compose(
            [
                LoadImaged(keys=["image", "label"], reader='pilreader'),
                RemoveAlpha(keys=["image", "label"]),
                KeepOnlyClass(keys=["label"], class_to_keep=255),
                ToGrayScale(keys=["image"], normalize=True),
                EnsureChannelFirstd(keys=["image", "label"]),
                Resized(keys=["image", "label"], spatial_size=(256,256)),
                # RandSpatialCropd(keys=["image", "label"], roi_size=self.train_img_size,random_size=True),
                RandFlipd( 
                    keys=["image", "label"],
                    spatial_axis=[0],
                    prob=0.10,
                ),
                RandFlipd(
                    keys=["image", "label"],
                    spatial_axis=[0],
                    prob=0.10,
                ),
                RandFlipd(
                    keys=["image", "label"],
                    spatial_axis=[1],
                    prob=0.10,
                ),
                RandRotate90d(
                    keys=["image", "label"],
                    prob=0.10,
                    max_k=3,
                ),
                ToTensord(keys=["image", "label"]),
            ]
        )

train_images = sorted( glob.glob( os.path.join( "U:\\lauraalvarez","data", "vascular_injuries", "png", "imagesTr", "VI_L110016_157.png") ) )
train_labels = sorted( glob.glob( os.path.join( "U:\\lauraalvarez","data", "vascular_injuries", "png", "labelsTr", "VI_L110016_157.png") ) )
data_dicts = [ {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ]
result = transforms(data_dicts)

[0. 1.]
[0. 1.]
(142, 106, 2)
(142, 106, 1)


In [162]:
print(result[0][0]["image"].shape)
print(result[0][0]["label"].shape)


(1, 96, 96)
(2, 96, 96)


In [166]:
from monai.visualize import matshow3d, blend_images
import torch 


def dicom_animation(slice):
    f, axarr = plt.subplots(2, 3, figsize=(15, 6))
    axarr[0,0].imshow(result[0][0]["image"][0,:,:], cmap="bone")
    axarr[0,1].imshow(result[0][0]["image"][0,:,:], cmap="bone")
    axarr[0,2].imshow(result[0][0]["image"][0,:,:], cmap="bone")
    axarr[1,0].imshow(result[0][0]["label"][slice,:,:], cmap="bone")
    axarr[1,1].imshow(result[0][0]["label"][slice,:,:], cmap="bone")
    axarr[1,2].imshow(result[0][0]["label"][slice,:,:], cmap="bone")


interact(dicom_animation, slice=(0, 1))

interactive(children=(IntSlider(value=0, description='slice', max=1), Output()), _dom_classes=('widget-interac…

<function __main__.dicom_animation(slice)>