In [3]:
from keras.layers.pooling import AveragePooling3D, GlobalMaxPooling3D
from keras.layers import Input, merge, Activation, Dropout
from keras.optimizers import Adamax, Adam, Nadam, sgd
from keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,LearningRateScheduler,EarlyStopping

Using TensorFlow backend.


In [4]:
def learningRateSchedule(epoch):
    if epoch  < 2:
        return 1e-2
    if epoch < 5:
        return 1e-3
    if epoch < 10:
        return 5e-4
    return 5e-5

In [5]:
def compileModel(inputShape, dropRate):
    
    x = Input(inputShape)
    
    x1 = convCNNBlock(x, 8, dropRate=dropRate)
    x1Pool = AveragePooling3D(dim_ordering="th")(x)
    x1Merged = merge([x1, x1Pool], mode='concat', concat_axis=1)
    
    x2 = convCNNBlock(x1Merged, 24, dropRate=dropRate)
    x2Pool = AveragePooling3D(dim_ordering="th")(x1Pool)
    x2Merged = merge([x2, x2Pool], mode='concat', concat_axis=1)
    
    x3 = convCNNBlock(x2Merged, 48, dropRate=dropRate)
    x3Pool = AveragePooling3D(dim_ordering="th")(x2Pool)
    x3Merged = merge([x3,x3Pool], mode='concat', concat_axis=1)

    x4 = convCNNBlock(x3Merged, 64, dropRate=dropRate)
    x4Pool = AveragePooling3D(dim_ordering="th")(x3Pool)
    x4Merged = merge([x4, x4Pool], mode='concat', concat_axis=1)

    x5 = convCNNBlock(x4Merged, 65, dropRate=dropRate)
    
    xMaxPool = GlobalMaxPooling3D()(x5)
    xMaxPoolNorm = BatchNormalization()(xMaxPool) 
    
    xMalig = denseCNNBlock(xMaxPoolNorm, name='Malignancy', outSize=1, activation='softplus', dropRate=dropRate)
    xDiam = denseCNNBlock(xMaxPoolNorm, name='Diameter', outSize=1, activation='softplus', dropRate=dropRate)
    xLob = denseCNNBlock(xMaxPoolNorm, name='Lobulation', outSize=1, activation='softplus', dropRate=dropRate)
    xSpic = denseCNNBlock(xMaxPoolNorm, name='Spiculation', outSize=1, activation='softplus', dropRate=dropRate)
    
    model = Model(input=x, output=[xMalig, xDiam, xLob, xSpic])

#     opt = Nadam(0.01, clipvalue=1.0)
    opt = sgd(0.01, nesterov=True, )
    
    print ('Compiling model...')
    
    model.compile(optimizer=opt,
                  loss={'Malignancy':'mse', 'Diameter':'mse', 'Lobulation':'mse',
                       'Spiculation':'mse'},
                  loss_weights={'Malignancy':1, 'Diameter':1, 'Lobulation':1,
                       'Spiculation':1})
    
    return model

In [5]:
def compileModelDeepBranching(inputShape, dropRate):
    
    dropRate=1e-2
    
    x = Input(inputShape)
    
    x1 = convCNNBlock(x, 8, dropRate=dropRate)
    x1Pool = AveragePooling3D(dim_ordering="th")(x)
    x1Merged = merge([x1, x1Pool], mode='concat', concat_axis=1)
    
    x2 = convCNNBlock(x1Merged, 24, dropRate=dropRate)
    x2Pool = AveragePooling3D(dim_ordering="th")(x1Pool)
    x2Merged = merge([x2, x2Pool], mode='concat', concat_axis=1)
    
    x3 = convCNNBlock(x2Merged, 48, dropRate=dropRate)
    x3Pool = AveragePooling3D(dim_ordering="th")(x2Pool)
    x3Merged = merge([x3,x3Pool], mode='concat', concat_axis=1)

    x4 = convCNNBlock(x3Merged, 64, dropRate=dropRate)
    x4Pool = AveragePooling3D(dim_ordering="th")(x3Pool)
    x4Merged = merge([x4, x4Pool], mode='concat', concat_axis=1)
    
    
    ###Malignancy###
    x5Malig = convCNNBlock(x4Merged, 65, dropRate=dropRate)
    xMaxPoolMalig = GlobalMaxPooling3D()(x5Malig)
    xMaxPoolNormMalig = BatchNormalization()(xMaxPoolMalig) 
    xMalig = denseCNNBlock(xMaxPoolNormMalig, name='Malignancy', 
                           outSize=1, activation='softplus', 
                           dropRate=dropRate)
    
    ###Diameter###
    x5Diam = convCNNBlock(x4Merged, 65, dropRate=dropRate)
    xMaxPoolDiam = GlobalMaxPooling3D()(x5Diam)
    xMaxPoolNormDiam = BatchNormalization()(xMaxPoolDiam) 
    xDiam = denseCNNBlock(xMaxPoolNormDiam, name='Diameter', 
                          outSize=1, activation='softplus', 
                          dropRate=dropRate)
    
    ###Lobulation###
    x5Lob = convCNNBlock(x4Merged, 65, dropRate=dropRate)
    xMaxPoolLob = GlobalMaxPooling3D()(x5Lob)
    xMaxPoolNormLob = BatchNormalization()(xMaxPoolLob) 
    xLob = denseCNNBlock(xMaxPoolNormLob, name='Lobulation', 
                         outSize=1, activation='softplus', 
                         dropRate=dropRate)
    
    ###Spiculation###
    x5Spic = convCNNBlock(x4Merged, 65, dropRate=dropRate)
    xMaxPoolSpic = GlobalMaxPooling3D()(x5Spic)
    xMaxPoolNormSpic = BatchNormalization()(xMaxPoolSpic) 
    xSpic = denseCNNBlock(xMaxPoolNormSpic, name='Spiculation', 
                          outSize=1, activation='softplus', 
                          dropRate=dropRate)

    
    model = Model(input=x, output=[xMalig, xDiam, xLob, xSpic])

