In [1]:
from keras.layers.pooling import AveragePooling3D, GlobalMaxPooling3D
from keras.layers import Input, merge, Activation, Dropout
from keras.optimizers import Adamax, Adam, Nadam, sgd
from keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,LearningRateScheduler,EarlyStopping

Using TensorFlow backend.


In [2]:
def learningRateSchedule(epoch):
    if epoch  < 2:
        return 1e-2
    if epoch < 5:
        return 1e-3
    if epoch < 10:
        return 5e-4
    return 5e-5

In [5]:
def compileModel(inputShape, dropRate, regRate):
    
    x = Input(inputShape)
    
    x1 = convCNNBlock(x, 8, dropRate=0, regRate=regRate)
    x1Pool = AveragePooling3D(dim_ordering="th")(x)
    x1Merged = merge([x1, x1Pool], mode='concat', concat_axis=1)
    
    x2 = convCNNBlock(x1Merged, 24, dropRate=0, regRate=regRate)
    x2Pool = AveragePooling3D(dim_ordering="th")(x1Pool)
    x2Merged = merge([x2, x2Pool], mode='concat', concat_axis=1)
    
    x3 = convCNNBlock(x2Merged, 48, dropRate=0, regRate=regRate)
    x3Pool = AveragePooling3D(dim_ordering="th")(x2Pool)
    x3Merged = merge([x3,x3Pool], mode='concat', concat_axis=1)

    x4 = convCNNBlock(x3Merged, 64, dropRate=0, regRate=regRate)
    x4Pool = AveragePooling3D(dim_ordering="th")(x3Pool)
    x4Merged = merge([x4, x4Pool], mode='concat', concat_axis=1)

    x5 = convCNNBlock(x4Merged, 65, dropRate=0, regRate=regRate)
    
    xMaxPool = GlobalMaxPooling3D()(x5)
    xMaxPoolNorm = BatchNormalization()(xMaxPool) 
    
    xOut = denseCNNBlock(xMaxPoolNorm, name='Nodule', outSize=2, activation='softmax', 
                         dropRate=dropRate, regRate=regRate, neuronNumber=5)
    
    model = Model(input=x, output=xOut)

#     opt = sgd(0.01, nesterov=True)
    opt = Nadam()
    
    print ('Compiling model...')
    
    model.compile(optimizer=opt,
                  loss='categorical_crossentropy',
                metrics=['categorical_accuracy'])
    
    return model

In [None]:
def randomFlips(Xbatch):
    
    swaps = np.random.choice([-1,1],size=(Xbatch.shape[0],3))
    for i in range(Xbatch.shape[0]):
 
        Xbatch[i] = Xbatch[i,::swaps[i,0],::swaps[i,1],::swaps[i,2]]
        
    return Xbatch

In [None]:
def trainModelClass(model, modelPath, testSize=0.2, batchSize=10, nbEpoch = 1, stepsPerEpoch = 2, fp=False):
    
    print ('Loading positive patches')
    xPos = loadCategoryClass('true')
#     xPos = randomFlips(xPos)
    xPosTrain,xPosValid,indPosTrain,indPosValid = train_test_split(xPos, 
                                                np.array([n for n in range(xPos.shape[0])]), 
                                                test_size=testSize)
    ixPosTrainClass = np.ones((xPosTrain.shape[0]))
    ixPosValidClass = np.ones((xPosValid.shape[0]))
    del xPos
    
    print ('Loading negative patches')
    xNeg = loadCategoryClass('random')
    
    xNegTrain,xNegValid,indNegTrain,indNegValid = train_test_split(xNeg, 
                                                np.array([n for n in range(xNeg.shape[0])]), 
                                                test_size=testSize)
    ixNegTrainClass = np.zeros((xNegTrain.shape[0]))
    ixNegValidClass = np.zeros((xNegValid.shape[0]))
    del xNeg
    
    trainGenerator = batchGeneratorClass(xPosTrain,xNegTrain,
                                    ixPosTrainClass,ixNegTrainClass,
                                    batchSize=batchSize,
                                    posFraction=.5)

    validGenerator = batchGeneratorClass(xPosValid,xNegValid,
                                    ixPosValidClass,ixNegValidClass,
                                    batchSize=batchSize,
                                    posFraction=.5)
        
    ckp = ModelCheckpoint(filepath=modelPath)
        
    lr = LearningRateScheduler(learningRateSchedule)
    es = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=0, mode='auto')
    
    lossHist = {'loss':[], 'val_loss':[], 'val_categorical_accuracy':[], 'categorical_accuracy':[]}
    
    for epoch in range(nbEpoch):
        hist = model.fit_generator(trainGenerator, validation_data=validGenerator, 
                                   validation_steps=10,steps_per_epoch=stepsPerEpoch,
                                   nb_epoch=epoch+1,callbacks=[ckp],
                                   initial_epoch=epoch)
        for key in hist.history:
            lossHist[key].extend(hist.history[key])

    return model, lossHist, indPosValid, indNegValid