In [1]:
import cv2
import numpy as np
import SimpleITK as sitk
import registration_gui as rgui
import matplotlib.pyplot as plt
from os.path import join
%matplotlib inline

In [2]:
image = cv2.imread('..\data\TrainSet\\1003.jpg')

# show
# cv2.imshow('Original SPECT MPI image', image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()

### Processing SPECT MPI image
1. Cut-out image blocks.
2. RGB to grayscale.
3. Crop and concatenate blocks into 3D volume.
4. Divide into two 3D volume according to status: Stress & Rest.

In [None]:
# Cut-out image blocks
SA_rows = np.arange(50,411)
HLA_rows = np.arange(483,664)
VLA_rows = np.arange(707,888)
rows = np.concatenate([SA_rows, HLA_rows, VLA_rows], axis=0)
cols = np.arange(69,970)
cc_image = image[rows[:,np.newaxis], cols]

# RGB to grayscale
cc_gray = cv2.cvtColor(cc_image, cv2.COLOR_BGR2GRAY)

# Crop and concatenate blocks into 3D volume.
block_size = 90
cc_gray_3d = np.empty(shape=(block_size,block_size,0), dtype=np.uint8)
for row in range(0, block_size*8, block_size):
    for col in range(0, block_size*10, block_size):
        count = 0
        rows_idx = np.arange(row, row+block_size)
        cols_idx = np.arange(col, col+block_size)
        cc_gray_block = cc_gray[rows_idx[:,np.newaxis], cols_idx, None]
        cc_gray_3d = np.concatenate((cc_gray_3d, cc_gray_block), axis=2)
# print('3D-concatenated volume shape:', cc_gray_3d.shape)
# for frame in range(80):
#     cv2.imshow('block imahe', cc_gray_3d[:,:,frame])
#     cv2.waitKey(100)
# cv2.destroyAllWindows()

# Divide into two 3D volume according to status: Stress & Rest.
stress_series = [
    cc_gray_3d[:,:,np.arange(0,20)], # SA view
    cc_gray_3d[:,:,np.arange(40,50)], # HLA view
    cc_gray_3d[:,:,np.arange(60,70)] # VLA view
    ]

rest_series = [
    cc_gray_3d[:,:,np.arange(20,40)], # SA view
    cc_gray_3d[:,:,np.arange(50,60)], # HLA view
    cc_gray_3d[:,:,np.arange(70,80)] # VLA view
]


### 3D Registration using Rigid-body transform
1. Initialization: Center transform
2. Similarity metric: Mean squares
3. Optimizer: Gradient descent
4. Interpolator: Linear

In [None]:
fixed_image = sitk.Cast(sitk.GetImageFromArray(stress_series[1]), sitk.sitkFloat32)
print('Fixed image info, Origin:', fixed_image.GetOrigin(), 'Spacing:', fixed_image.GetSpacing(), 'Direction:', fixed_image.GetDirection())
moving_image = sitk.Cast(sitk.GetImageFromArray(rest_series[1]), sitk.sitkFloat32)
print('Moving image info, Origin:', moving_image.GetOrigin(), 'Spacing:', moving_image.GetSpacing(), 'Direction:', moving_image.GetDirection())

In [None]:
def multires_registration(fixed_image, moving_image, initial_transform):
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMeanSquares()
    # registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    # registration_method.SetMetricSamplingPercentage(0.5)
    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, estimateLearningRate=registration_method.Once)
    registration_method.SetOptimizerScalesFromPhysicalShift() 
    registration_method.SetInitialTransform(initial_transform, inPlace=False)
    

    # Connect all of the observers so that we can perform plotting during registration
    registration_method.AddCommand(sitk.sitkStartEvent, rgui.start_plot)
    registration_method.AddCommand(sitk.sitkEndEvent, rgui.end_plot)
    registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, rgui.update_multires_iterations)
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: rgui.plot_values(registration_method))
    

    final_transform = registration_method.Execute(fixed_image, moving_image)
    print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
    print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
    return (final_transform, registration_method.GetMetricValue())

In [None]:
initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

final_transform, _ = multires_registration(fixed_image, moving_image, initial_transform)

In [None]:
moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())
nda = sitk.GetArrayFromImage(moving_resampled).astype(np.uint8)

for frame in range(nda.shape[-1]):
    img = cv2.resize(nda[:,:,frame], (256,256))
    cv2.imshow('registered', img)
    cv2.waitKey(0)
cv2.destroyAllWindows()