#     opt = Nadam(0.01, clipvalue=1.0)
    opt = sgd(0.01, nesterov=True, )
    
    print ('Compiling model...')
    
    model.compile(optimizer=opt,
                  loss={'Malignancy':'mse', 'Diameter':'mse', 'Lobulation':'mse',
                       'Spiculation':'mse'},
                  loss_weights={'Malignancy':1, 'Diameter':1, 'Lobulation':1,
                       'Spiculation':1})
    
    return model

In [None]:
def randomFlips(Xbatch):
    
    swaps = np.random.choice([-1,1],size=(Xbatch.shape[0],3))
    for i in range(Xbatch.shape[0]):
 
        Xbatch[i] = Xbatch[i,::swaps[i,0],::swaps[i,1],::swaps[i,2]]
        
    return Xbatch

In [6]:
def trainModel(model, modelPath, testSize=0.2, batchSize=10, nbEpoch = 1, stepsPerEpoch = 2, fp=False):
    
    print ('Loading positive patches')
    xPos, ixPos = loadCategory('true')
    xPos = randomFlips(xPos)
    xPosTrain,xPosValid,ixPosTrain,ixPosValid = train_test_split(xPos, ixPos, test_size=testSize)   
    del xPos, ixPos
    
    print ('Loading negative patches')
    xNeg, ixNeg = loadCategory('random')
    xNegTrain,xNegValid,ixNegTrain,ixNegValid = train_test_split(xNeg, ixNeg, test_size=testSize)
    del xNeg, ixNeg
    
    if fp==True:
        print ('Loading false positive patches')
        xFP, ixFP = loadCategory('false')
        xFPTrain,xFPValid,ixFPTrain,ixFPValid = train_test_split(xFP, ixFP, test_size=testSize)
        del xFP, ixFP

        trainGenerator = batchGenerator(xPosTrain,xNegTrain,
                                        ixPosTrain,ixNegTrain,
                                        xFP=xFPTrain,ixFP=ixFPTrain,
                                        batchSize=batchSize,
                                        posFraction=.5)

        validGenerator = batchGenerator(xPosValid,xNegValid,
                                        ixPosValid,ixNegValid,
                                        xFP=xFPValid,ixFP=ixFPValid,
                                        batchSize=batchSize,
                                        posFraction=.5)
        
    else:
        
        trainGenerator = batchGenerator(xPosTrain,xNegTrain,
                                        ixPosTrain,ixNegTrain,
                                        batchSize=batchSize,
                                        posFraction=.5)

        validGenerator = batchGenerator(xPosValid,xNegValid,
                                        ixPosValid,ixNegValid,
                                        batchSize=batchSize,
                                        posFraction=.5)
        
        
    ckp = ModelCheckpoint(filepath=modelPath)
        
    lossHist = {}
    
    for lossType in ['Diameter_loss', 'val_loss', 
                     'val_Lobulation_loss', 'Spiculation_loss', 
                     'loss', 'val_Diameter_loss', 'val_Spiculation_loss', 
                     'Lobulation_loss', 'val_Malignancy_loss', 'Malignancy_loss']:
        
        lossHist[lossType] = []
        
    lr = LearningRateScheduler(learningRateSchedule)
    es = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=4, verbose=0, mode='auto')
    
    for epoch in range(nbEpoch):
        hist = model.fit_generator(trainGenerator, validation_data=validGenerator, 
                                   validation_steps=20,steps_per_epoch=stepsPerEpoch,
                                   nb_epoch=epoch+1,callbacks=[ckp, lr, es],
                                   initial_epoch=epoch)
        
        for lossType in hist.history.keys():
            lossHist[lossType].extend(hist.history[lossType])

    return model, lossHist