In [1]:
# import os
import numpy as np
import SimpleITK
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
from __future__ import division
%pylab inline

Populating the interactive namespace from numpy and matplotlib


A method to display slices

In [3]:
def sitk_show(img, title=None, margin=0.0, dpi=40):
    nda = SimpleITK.GetArrayFromImage(img)
    #spacing = img.GetSpacing()
    figsize = (1 + margin) * nda.shape[0] / dpi, (1 + margin) * nda.shape[1] / dpi
    #extent = (0, nda.shape[1]*spacing[1], nda.shape[0]*spacing[0], 0)
    extent = (0, nda.shape[1], nda.shape[0], 0)
    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_axes([margin, margin, 1 - 2*margin, 1 - 2*margin])
    plt.set_cmap("gray")
    ax.imshow(nda,extent=extent,interpolation=None)
    
    if title:
        plt.title(title)
    
    plt.show()

Adds as outlier voxels 3 standard deviations away from the norm

In [4]:
def add_outlier(seg, img, k):
    for i in range(1,k+1):
        stdev = img[seg==i].std()
        mean = img[seg==i].mean()
        sub = np.all([seg==i, np.any([img>mean+3*stdev, img<mean-3*stdev], 0)],0)
        seg[sub] = 4
        
        

Takes an image segmentation (gray matter, white matter and CSF) and initializes theta. outliers are more than 3 standard deviations away from norm

In [5]:
def init_theta(seg,img, theta, k, c):
    for i in range(0,c):
        for j in range(1,k+2):
            theta[i][2*j] = img[i][seg==(j)].mean()
            theta[i][2*j+1] = img[i][seg==(j)].var()
            
            
    
        

def init_tumor(seg, tumor, c):
    for i in range(0,c):
        tumor[i][seg==4] = 1
            

In [None]:
def T_prob(Ti, alphai):
    return alphai**np.sum(Ti) * (1-alphai)**(Ti.size-np.sum(Ti))

In [192]:
# def help_func(yi,tic, muk, stdevk, mukP1, stdevkP1):
#     return mlab.normpdf(yi,muk,stdevk)**(1-tic) * mlab.normpdf(yi,mukP1, stdevkP1)**tic

def Y_prob(yi, ti, ki, theta):
    mu=[theta[0][2*ki],theta[1][2*ki],theta[2][2*ki],theta[3][2*ki]]
    var=[theta[0][2*ki+1],theta[1][2*ki+1],theta[2][2*ki+1],theta[3][2*ki+1]]
    mukP1 = [theta[0][-2],theta[1][-2],theta[2][-2],theta[3][-2]]
    varkP1 = [theta[0][-1],theta[1][-1],theta[2][-1],theta[3][-1]]
    return mlab.normpdf(yi,mu,np.sqrt(var))**(1-tic) * mlab.normpdf(yi,mukP1, np.sqrt(varkP1))**tic
#     return np.prod(map(help_func, yi, ti, mu, stdev, mukP1, stdevkP1))

In [12]:
def get_class(seg, probMap):
    return np.argmax([np.mean(probMap[seg==1]), np.mean(probMap[seg==2]), np.mean(probMap[seg==3])])+1

In [13]:
def Qi(yi,ti, k, theta, alphai, atlas, index):
    sum = 0
    for i in range(1,k+1): #3 classes are 1,2,3
        sum += Y_prob(yi, ti, i, theta)*T_prob(ti, alphai)*atlas[i][index]
    return sum

In [193]:
def Wik(yi, theta, atlas, ti, index, k):
    mu=[theta[0][2*k],theta[1][2*k],theta[2][2*k],theta[3][2*k]]
    var=[theta[0][2*k+1],theta[1][2*k+1],theta[2][2*k+1],theta[3][2*k+1]]
    return atlas[k][index]*np.prod(mlab.normpdf, yi, mu, np.sqrt(var))**(1-ti)

In [76]:
def get_all_t(c): # returnes the 2^c configurations
    if c == 1:
        return [np.array([0]), np.array([1])]
    else:
        arr = get_all_t(c-1) 

In [None]:
def alpha_update(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0]) * np.true_divide(np.sum(all_t,1), ci))

In [166]:
def mu_update_numerator(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*
           map(Wik, [img[:,index]]*all_t.shape[0],[theta]*all_t.shape[0], [atlas]*all_t.shape[0],
                all_t, [index]*all_t.shape[0], [ki]*all_t.shape[0])*(1-all_t[:,ci])*img[ci,index])

def update_denominator(index): # mu and stdev update denominators are the same
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*
           map(Wik, [img[:,index]]*all_t.shape[0],[theta]*all_t.shape[0], [atlas]*all_t.shape[0],
                all_t, [index]*all_t.shape[0], [ki]*all_t.shape[0])*(1-all_t[:,ci]))

def stddev_update_numerator(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*
           map(Wik, [img[:,index]]*all_t.shape[0],[theta]*all_t.shape[0], [atlas]*all_t.shape[0],
                all_t, [index]*all_t.shape[0], [ki]*all_t.shape[0])*(1-all_t[:,ci])*(img[ci,index]-theta[ci][2*ki])**2)

def tumor_mu_update_numerator(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*all_t[:,ci]*img[ci,index])

def tumor_update_denominator(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*all_t[:,ci])

def tumor_stddev_update_numerator(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*all_t[:,ci]*(img[ci,index]-theta[ci][2*ki])**2)

In [167]:
def get_T_for_C(index):
    np.sum(map(Qi, [img[:,index]]*all_t.shape[0], all_t, [K]*all_t.shape[0], [theta]*all_t.shape[0], 
        [alpha[index]]*all_t.shape[0], [atlas]*all_t.shape[0], [index]*all_t.shape[0])*all_t[:,ci])

