In [1]:
import os
from monai.transforms import (
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    ToTensord,
    Resized,
    AsChannelLastd,
    AsChannelFirstd,
    AsDiscrete,
    CropForeground,
    SpatialCropd,
    AsDiscreted,
    ScaleIntensityRanged,
    EnsureType,
    KeepLargestConnectedComponent,
    KeepLargestConnectedComponentd,
    LabelToContour,
    FillHolesd
)
import glob
from monai.transforms.transform import MapTransform
from monai.transforms.inverse import InvertibleTransform
from monai.data import decollate_batch
import SimpleITK as sitk
from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
from monai.visualize import matshow3d, blend_images
import torch
from monai.metrics import DiceMetric
import csv
import cv2
# import cc3d
# import morphsnakes as ms
import cv2
import imageio
from collections import Counter
from skimage.morphology import disk, dilation, binary_dilation, ball

class RefineOutput(MapTransform):
    def __init__(
        self,
        keys: KeysCollection,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        old_mask_organ = np.where((d["label"] != 1), 0, d["label"])
        kernel = np.ones((2, 2), np.uint8)
        old_mask_organ =  np.expand_dims(np.stack([cv2.dilate(old_mask_organ[0,slice,:,:],kernel,iterations = 1) for slice in range(old_mask_organ.shape[1])]),0)
        old_mask_injury = np.where((d["label"] != 2), 0, d["label"]) 
        new_mask_injury = np.zeros_like(old_mask_injury)
        new_img = d["image"][:, :, :, :].copy()
        idx_label_organ = np.where(old_mask_organ.flatten() == 1)[0] #ids of spleen
        min_intensity = np.min(new_img[old_mask_injury!=0]) 
        idx_img = np.where((new_img.flatten() > min_intensity))[0]
        idx_img_2 = np.where((new_img.flatten() < min_intensity +30))[0]
        idx_img = np.intersect1d(idx_img, idx_img_2)
        idx_to_change = np.intersect1d(idx_img, idx_label_organ)
        np.put(new_mask_injury, idx_to_change, 1)
        old_mask_injury += new_mask_injury
        old_mask_injury = np.where((old_mask_injury == 3), 2, old_mask_injury) 
        old_mask_injury = np.where((old_mask_injury == 1), 2, old_mask_injury) 

        # closed_slices = list()
        # for slice in range(new_mask.shape[-1]):
        #     result = cv2.morphologyEx(
        #         new_mask[0, :, :, slice], cv2.MORPH_CLOSE, kernel, iterations=2
        #     )
        #     result = cv2.medianBlur(result, 3)
        #     closed_slices.append(result)

        # new_mask = np.stack(closed_slices)

        final_mask = old_mask_injury + old_mask_organ

        d["label"] = final_mask


        return d


class DilationLabel(MapTransform):
    def __init__(
        self,
        keys: KeysCollection,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        img = d["label"]
        old_mask_injury = np.where((img != 2), 0, img)
        if ORGAN == "Spleen":
            injury_size = np.sum(old_mask_injury)/2
            if injury_size < 4500:
                # radius = int((injury_size/1000)*2)
                old_mask_organ = np.where((img != 1), 0, img)
                final_mask_injury = dilation(old_mask_injury[0,:,:,:], footprint=ball(radius=6))
                final_mask_injury = np.expand_dims(final_mask_injury, 0).astype(np.int8)
                final_mask = old_mask_organ + final_mask_injury
                final_mask = np.where((final_mask == 3), 2, final_mask)
                d["label"] = final_mask
            else:
                old_mask_organ = np.where((img != 1), 0, img)
                final_mask_injury = dilation(old_mask_injury[0,:,:,:], footprint=ball(radius=2))
                final_mask_injury = np.expand_dims(final_mask_injury, 0).astype(np.int8)
                final_mask = old_mask_organ + final_mask_injury
                final_mask = np.where((final_mask == 3), 2, final_mask)
                d["label"] = final_mask
        if ORGAN == "Liver":
            old_mask_organ = np.where((img != 1), 0, img)
            mask = disk(2)
            new_mask_injury = list()
            for slice in range(old_mask_injury.shape[1]):
                result = dilation(old_mask_injury[0,slice,:,:], footprint=mask)
                new_mask_injury.append(result)
            final_mask_injury = np.stack(new_mask_injury)
            final_mask_injury = np.expand_dims(final_mask_injury, 0).astype(np.int8)
            final_mask = old_mask_organ + final_mask_injury
            final_mask = np.where((final_mask == 3), 2, final_mask)
            d["label"] = final_mask

        return d


def fill_contours_fixed(arr):
    slices = []
    for _ in range(arr.shape[0]):
        slices.append(
        np.maximum.accumulate(arr, 1) &\
            np.maximum.accumulate(arr[:, :, ::-1], 1)[:, :, ::-1] &\
            np.maximum.accumulate(arr[:, ::-1, :], 0)[:,::-1, :] &\
            np.maximum.accumulate(arr[::-1, :, :], 0)[::-1, :, :] &\
            np.maximum.accumulate(arr, 0))
    return np.stack(slices, 0)



class ActiveContour(MapTransform):
    def __init__(
        self,
        keys: KeysCollection,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        label = d["label"][:,:,:,:]
        img = d["image"][:,:,:,:]
        old_mask_organ = np.where((label != 1), 0, label)
        old_mask_injury = np.where((label != 2), 0, label)
        # old_mask_injury = np.expand_dims(old_mask_injury,0)
        temp_d = {"image": d["image"], "label": old_mask_injury}
        cropped_contour = CropForegroundd(keys=["image", "label"], 
                                            source_key="label",
                                            margin=10)(temp_d)
        cropped_img = cropped_contour["image"][0,:,:,:]
        cropped_label_injury = cropped_contour["label"][0,:,:,:]
        # gimg = ms.inverse_gaussian_gradient(cropped_img, alpha=1000, sigma=5.48)
        # contour_injury = LabelToContour()(cropped_label_injury)
        # label_ac = ms.morphological_geodesic_active_contour(gimg, iterations=10,
        #                                      init_level_set=cropped_label_injury,
        #                                      smoothing=1, threshold=0.31,
        #                                      balloon=1)
        label_ac = ms.morphological_chan_vese(cropped_img, 25, init_level_set=cropped_label_injury, lambda2=2)
        label_ac = np.where((label_ac == 1), 2, label_ac)
        cropped_contour["label"] = np.expand_dims(label_ac, 0)
        inv_cropped = CropForegroundd(keys=["image", "label"], source_key="label",
                                            margin=30).inverse(cropped_contour)
        label_ac = inv_cropped["label"]
        final_mask_injury = label_ac.astype(np.int8)
        final_mask = old_mask_organ + final_mask_injury
        final_mask = np.where((final_mask == 3), 2, final_mask)
        d["label"] = final_mask

        return d

def save_csv(output_path, task_name, data):
    import csv

    base_path = os.path.join(
        HOME,
        "lauraalvarez",
        "nnunet",
        "nnUNet_raw_data",
        task_name,
        OUT_FOLDER,
        GIF_FOLDER,
        output_path)

    keys = data[0].keys()
    a_file = open(base_path, "w+")
    dict_writer = csv.DictWriter(a_file, keys)
    dict_writer.writeheader()
    dict_writer.writerows(data)
    a_file.close()


def _save_gif(volume, filename, task_name="Task504_LiverTrauma"):
    volume = volume.astype(np.float64) / np.max(volume)  # normalize the data to 0 - 1
    volume = volume * 255  # Now scale by 255
    volume = volume.astype(np.uint8)
    base_path = os.path.join(
        HOME,
        "lauraalvarez",
        "nnunet",
        "nnUNet_raw_data",
        task_name,
        OUT_FOLDER,
        GIF_FOLDER)
    path_to_gif = os.path.join(base_path, f"{filename}.mp4")
    if not os.path.exists(base_path):
        print("Creating gifs directory")
        os.mkdir(base_path)
    imageio.mimsave(path_to_gif, volume, fps=5)
    return path_to_gif



In [2]:
OUT_FOLDER = "out_unet"
GIF_FOLDER = "gifs"
ORGAN = "Liver"
HOME = "/mnt/chansey/"

task_name="Task510_LiverTraumaDGX"
    #Task511_SpleenTraumaCV Task510_LiverTraumaDGX Task512_LiverSpleenTrauma
    
predictions = glob.glob(
    os.path.join(
        HOME,
        "lauraalvarez",
        "nnunet",
        "nnUNet_raw_data",
        task_name,
        OUT_FOLDER,
        "*.nii.gz",
    )
)

images = [x.replace(OUT_FOLDER,"imagesTs") for x in predictions]
images = [x.replace(".nii.gz","_0000.nii.gz") for x in images]
true_labels = [x.replace(OUT_FOLDER,"labelsTs") for x in predictions]
# true_labels = [x.replace("_0000.nii.gz", ".nii.gz") for x in true_labels]




done = ["TNI_000_0000", "TNI_002_0000", "TNI_004_0000", "TSpLi_001_0000", "TSpLi_003_0000", "TSpLi_005_0000", 
       "TSpLi_008_0000", "TSpLi_010_0000", "TSpLi_011_0000", "TSpLi_013_0000", "TSpLi_015_0000"]

done = [x + ".nii.gz" for x in done]

data_dicts_test = [
    {"image": image_name, "label": label_name, "tLabel": true_name}
    for image_name, label_name, true_name in zip(images, predictions, true_labels) if os.path.basename(image_name) not in done
]



In [3]:
data_dicts_test

[{'image': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/imagesTs/TLIV_002_0000.nii.gz',
  'label': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/out_unet/TLIV_002.nii.gz',
  'tLabel': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/labelsTs/TLIV_002.nii.gz'},
 {'image': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/imagesTs/TLIV_000_0000.nii.gz',
  'label': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/out_unet/TLIV_000.nii.gz',
  'tLabel': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/labelsTs/TLIV_000.nii.gz'},
 {'image': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/imagesTs/TLIV_006_0000.nii.gz',
  'label': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/out_unet/TLIV_006.nii.gz',
  'tLabel': '/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/labelsTs

In [4]:
# CASE 006 EXCLUDED, TOO BIG FOR DIAG TO MANAGE, REDUCE SIZE TO GET SCORES EL 5 Y 9 TAMBIEN

In [5]:
a = Compose(
        [
            LoadImaged(keys=["image", "label", "tLabel"]),
            AddChanneld(keys=["label", "image", "tLabel"]),
            CropForegroundd(keys=["image","tLabel", "label"], source_key="image"),
            
#             DilationLabel(keys=["label"]),
            KeepLargestConnectedComponentd(keys=["label"], applied_labels=[1,2], is_onehot=False, independent=True),
            # ActiveContour(keys=["label", "image"]),
            # FillHolesd(keys=["label"]),

            ScaleIntensityRanged(
                keys=["image"],
                a_min=-175,
                a_max=250,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
        ]
    )(data_dicts_test[1])

In [6]:
print(a["image"].shape)
print(a["label"].shape)
print(a["tLabel"].shape)


(1, 506, 473, 1345)
(1, 506, 473, 1345)
(1, 506, 473, 1345)


In [8]:

# NOTE: POr alguna razon aqui uno de los scanneres me sale metric 0 cuando con el mismo codigo
# en el run_metrics me sale 0.67, habra que debugear, ignorando for now.
csv_list = []
for data in data_dicts_test:
    print(f"Infering for \n\t image:{data['image']}, \n\t label: {data['label']}, \n\t true label: {data['tLabel']}")
    normal_plot = Compose(
        [
            LoadImaged(keys=["image", "label", "tLabel"]),
            AsChannelFirstd(keys=["image", "label", "tLabel"]),
            AddChanneld(keys=["label", "image", "tLabel"]),
            CropForegroundd(keys=["image","tLabel",  "label"], source_key="image"),
            
#             DilationLabel(keys=["label"]),
            KeepLargestConnectedComponentd(keys=["label"], applied_labels=[1,2], is_onehot=False, independent=True),
            # ActiveContour(keys=["label", "image"]),
            # FillHolesd(keys=["label"]),

            ScaleIntensityRanged(
                keys=["image"],
                a_min=-175,
                a_max=250,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
        ]
    )
    basename = os.path.basename(data["image"])
    injures = normal_plot(data)
    post_pred = Compose([AsDiscrete(to_onehot=3)])
    post_label = Compose([AsDiscrete(to_onehot=3)])
    outputs = torch.Tensor(np.expand_dims(post_pred(torch.Tensor(injures["label"])), 0))
    labels = torch.Tensor(np.expand_dims(post_label(torch.Tensor(injures["tLabel"])), 0))
    dice_metric = DiceMetric(include_background=False, reduction="mean_batch")
    print(outputs.shape, labels.shape)
    dice_metric(y_pred=outputs, y=labels)
#     dice_liver, dice_spleen, dice_liver_injury, dice_spleen_injury = dice_metric.aggregate()
    dice_liver, dice_liver_injury = dice_metric.aggregate()
    

    dict_data = {
        "image": basename,
        "dice_liver": dice_liver.numpy(),
#         "dice_spleen": dice_spleen.numpy(),
        "dice_liver_injury": dice_liver_injury.numpy(),
#         "dice_spleen_injury": dice_spleen_injury.numpy(),
    }
    print(dict_data)
    csv_list.append(dict_data)
    save_gif = True
    if save_gif == True:
        post_plotting = Compose([EnsureType(), AsDiscrete(argmax=False)])
        injures["label"] = post_plotting(injures["label"])
        inj = dict(injures)
        inj = Resized(keys=["image", "label", "tLabel"], spatial_size=(512, 512, 512))(
            inj
        )

        blended_label_in = blend_images(inj["image"], inj["label"], 0.5)
        blended_final = blended_label_in.permute(1, 2, 0, 3)

        blended_true_label = blend_images(inj["image"], inj["tLabel"], 0.5)
        blended_true_label = torch.from_numpy(blended_true_label).permute(1, 2, 0, 3)

        volume = torch.hstack(
            (
                torch.from_numpy(inj["image"]).permute(1, 2, 0, 3).repeat(1, 1, 3, 1),
                blended_final,
                blended_true_label,
            )
        )
        volume = volume.permute(0, 1, 3, 2)

        volume_path = _save_gif(volume.numpy(), f"{basename}", task_name)
        # _save_gif(blended_true_label.numpy().transpose(0, 1, 3, 2), f"{basename}_True", task_name)
        # _save_gif(blended_final.numpy().transpose(0, 1, 3, 2), f"{basename}_Pred", task_name)

        print(f"Saved {volume_path}")
    save_csv("summary.csv", task_name, csv_list)

Infering for 
	 image:/mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/imagesTs/TLIV_002_0000.nii.gz, 
	 label: /mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/out_unet/TLIV_002.nii.gz, 
	 true label: /mnt/chansey/lauraalvarez/nnunet/nnUNet_raw_data/Task510_LiverTraumaDGX/labelsTs/TLIV_002.nii.gz
torch.Size([1, 3, 1569, 495, 416]) torch.Size([1, 3, 1569, 495, 416])
{'image': 'TLIV_002_0000.nii.gz', 'dice_liver': array(0.13479035, dtype=float32), 'dice_liver_injury': array(0., dtype=float32)}


ImportError: To use the imageio ffmpeg plugin you need to 'pip install imageio-ffmpeg'

In [9]:
import os
# os.environ["IMAGEIO_FFMPEG_EXE"] = "/usr/bin/ffmpeg"

In [10]:
!pip install --upgrade pip --user

Collecting pip
  Downloading pip-22.2.2-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 12.7 MB/s eta 0:00:01
[?25hInstalling collected packages: pip
Successfully installed pip-22.2.2


In [11]:
!pip install imageio-ffmpeg --user

Collecting imageio-ffmpeg
  Downloading imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl (26.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.9/26.9 MB[0m [31m40.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: imageio-ffmpeg
Successfully installed imageio-ffmpeg-0.4.7


In [None]:
!pip install numpy==1.16 --user

Collecting numpy==1.16
  Downloading numpy-1.16.0.zip (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m50.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: numpy
  Building wheel for numpy (setup.py) ... [?25l/

In [None]:
!pip install ffmpeg

In [None]:
!pip install opencv-python

In [None]:
!pip install scikit-image

In [None]:
!pip install torchvision -U