# Segment covid 19 from the lungs
Uses https://github.com/RiccardoBiondi/segmentation

Uses material from https://github.com/SimpleITK/SimpleITK

Input: segmented lung 3D image

Output: covid 19 segmentation mask

In [1]:
import numpy as np
import pandas as pd
import os
import SimpleITK as sitk
from myshow import myshow, myshow3d

In [2]:
input_path = "./output_data/patient_test_5_output_image_from_mask.nii"
output_path = "./output_data_covid_segmentation/patient_test_5_output_image_from_mask_covid_show.nii"

In [3]:
# pre-trained centroids
centroids = {'healthy lung': [1.0291475, 1.7986686, 1.3147535, 1.6199226],
            'lung'   :  [2.4449115, 2.8337748, 1.556249,  2.9394238],
            'Edges' :  [3.4244044, 2.1809669, 4.172402,  3.652266],
            'GGO'   :  [5.1485806, 5.3843336, 2.7543516, 4.812335],
            'Noise'     : [8.233303,  1.9194404, 6.503928,  6.670035]}

In [4]:
def threshold(image, upper, lower, inside=1, outside=0):
    thr = sitk.BinaryThresholdImageFilter()
    thr.SetLowerThreshold(lower)
    thr.SetUpperThreshold(upper)
    thr.SetOutsideValue(outside)
    thr.SetInsideValue(inside)
    return thr.Execute(image)

def normalize(image) :
    stats = sitk.StatisticsImageFilter()
    stats.Execute(image)
    if np.isclose(stats.GetSigma(), 0) :
        raise ZeroDivisionError('Cannot normalize image with Sigma == 0')
    norm = sitk.NormalizeImageFilter()
    return norm.Execute(image)

def adaptive_histogram_equalization(image, radius):
    ahe = sitk.AdaptiveHistogramEqualizationImageFilter()
    ahe.SetAlpha(1)
    ahe.SetBeta(1)
    ahe.SetRadius(radius)
    return ahe.Execute(image)

def median_filter(img, radius):
    if radius <=0 :
        raise ValueError('Radius must be greater or equal than one')
    median = sitk.MedianImageFilter()
    median.SetRadius(int(radius))
    return median.Execute(img)

def std_filter(image, radius):
    if radius <=0 :
        raise ValueError('Radius must be greater or equal than one')
    std = sitk.NoiseImageFilter()
    std.SetRadius(radius)
    return std.Execute(image)

def cast_image(image, new_pixel_type):
    caster = sitk.CastImageFilter()
    caster.SetOutputPixelType(new_pixel_type)
    return caster.Execute(image)

def CopyInformation(self, srcImage):
    return _SimpleITK.Image_CopyInformation(self, srcImage)

def adjust_gamma(image, gamma=1.0, image_type='HU'):
    if gamma == 0 :
        raise Exception('gamma vlaue cannot be zero')
    if image_type not in ['HU', 'uint8', 'uint16'] :
        raise Exception('image type {} not supported'.format(type))
    invGamma = 1.0 / gamma
    # cast image to float
    img = cast_image(image, sitk.sitkFloat32)
    c = sitk.PowImageFilter()
    out = c.Execute(img, invGamma)
    # saturate out of bounds voxels
    bound = bounding_values[image_type]
    out = sitk.Threshold(out, bound[0], bound[1], bound[1])
    # cast to the correct type
    out = sitk.Cast(out, image_types[image_type])
    return out

def imlabeling(image, centroids, weight=None) :

    if centroids.shape[1] != image.shape[-1] :
        raise Exception('Number of image channel doesn t match the number of \
                            centroids features : {} != {}\
                            '.format(image.shape[-1], centroids.shape[1]))
    if weight  is not None :
        if weight.shape != image.shape[:-1] :
            raise Exception('Weight shape doesn t match image one : {} != {}\
                                '.format( weight.shape, image.shape[:-1]))
        distances = np.asarray([np.linalg.norm(image[weight != 0] -c, axis = 1) for c in centroids])
        weight[weight != 0] = np.argmin(distances, axis=0)
        return weight
    else :
        distances = np.asarray([np.linalg.norm(image - c, axis=3) for c in centroids])
        labels = np.argmin(distances, axis=0)
        return labels

def shift_and_crop(image) :
    shifted = sitk.ShiftScale(image, 1000, 1.0)
    cropped = sitk.Threshold(shifted, 0, 2048, 0)
    return cropped

def remove_vessels(image, sigma=2., thr=8) :
    smooth = gauss_smooth(image, sigma)
    vessel = vesselness(smooth)
    mask = threshold(vessel, 4000, thr, 0, 1)
    return apply_mask(image, mask, outside_value=-1000)

def gauss_smooth(image, sigma = 1.):
    gauss = sitk.SmoothingRecursiveGaussianImageFilter()
    gauss.SetSigma(sigma)
    return  gauss.Execute(image)

def vesselness(image):
    vess = sitk.ObjectnessMeasureImageFilter()
    vess.SetObjectDimension(1)
    return vess.Execute(image)

def apply_mask(image, mask, masking_value=0, outside_value=-1500):
    mf = sitk.MaskImageFilter()
    mf.SetMaskingValue(masking_value)
    mf.SetOutsideValue(outside_value)
    return mf.Execute(image, mask)

def main(volume, centroids):
    # prepare the image
    weight = sitk.GetArrayFromImage(threshold(image=volume, upper=4000, lower=1))
    equalized = normalize(adaptive_histogram_equalization(image=volume, radius=5))
    median = normalize(median_filter(img=volume, radius=3))
    std = normalize(std_filter(image=volume, radius=3))
    gamma = normalize(adjust_gamma(image=volume, gamma=1.5))

    mc = np.stack([sitk.GetArrayFromImage(equalized),
                   sitk.GetArrayFromImage(median),
                   sitk.GetArrayFromImage(gamma),
                   sitk.GetArrayFromImage(std)], axis = -1)
    
    labels = imlabeling(image=mc, centroids=centroids, weight=weight)
    labels = (labels == 3).astype(np.uint8)
    labels = sitk.GetImageFromArray(labels)
    labels.CopyInformation(volume)
    labels = median_filter(img=labels, radius=3)

    return labels


In [5]:
bounding_values = {'uint8' : [0, 255],
                   'uint16': [0, 2**16],
                   'HU' : [0, 2**12]}
image_types = {'uint8' : sitk.sitkUInt8,
            'uint16': sitk.sitkUInt16,
            'HU' : sitk.sitkUInt16 }

In [6]:
volume = sitk.ReadImage(input_path) #Input from lungmask

In [7]:
volume = remove_vessels(image=volume)

In [8]:
volume = shift_and_crop(volume)

In [9]:
centroids = np.asarray([np.array(v) for _, v in centroids.items()])
labels = main(volume, centroids)

In [10]:
myshow3d(labels)

interactive(children=(IntSlider(value=57, description='z', max=115), Output()), _dom_classes=('widget-interact…

In [11]:
sitk.WriteImage(labels, output_path)