# 3D-CNN for nodule detection

In [1]:
import os
import numpy as np
from keras.models import Model, Sequential, load_model
from keras.layers import Input, Dense, Dropout, Activation, Flatten, merge
from keras.optimizers import SGD, Adam
from keras.layers import Convolution2D, MaxPooling2D, Convolution3D, MaxPooling3D,UpSampling2D
from keras.layers import Conv2D
from keras.layers.merge import concatenate

from keras import backend as K
from glob import glob


Using TensorFlow backend.


In [9]:
def square_loss(y_true, y_pred):
    return np.sum((y_true - y_pred)*(y_true - y_pred))/3
    

def detection(inputs_shape, kernel_size=3, pool_size=2):
    inputs = Input(inputs_shape, name='inputs')  # (bs,512,512)

    conv1 = Convolution2D(24, 3, 3, activation='relu', border_mode='same')(inputs)
    dpr1 = Dropout(0.3)(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(dpr1)
    
    conv2 = Convolution2D(24, 3, 3, activation='relu', border_mode='same')(pool1)
    dpr2 = Dropout(0.3)(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(dpr2)
    
    conv3 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(pool2)
    dpr3 = Dropout(0.3)(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(dpr3)

    conv4 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(pool3)
    dpr4 = Dropout(0.3)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(dpr4)
    
    conv5 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(pool4)
    dpr5 = Dropout(0.3)(conv5)
    pool5 = MaxPooling2D(pool_size=(2, 2))(dpr5)
    
    conv6 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(pool5)
    dpr6 = Dropout(0.3)(conv6)
    pool6 = MaxPooling2D(pool_size=(2, 2))(dpr6)
    
    conv7 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(pool6)
    dpr7 = Dropout(0.3)(conv7)
    pool7 = MaxPooling2D(pool_size=(2, 2))(dpr7)
    
    flt = Flatten()(pool7)
    ds1 = Dense(128)(flt)
    ds2 = Dense(64)(ds1)
    ds3 = Dense(3)(ds2)
    
    model = Model(inputs=inputs, outputs=ds3)
    model.summary()
    optimizer = Adam(lr=2e-3)
    model.compile(optimizer=optimizer, loss=square_loss)

    return model

model = detection([512,512,128])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
inputs (InputLayer)          (None, 512, 512, 128)     0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 512, 512, 24)      27672     
_________________________________________________________________
dropout_8 (Dropout)          (None, 512, 512, 24)      0         
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 256, 256, 24)      0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 256, 256, 24)      5208      
_________________________________________________________________
dropout_9 (Dropout)          (None, 256, 256, 24)      0         
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 128, 128, 24)      0         
__________

# #load in data for trainning

In [11]:
input_path = '/media/izm/Normal/FJJ/'
load_batch = 12
batch_size = 4
label_name = input_path+'nodule-slcs/'+'nodule-area.txt'
save_name = 'detection'
model.load_weights("model/detection-model/detection200")
def load_and_train(dect_model):
    lb_file = open(label_name)
    lines = lb_file.readlines()
    for start in range(0, len(lines), load_batch):
        data = []
        lb = []
        for idx in range(start, min(start+load_batch, len(lines))):
            objs = lines[idx].split(' ')
            print objs[0]
            array = np.load(input_path+'nodule-slcs/zclip-'+objs[0]+'.npy').transpose(1,2,0) 
            data.append(array)
            lb.append([float(objs[2]), float(objs[3]), float(objs[4])-float(objs[1])*128])   #z-value subtract the idex*128
        data = np.array(data)
        lb = np.array(lb)
        dect_model.fit(x=data, y=lb, batch_size=batch_size, nb_epoch=1, shuffle=True)
    dect_model.save_weights("model/"+save_name)
    print "model finish trianing and saved: ", save_name
    return dect_model

for i in range(10):
    model = load_and_train(model)
    print "Now ----------------------------step ", i
             
            