Initialization of variables, and some loading that had to be done in order

In [2]:
print "Loading data and initializing variables"
dir_path = 'BRATS-2/Image_Data/HG/0001/'
brainT1_path = dir_path + 'VSD.Brain.XX.O.MR_T1/VSD.Brain.XX.O.MR_T1.685.mha'
brainT1c_path =dir_path + 'VSD.Brain.XX.O.MR_T1c/VSD.Brain.XX.O.MR_T1c.686.mha'
brainT2_path = dir_path + 'VSD.Brain.XX.O.MR_T2/VSD.Brain.XX.O.MR_T2.687.mha'
brainFLAIR_path = dir_path + 'VSD.Brain.XX.O.MR_Flair/VSD.Brain.XX.O.MR_Flair.684.mha'
brainT1_img = SimpleITK.ReadImage(brainT1_path)
brainT1c_img = SimpleITK.ReadImage(brainT1c_path)
brainT2_img = SimpleITK.ReadImage(brainT2_path)
brainFlair_img = SimpleITK.ReadImage(brainFLAIR_path)
# truth_path = 'BRATS-2/Image_Data/HG/0001/VSD.Brain_3more.XX.XX.OT/VSD.Brain_3more.XX.XX.OT.6560.mha'
# truth_img = SimpleITK.ReadImage(truth_path)
initial_seg_path = 'BRATS-2/Image_Data/HG/0001/seg.nii.gz'
initial_seg = SimpleITK.ReadImage(initial_seg_path)
atlas_path = 'BRATS-2/Image_Data/HG/0001/Atlas/'

In [18]:
K = 3 #3 tissue classes
C = 4 # my data has 4 channels
N = brainT1_img.GetHeight()*brainT1_img.GetWidth()*brainT1_img.GetDepth()
theta = np.empty([C, 2*(K+2)], double) # first 2 vals in a row are empty to match index's. class one is 2*1 and 2*1+1

img = np.empty([C,N])
img[0] = SimpleITK.GetArrayFromImage(brainT1_img).flatten()
img[1] = SimpleITK.GetArrayFromImage(brainT1c_img).flatten()
img[2] = SimpleITK.GetArrayFromImage(brainT2_img).flatten()
img[3] = SimpleITK.GetArrayFromImage(brainFlair_img).flatten()
initial_seg = SimpleITK.GetArrayFromImage(initial_seg).flatten()

atlas = np.empty((4,N), double) # atlas[0] is empty to make index's match with segmentation.
csf = SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(atlas_path + "CSF_warped.nii.gz")).flatten()
gm = SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(atlas_path + "GM_warped.nii.gz")).flatten()
wm = SimpleITK.GetArrayFromImage(SimpleITK.ReadImage(atlas_path + "WM_warped.nii.gz")).flatten()
atlas[get_class(initial_seg, csf)] = csf
# print "csf class: " + str(get_class(initial_seg, csf))
atlas[get_class(initial_seg, gm)] = gm
# print "gm class: " + str(get_class(initial_seg, gm))
atlas[get_class(initial_seg, wm)] = wm
# print "wm class: " + str(get_class(initial_seg, wm))

add_outlier(initial_seg, img[0], K)
alpha = np.full(N, 0.3)
alpha[initial_seg==4] = 0.7

init_theta(initial_seg, img, theta, K, C)
all_t = get_all_t(C)

#### EM, i.e. main segmentation

Considering only non-zero( in all channels) voxels

In [124]:
indexs = np.where(~np.any(img,axis=0))[0] # index's of non-zero intensity vectors yi
for j in range(0,16): #for now do 15 iterations of EM
    print "EM iteration " + str(j)
    print "\tPerforming alpha step"
    alpha[indexs] = map(alpha_update, indexs) # alpha update
    for ci in range(0,C):
        print"\tPerforming theta update for normal tissues"
        for ki in range (1,K+2): #update mean and var for normal tissues
            theta[ci][2*ki] = np.sum(map(mu_update_numerator, indexs)) / np.sum(map(mu, update_denominator, indexs))
            theta[ci][2*ki+1] = np.sum(map(stddev_update_numerator, indexs)) / np.sum(map(update_denominator, indexs))
        ki = 4  #update mean and var for tumor tissues
        print "\tPerforming theta update for tumor tissue"
        theta[ci][2*ki] = np.sum(map(tumor_mu_update_numerator, index)) / np.sum(map(tumor_update_denominator, index))
        theta[ci][2*ki+1] = np.sum(map(tumor_stddev_update_numerator, index)) / np.sum(map(tumor_update_denominator, index))

IndentationError: unexpected indent (<ipython-input-124-3e8625d97561>, line 2)

Get final segmentations

In [None]:
print "Computing final segmentations"
final_tumor = np.zeros((C,N), double)
for c in range(0,C):
    final_tumor[c][index] = map(get_T_for_C, index) 

Save the segmentations

In [None]:
print "Saving images"
t1_seg = SimpleITK.GetImageFromArray(np.reshape(final_tumor[0],(176,216,160)))
t1c_seg = SimpleITK.GetImageFromArray(np.reshape(final_tumor[1],(176,216,160)))
t2_seg = SimpleITK.GetImageFromArray(np.reshape(final_tumor[2],(176,216,160)))
flair_seg = SimpleITK.GetImageFromArray(np.reshape(final_tumor[3],(176,216,160)))
SimpleITK.WriteImage(t1_seg, dir_path + 'my_T1_seg.nii.gz')
SimpleITK.WriteImage(t1c_seg, dir_path + 'my_T1c_seg.nii.gz')
SimpleITK.WriteImage(t2_seg, dir_path + 'my_T2_seg.nii.gz')
SimpleITK.WriteImage(flair_seg, dir_path + 'my_Flair_seg.nii.gz')