In [1]:
from keras.layers.pooling import AveragePooling3D, GlobalMaxPooling3D
from keras.layers import Input, merge, Activation, Dropout
from keras.optimizers import Adamax, Adam, Nadam, sgd
from keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,LearningRateScheduler,EarlyStopping

Using TensorFlow backend.


In [4]:
def learningRateSchedule(epoch):
    if epoch  < 2:
        return 1e-2
    if epoch < 5:
        return 1e-3
    if epoch < 10:
        return 5e-4
    return 5e-5

In [5]:
def compileModelDeepBranching(inputShape, regRate, dropRate):
    
    x = Input(inputShape)
    
    x1 = convCNNBlock(x, 8, dropRate=0, regRate=regRate)
    x1Pool = AveragePooling3D(dim_ordering="th")(x)
    x1Merged = merge([x1, x1Pool], mode='concat', concat_axis=1)
    
    x2 = convCNNBlock(x1Merged, 24, dropRate=0, regRate=regRate)
    x2Pool = AveragePooling3D(dim_ordering="th")(x1Pool)
    x2Merged = merge([x2, x2Pool], mode='concat', concat_axis=1)
    
    x3 = convCNNBlock(x2Merged, 48, dropRate=0, regRate=regRate)
    x3Pool = AveragePooling3D(dim_ordering="th")(x2Pool)
    x3Merged = merge([x3,x3Pool], mode='concat', concat_axis=1)

    x4 = convCNNBlock(x3Merged, 64, dropRate=0, regRate=regRate)
    x4Pool = AveragePooling3D(dim_ordering="th")(x3Pool)
    x4Merged = merge([x4, x4Pool], mode='concat', concat_axis=1)
    
    
    ###Malignancy###
    x5Malig = convCNNBlock(x4Merged, 65, dropRate=0, regRate=regRate)
    xMaxPoolMalig = GlobalMaxPooling3D()(x5Malig)
    xMaxPoolNormMalig = BatchNormalization()(xMaxPoolMalig) 
    xMalig = denseCNNBlock(xMaxPoolNormMalig, name='Malignancy', 
                           outSize=1, activation='softplus', 
                           dropRate=dropRate, regRate=regRate)
    
    ###Diameter###
    x5Diam = convCNNBlock(x4Merged, 65, dropRate=0, regRate=regRate)
    xMaxPoolDiam = GlobalMaxPooling3D()(x5Diam)
    xMaxPoolNormDiam = BatchNormalization()(xMaxPoolDiam) 
    xDiam = denseCNNBlock(xMaxPoolNormDiam, name='Diameter', 
                          outSize=1, activation='softplus', 
                          dropRate=dropRate, regRate=regRate)
    
    ###Lobulation###
    x5Lob = convCNNBlock(x4Merged, 65, dropRate=0, regRate=regRate)
    xMaxPoolLob = GlobalMaxPooling3D()(x5Lob)
    xMaxPoolNormLob = BatchNormalization()(xMaxPoolLob) 
    xLob = denseCNNBlock(xMaxPoolNormLob, name='Lobulation', 
                         outSize=1, activation='softplus', 
                         dropRate=dropRate, regRate=regRate)
    
    ###Spiculation###
    x5Spic = convCNNBlock(x4Merged, 65, dropRate=0, regRate=regRate)
    xMaxPoolSpic = GlobalMaxPooling3D()(x5Spic)
    xMaxPoolNormSpic = BatchNormalization()(xMaxPoolSpic) 
    xSpic = denseCNNBlock(xMaxPoolNormSpic, name='Spiculation', 
                          outSize=1, activation='softplus', 
                          dropRate=dropRate, regRate=regRate)

    
    model = Model(input=x, output=[xMalig, xDiam, xLob, xSpic])

    opt = Nadam()
#     opt = sgd(0.01, nesterov=True)
    
    print ('Compiling model...')
    
    model.compile(optimizer=opt,
                  loss={'Malignancy':'mse', 'Diameter':'mse', 'Lobulation':'mse',
                       'Spiculation':'mse'},
                  loss_weights={'Malignancy':3, 'Diameter':1, 'Lobulation':1,
                       'Spiculation':1})
    
    return model

In [None]:
def randomFlips(Xbatch):
    
    swaps = np.random.choice([-1,1],size=(Xbatch.shape[0],3))
    for i in range(Xbatch.shape[0]):
 
        Xbatch[i] = Xbatch[i,::swaps[i,0],::swaps[i,1],::swaps[i,2]]
        
    return Xbatch

In [6]:
def trainRegressionModel(model, modelPath, modelPathClass, 
                         validInd, testSize=0.2, batchSize=10, 
                         nbEpoch=1, stepsPerEpoch=2, fp=False,
                         posFraction=0.5):
    

    if fp==False:
        trainGenerator = batchGeneratorRegression(xPosTrain,xNegTrain,
                                        ixPosTrain,ixNegTrain,
                                        batchSize=batchSize, 
                                        posFraction=posFraction, fp=False)
        
    else:
        trainGenerator = batchGeneratorRegression(xPosTrain, xNegTrain,
                                ixPosTrain, ixNegTrain,
                                batchSize=batchSize,
                                xFP=xFP, ixFP=ixFP, 
                                posFraction=posFraction, fp=True)

    validGenerator = batchGeneratorRegression(xPosValid,xNegValid,
                                    ixPosValid,ixNegValid,
                                    batchSize=batchSize,
                                    posFraction=posFraction)
        
        
    ckp = ModelCheckpoint(filepath=modelPath)
        
    lossHist = {}
    
    for lossType in ['Diameter_loss', 'val_loss', 
                     'val_Lobulation_loss', 'Spiculation_loss', 
                     'loss', 'val_Diameter_loss', 'val_Spiculation_loss', 
                     'Lobulation_loss', 'val_Malignancy_loss', 'Malignancy_loss']:
        
        lossHist[lossType] = []
        
    lr = LearningRateScheduler(learningRateSchedule)
    es = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=4, verbose=0, mode='auto')
    
    for epoch in range(nbEpoch):
        
        hist = model.fit_generator(trainGenerator, validation_data=validGenerator, 
                                   validation_steps=20,steps_per_epoch=stepsPerEpoch,
                                   nb_epoch=epoch+1,callbacks=[ckp],
                                   initial_epoch=epoch)
        
        for lossType in hist.history.keys():
            lossHist[lossType].extend(hist.history[lossType])

    return model, lossHist