In [1]:
import os
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    EnsureType,
    EnsureChannelFirstd,
    RandFlipd,
    RandRotated,
    ToTensord,
    Resized,
    RandSpatialCropSamplesd,
    RandRotate90d,
    RandShiftIntensityd,
    KeepLargestConnectedComponent,
    RandCropByPosNegLabeld,
    RandCropByLabelClassesd
)
from monai.transforms.transform import MapTransform
from monai.transforms.inverse import InvertibleTransform

from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np

In [2]:
class RemoveDicts(MapTransform, InvertibleTransform):

    def __init__(
        self,
        keys: KeysCollection,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            self.push_transform(d, key)
        # print(d["image_meta_dict"]["filename_or_obj"])
        a = {"image": d["image"], "label": d["label"], "path": d["image_meta_dict"]["filename_or_obj"]}
        # print(a["path"])
        d = a
        return d

    def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
        d = deepcopy(dict(data))
        for key in self.key_iterator(d):
            d[key] = d[key]
            # Remove the applied transform
            self.pop_transform(d, key)
        return d


In [3]:
from monai.visualize import matshow3d, blend_images
import imageio
def make_gif(predictions):
    volumes = []
    for prediction in predictions:
        selected = prediction

        pred = torch.argmax(selected['output'], dim=1).detach().cpu().numpy()
        true_label = torch.sum(selected['label'][:,1:,:,:,:], dim=1).detach().cpu().numpy()
        image = selected['image'][0].cpu().numpy()

        blended_true_label = blend_images(image, true_label)
        blended_final_true_label = torch.from_numpy(blended_true_label).permute(1,2,0,3)

        blended_prediction = blend_images(image, pred)
        blended_final_prediction = torch.from_numpy(blended_prediction).permute(1,2,0,3)

        volume_pred = blended_final_prediction[:,:,:,:]
        volume_label = blended_final_true_label[:,:,:,:]
        volume_pred = np.squeeze(volume_pred).permute(3,0,1,2)
        volume_label = np.squeeze(volume_label).permute(3,0,1,2)
        volume = torch.hstack((volume_pred, volume_label)).numpy()
        volumes.append(volume)
    volume = np.hstack((volumes))
    data = volume.astype(np.float64) / np.max(volume) # normalize the data to 0 - 1
    data = 255 * data # Now scale by 255
    volume = data.astype(np.uint8)
    path_to_gif = f'gifs\\prediction.gif'
    if not os.path.exists("gifs\\"):
        os.mkdir("gifs\\")
    imageio.mimsave(path_to_gif, volume)
    return path_to_gif

In [4]:
import matplotlib.pyplot as plt
from ipywidgets.widgets import * 
import ipywidgets as widgets
import matplotlib.pyplot as plt

injure_liver = os.path.join(  "/mnt/chansey/", "lauraalvarez", "nnunet", "nnUNet_raw_data_base", "nnUNet_raw_data", "Task503_LiverSpleenTrauma", "imagesTr", "TRMLIV_043_0000.nii.gz")
injure_liver_label = os.path.join( "/mnt/chansey/", "lauraalvarez", "nnunet", "nnUNet_raw_data_base", "nnUNet_raw_data", "Task503_LiverSpleenTrauma", "labelsTr", "TRMLIV_043.nii.gz")

injure_liver = '/mnt/chansey/lauraalvarez/data/liver_spleen_nnunet/train/data/L110162.mha'
injure_liver_label = '/mnt/chansey/lauraalvarez/data/liver_spleen_nnunet/train/mask/L110162.mha'

In [5]:
paths = {'image': injure_liver, 'label': injure_liver_label}
val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                # RemoveDicts(keys=["image", "label"]),
                # AddChanneld(keys=["image", "label"]),
                # Orientationd(keys=["image", "label"], axcodes="RAS"),
                # Spacingd(
                #     keys=["image", "label"],
                #     pixdim=(1.5, 1.5, 2.0),
                #     mode=("bilinear", "nearest"),
                # ),
                # Resized(keys=["image", "label"], spatial_size=self.val_img_size),
                # ScaleIntensityRanged(
                #     keys=["image"],
                #     a_min=-175,
                #     a_max=250,
                #     b_min=0.0,
                    # b_max=1.0,
                    # clip=True,
                # ),
                # CropForegroundd(keys=["image", "label"], source_key="image"),
                ToTensord(keys=["image", "label"]),
                # RemoveDicts(keys=["image", "label"]),
            ]
        )
injures = val_transforms(paths)

In [6]:
print(injures["image"].shape)
print(injures["label"].shape)
print(np.unique(injures["label"]))

torch.Size([860, 1024, 184])
torch.Size([860, 1024, 184, 6])
[0. 1.]


In [28]:
from monai.visualize import matshow3d, blend_images
import torch 

blended_label_in = blend_images(injures["image"].cpu().numpy(), injures["label"].cpu().numpy())
blended_final = torch.from_numpy(blended_label_in).permute(1,2,0,3)


In [29]:
def dicom_animation(slice):
    # extent = np.min(x), np.max(x), np.min(y), np.max(y)
    plt.figure(figsize=(18, 6))
    plt.title(f"liver no injured ")
    plt.imshow(blended_final[:, :, :, slice], cmap="bone")
    # plt.imshow(ni_arr_label[:, :, :, slice], cmap=plt.cm.viridis, alpha=.15, interpolation=None)
    plt.show()

interact(dicom_animation, slice=(0, blended_final.shape[-1]-1))

interactive(children=(IntSlider(value=42, description='slice', max=85), Output()), _dom_classes=('widget-inter…

<function __main__.dicom_animation(slice)>