In [None]:
import cv2
import numpy as np
import SimpleITK as sitk
import registration_gui as rgui
import matplotlib.pyplot as plt
from os.path import join
%matplotlib inline

In [None]:
arg = {'wall_threshold_initial': 165,
    'wall_lw_limit': 500,
    'wall_up_limit': 1800,
    'wall_threshold_lower': 20,
    'wall_threshold_higher': 200,
    'mask_overlap_threshold': 0.75/2,
    'mask_y_threshold': 80}

In [None]:
# show the 3d volume slice by slice
def show_volume(vol):
    for frame in range(vol.shape[-1]):
        img = vol[:,:,frame]
        cv2.imshow('image frame', img)
        cv2.waitKey(0)
    cv2.destroyAllWindows()

In [None]:
image = cv2.imread(join('..', '..', 'data', 'TrainSet', '1027.jpg'))
print('Image shape: ', image.shape)


### Processing SPECT MPI image
1. Cut-out image blocks.
2. RGB to grayscale.
3. Crop and concatenate blocks into 3D volume.
4. Divide into two 3D volume according to status: Stress & Rest.

In [None]:
def _image_to3d(image):
    block_size = 90
    vol = np.empty(shape=(block_size,block_size,0), dtype=np.uint8)
    for block_y in range(0, block_size*8, block_size):
        for block_x in range(0, block_size*10, block_size):
            idx_y = np.arange(block_y, block_y + block_size)
            idx_x = np.arange(block_x, block_x + block_size)
            cropped_block = image[idx_y[:,np.newaxis], idx_x, None]
            vol = np.concatenate((vol, cropped_block), axis=2)

    return vol


def _shield_number_labels(vol):
    for k in range(vol.shape[-1]):
        vol[:14, :17, k] = 0
        vol[0, :] = 0
        vol[:, 0] = 0

    return vol


def _stress_rest_divide(vol):
    stress = [vol[:,:,np.concatenate([np.arange(0,10), np.arange(20,30)], axis=0)], # SA view
        vol[:,:,np.arange(40,50)], # HLA view
        vol[:,:,np.arange(60,70)]] # VLA view

    rest = [vol[:,:,np.concatenate([np.arange(10,20), np.arange(30,40)], axis=0)],
        vol[:,:,np.arange(50,60)],
        vol[:,:,np.arange(70,80)]]

    return stress, rest


def image_preproc_basic(image):
    # cut-out desired part of the image
    rows_SA = np.arange(50,411)
    rows_HLA = np.arange(483,664)
    rows_VLA = np.arange(707,888)
    rows = np.concatenate([rows_SA, rows_HLA, rows_VLA], axis=0)
    cols = np.arange(69,970)
    image_cleaned = image[rows[:,np.newaxis], cols]

    # get the red-channel and grayscale-form of the image
    image_red = image_cleaned[:,:,-1] # red channel is the last channel in bgr
    image_gray = cv2.cvtColor(image_cleaned, cv2.COLOR_BGR2GRAY)
    
    # concatenate blocks into 3D volume.
    vol_red = _image_to3d(image_red)
    vol_gray = _image_to3d(image_gray)

    # shield the number labels at the upper-left corner
    vol_red = _shield_number_labels(vol_red)
    vol_gray = _shield_number_labels(vol_gray)

    # divide into two 3D volume according to status: Stress or Rest.
    stress_red, rest_red = _stress_rest_divide(vol_red)
    stress_gray, rest_gray = _stress_rest_divide(vol_gray)

    return stress_red, rest_red, stress_gray, rest_gray


    
stress_red, rest_red, stress_gray, rest_gray = image_preproc_basic(image)
# show_volume(stress_gray[0])

### 3D Registration using Rigid-body transform
register the rest-volume to the stress-volume.

Registration include the following steps:
1. Initialization: Center transform
2. Similarity metric: Mean squares
3. Optimizer: Regular step gradient descent
4. Interpolator: Linear

In [None]:
def registration_estimate(fixed_nda, moving_nda):
    # Transform image from ndarray to sitk image
    fixed_image = sitk.Cast(sitk.GetImageFromArray(fixed_nda), sitk.sitkFloat32)
    moving_image = sitk.Cast(sitk.GetImageFromArray(moving_nda), sitk.sitkFloat32)
    
    # Setup the initial transform and registration parameters
    initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMeanSquares()
    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=1.0,
                                                                minStep=1e-5,
                                                                relaxationFactor=0.5,
                                                                gradientMagnitudeTolerance=1e-4,
                                                                numberOfIterations=100)
    registration_method.SetOptimizerScalesFromPhysicalShift() 
    registration_method.SetInitialTransform(initial_transform, inPlace=False)
    
    # Connect all of the observers so that we can perform plotting during registration
    registration_method.AddCommand(sitk.sitkStartEvent, rgui.start_plot)
    registration_method.AddCommand(sitk.sitkEndEvent, rgui.end_plot)
    registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, rgui.update_multires_iterations)
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: rgui.plot_values(registration_method))
    
    # Execute the registration-estimation
    final_transform = registration_method.Execute(fixed_image, moving_image)
    print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
    print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))

    # Reample the moving image
    moving_image_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())
    moving_image_resampled = sitk.GetArrayFromImage(moving_image_resampled).astype(np.uint8)

    return moving_image_resampled, final_transform, registration_method.GetMetricValue()


def imregister(fixed_nda, moving_nda, final_transform):
    fixed_image = sitk.Cast(sitk.GetImageFromArray(fixed_nda), sitk.sitkFloat32)
    moving_image = sitk.Cast(sitk.GetImageFromArray(moving_nda), sitk.sitkFloat32)
    
    # reample the moving image
    moving_image_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())
    moving_image_resampled = sitk.GetArrayFromImage(moving_image_resampled).astype(np.uint8)

    return moving_image_resampled