LKDS-00001.mhd-1
LKDS-00003.mhd-1
LKDS-00003.mhd-2
LKDS-00004.mhd-1
LKDS-00005.mhd-1
LKDS-00007.mhd-1
LKDS-00011.mhd-1
LKDS-00013.mhd-1
LKDS-00015.mhd-1
LKDS-00016.mhd-1
LKDS-00016.mhd-2
LKDS-00016.mhd-3
Epoch 1/1
LKDS-00019.mhd-1
LKDS-00019.mhd-2
LKDS-00020.mhd-1
LKDS-00020.mhd-2
LKDS-00020.mhd-3
LKDS-00020.mhd-4
LKDS-00020.mhd-5
LKDS-00020.mhd-6
LKDS-00020.mhd-7
LKDS-00020.mhd-8
LKDS-00020.mhd-9
LKDS-00020.mhd-10
Epoch 1/1
LKDS-00020.mhd-11
LKDS-00021.mhd-1
LKDS-00023.mhd-1
LKDS-00023.mhd-2
LKDS-00025.mhd-1
LKDS-00026.mhd-2
LKDS-00028.mhd-1
LKDS-00029.mhd-1
LKDS-00030.mhd-1
LKDS-00034.mhd-1
LKDS-00035.mhd-1
LKDS-00035.mhd-2
Epoch 1/1
LKDS-00036.mhd-1
LKDS-00038.mhd-1
LKDS-00039.mhd-1
LKDS-00039.mhd-2
LKDS-00040.mhd-1
LKDS-00041.mhd-1
LKDS-00042.mhd-1
LKDS-00043.mhd-1
LKDS-00047.mhd-1
LKDS-00047.mhd-2
LKDS-00050.mhd-1
LKDS-00051.mhd-1
Epoch 1/1
LKDS-00053.mhd-1
LKDS-00054.mhd-1
LKDS-00054.mhd-2
LKDS-00054.mhd-3
LKDS-00054.mhd-4
LKDS-00054.mhd-5
LKDS-00058.mhd-1
LKDS-00061.mhd-1
LKDS-0

KeyboardInterrupt: 

# Use the model above to do prediction

In [39]:
input_path = '/media/izm/Normal/FJJ/'
label_name = input_path+'nodule-slcs/'+'nodule-area.txt'
lb_file = open(label_name)
lines = lb_file.readlines()
objs = lines[59].split(' ')
print objs[0]

data = []
lb = []

array = np.load(input_path+'nodule-slcs/zclip-'+objs[0]+'.npy').transpose(1,2,0)
data.append(array)
lb.append([float(objs[3]), float(objs[2]), float(objs[4])-float(objs[1])*128])   #z-value subtr
data = np.array(data)
lb = np.array(lb, int)

data = data[0]
lb = lb[0]



cnt = 0
x = y = z = 0.0
for i in range(512):
    for j in range(512):
        for k in range(128):
            if data[i][j][k] == 1:
                x += i
                y += j
                z += k
                cnt += 1
                print (i,j,k)
x/=cnt
y/=cnt
z/=cnt

#model.load_weights("model/detection-model/detection")
#prd = model.predict(data)
print (x,y,z), lb

LKDS-00064.mhd-2
(293, 219, 50)
(293, 219, 51)
(294, 219, 50)
(294, 219, 51)
(294, 219, 52)
(294, 219, 53)
(294, 219, 55)
(294, 220, 50)
(294, 220, 51)
(294, 220, 52)
(294, 220, 53)
(294, 220, 54)
(294, 220, 55)
(294, 220, 56)
(294, 220, 57)
(294, 220, 58)
(294, 221, 55)
(294, 221, 56)
(294, 221, 57)
(294, 221, 58)
(294, 222, 55)
(294, 222, 56)
(294, 222, 57)
(295, 219, 50)
(295, 219, 51)
(295, 219, 52)
(295, 219, 53)
(295, 219, 54)
(295, 219, 55)
(295, 219, 56)
(295, 219, 57)
(295, 219, 58)
(295, 220, 51)
(295, 220, 52)
(295, 220, 53)
(295, 220, 54)
(295, 220, 55)
(295, 220, 56)
(295, 220, 57)
(295, 220, 58)
(295, 220, 59)
(295, 221, 54)
(295, 221, 55)
(295, 221, 56)
(295, 221, 57)
(295, 221, 58)
(295, 221, 59)
(295, 222, 54)
(295, 222, 55)
(295, 222, 56)
(295, 222, 57)
(295, 222, 58)
(295, 222, 59)
(295, 223, 55)
(295, 223, 56)
(295, 223, 57)
(295, 223, 58)
(295, 224, 56)
(295, 224, 57)
(296, 219, 53)
(296, 219, 54)
(296, 219, 55)
(296, 219, 56)
(296, 219, 57)
(296, 219, 58)
(296, 21

In [13]:
import matplotlib.pyplot as plt
from skimage.measure import label,regionprops, perimeter
from skimage import measure, feature
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

def plot_3d(image, threshold=-1000):   
    # Position the scan upright, 
    # so the head of the patient would be at the top facing the camera
    p = image.transpose(2,1,0)
    p = p[:,:,::-1]    
    #print len(measure.marching_cubes(p, threshold))
    verts, faces, a, b = measure.marching_cubes(p, threshold)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    # Fancy indexing: `verts[faces]` to generate a collection of triangles
    mesh = Poly3DCollection(verts[faces], alpha=0.1)
    face_color = [0.5, 0.5, 1]
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)

    ax.set_xlim(0, p.shape[0])
    ax.set_ylim(0, p.shape[1])
    ax.set_zlim(0, p.shape[2])
    plt.show()