In [None]:
rest_gray_registered = []
rest_red_registered = []
for i in range(len(rest_gray)):
    _, final_transform, _ = registration_estimate(stress_gray[i], rest_gray[i])

    rest_gray_registered.append(imregister(stress_gray[i], rest_gray[i], final_transform))
    rest_red_registered.append(imregister(stress_red[i], rest_red[i], final_transform))


# check
# show_volume(rest_gray_registered[0])

### Masking

1. Get the centroids of the heart-wall for each 2d block (SA, HLA and VLA view were performed sequentially)

In [None]:
def _count_wall_size(block, threshold):
    block_binarized = block > threshold
    block_binarized = block_binarized.astype(np.uint8)
    num_pixel = np.sum(block_binarized, axis=(0,1))

    return num_pixel


def _get_heart_wall(block, arg):
    threshold = arg['wall_threshold_initial']
    # tuning the threshold value
    initial_num_pixel = _count_wall_size(block, threshold)
    if initial_num_pixel < arg['wall_lw_limit']:
        threshold = arg['wall_threshold_lower']
    elif initial_num_pixel > arg['wall_up_limit']:
        threshold = arg['wall_threshold_higher']

    new_num_pixel = _count_wall_size(block, threshold)
    if new_num_pixel<arg['wall_lw_limit'] or new_num_pixel>arg['wall_up_limit']:
        wall = None
    else:
        wall = block > threshold
    
    return wall


def compute_cetroids(vol, arg):
    centroids = []
    # get (x,y) coordinate of centroid for each block
    for block_idx in range(vol.shape[-1]):
        wall = _get_heart_wall(vol[:,:,block_idx], arg)

        if wall is None:
            centroids.append([np.nan, np.nan])
        else:
            x = np.arange(0, wall.shape[0])
            y = np.arange(0, wall.shape[1])
            xv, yv = np.meshgrid(x, y)
            xc = np.mean(xv[wall], axis=0)
            yc = np.mean(yv[wall], axis=0)
            centroids.append(np.asarray([xc, yc]))

    # compute global centroid based on valid centroid coordinates
    centroids_nda = np.asarray(centroids)
    centroids_valid = centroids_nda[~np.isnan(np.sum(centroids_nda, axis=1)), :]
    global_centroid = np.mean(centroids_valid, axis=0)
    for i, centroid in enumerate(centroids):
        if any(np.isnan(centroid)):
            centroids[i] = global_centroid

    return centroids


In [None]:
centroids_stress = []
centroids_rest = []
for vol_stress, vol_rest in zip(stress_red, rest_red):
    centroids_stress.append(compute_cetroids(vol_stress, arg))
    centroids_rest.append(compute_cetroids(vol_rest, arg))

print(centroids_stress[0])
print(centroids_stress[1])



In [None]:
def _draw_circle_mask(block_binarized, centroid):
    mask = np.zeros(shape=block_binarized.shape, dtype=bool)
    y = np.arange(0, block_binarized.shape[0])
    x = np.arange(0, block_binarized.shape[1])
    xv, yv = np.meshgrid(x, y)
    xc, yc = centroid[1], centroid[0]

    wall_area = np.sum(block_binarized, axis=(0,1)) 
    radius_square = wall_area/np.pi
    inferior_wall_grid = np.logical_or((xv-xc)**2+(yv-yc)**2<=radius_square, yv<yc)
    mask[inferior_wall_grid] = True
    
    return mask



def get_vol_mask(vol, centroids, arg):
    vol_mask = np.zeros(shape=vol.shape, dtype=bool)
    vol_binarized = (vol>arg['wall_threshold_initial']).astype(np.bool)
    for i in range(vol_binarized.shape[-1]):
        block_binarized = vol_binarized[:,:,i]
        centroid = centroids[i]
        block_mask = _draw_circle_mask(block_binarized, centroid)

        # checking whether apply the mask
        wall_area = np.sum(block_binarized)
        if wall_area > 0:
            overlap_ratio = np.sum(np.logical_and(block_mask,block_binarized)) / wall_area
            lowest_y = np.sum(block_binarized,axis=1).nonzero()[-1][-1]

            if (overlap_ratio>arg['mask_overlap_threshold']) and (centroid[0]>block_binarized.shape[0]/2):
                vol_mask[:,:,i] = block_mask
            elif (overlap_ratio<=arg['mask_overlap_threshold']) and lowest_y>=arg['mask_y_threshold']:
                bar_mask = np.zeros(shape=block_binarized.shape, dtype=bool)
                bar_mask[-15:,:] = True
                vol_mask[:,:,i] = bar_mask

    return vol_mask


In [None]:
masks = []
for i, (vol_stress, vol_rest, centroid_stress, centroid_rest) in enumerate(zip(stress_red, rest_red, centroids_stress, centroids_rest)):
    mask_stress = get_vol_mask(vol_stress, centroid_stress, arg)
    mask_rest = get_vol_mask(vol_rest, centroid_rest, arg)
    joined_mask = np.logical_or(mask_stress, mask_rest)
    masks.append(joined_mask.astype(np.uint8))

stress_masked = [stress_gray[i]*masks[i] for i in range(len(masks))]
rest_masked = [rest_gray[i]*masks[i] for i in range(len(masks))]

# check
# show_volume(rest_masked[0].astype(np.uint8))

In [None]:
stress_final = np.concatenate(stress_masked, axis=2)
rest_final = np.concatenate(rest_masked, axis=2)
output = np.concatenate((stress_final[:,:,:,None], rest_final[:,:,:,None]), axis=3)
print(output.shape)