In [None]:
import numpy as np

import os

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split


import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import Sequence

from tensorflow.keras import layers
from tensorflow.keras.models import Model

from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as keras
from tensorflow.keras import callbacks

from tensorflow.keras import metrics

from scipy.stats import pearsonr


from custom_losses import binary_crossentropy_weight_balance, binary_crossentropy_weight_dict, binary_crossentropy_closeness_to_foreground


In [None]:
def gpu_memory_limit(memory_limit):
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        # Restrict TensorFlow to only allocate 16GB of memory on the first GPU
        try:
            tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit)])
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
            print('GPU memory limit allocated.')
        except RuntimeError as e:
            # Virtual devices must be set before GPUs have been initialized
            print(e)
            
gpu_memory_limit(10000) # 8GB is 1/3 of available

In [None]:
DataDir = './data/pericardial/wsx_20200221/'

#load data - these files created by extract_dcm_for_wsx.ipynb
X = np.load(os.path.join(DataDir,'X.npy'))
Y = np.load(os.path.join(DataDir,'Y.npy')).astype('float')
pxSize = np.load(os.path.join(DataDir,'pxSize.npy'))

#ensure the shape is correct arrays saved were rank 3, so this changes to rank 4 (last dimension represents channels)
X = X.reshape([*X.shape,1])
Y = Y.reshape([*Y.shape,1])



#do train/test split!
X, X_test, Y, Y_test,pxSize,pxSize_test = train_test_split(X, Y, pxSize, test_size=0.2,random_state=101)

#
M = X.shape[0]
MTest = X_test.shape[0]

In [None]:
class augmentImageSequence(Sequence):
    
    '''class for data augmentation on matched image/mask pairs'''
    
    def __init__(self,Images,Masks,dataGenArgs,batchSize=1,seed=42):
        
        #copy raw data in
        self.x,self.y = Images,Masks
        self.batch_size = batchSize
        
        #convert to imageDataGenerators/create flow objects...
        self.augmentIm = ImageDataGenerator(**dataGenArgs).flow(x=Images,batch_size=batchSize,seed=seed)
        self.augmentMa = ImageDataGenerator(**dataGenArgs).flow(x=Masks, batch_size=batchSize,seed=seed)
        
        
    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self,idx):
        #cheaty fake 1-stage loop, returns 1 batch from both flow objects (which will be matched)
        for _,ims,masks in zip(range(1),self.augmentIm,self.augmentMa):        
            
            masks = (masks>0.5).astype('float')
            
            return ims,masks

U-net architecture....

In [None]:
def unet(pretrained_weights = None,input_size = (256,256,1),dropoutRate = 0):
    inputs = layers.Input(input_size)
    conv1 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(inputs)
    conv1 = layers.Dropout(rate=dropoutRate)(conv1)
    conv1 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv1)
    conv1 = layers.Dropout(rate=dropoutRate)(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool1)
    conv2 = layers.Dropout(rate=dropoutRate)(conv2)
    conv2 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv2)
    conv2 = layers.Dropout(rate=dropoutRate)(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool2)
    conv3 = layers.Dropout(rate=dropoutRate)(conv3)
    conv3 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv3)
    conv3 = layers.Dropout(rate=dropoutRate)(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool3)
    conv4 = layers.Dropout(rate=dropoutRate)(conv4)
    conv4 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv4)
    conv4 = layers.Dropout(rate=dropoutRate)(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = layers.Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(pool4)
    conv5 = layers.Dropout(rate=dropoutRate)(conv5)
    conv5 = layers.Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv5)
    conv5 = layers.Dropout(rate=dropoutRate)(conv5)

    up6 = layers.Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(layers.UpSampling2D(size = (2,2))(conv5))
    merge6 = layers.concatenate([conv4,up6], axis = 3)
    conv6 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(merge6)
    conv6 = layers.Dropout(rate=dropoutRate)(conv6)
    conv6 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv6)
    conv6 = layers.Dropout(rate=dropoutRate)(conv6)

    up7 = layers.Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(layers.UpSampling2D(size = (2,2))(conv6))
    merge7 = layers.concatenate([conv3,up7], axis = 3)
    conv7 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(merge7)
    conv7 = layers.Dropout(rate=dropoutRate)(conv7)
    conv7 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv7)
    conv7 = layers.Dropout(rate=dropoutRate)(conv7)

    up8 = layers.Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(layers.UpSampling2D(size = (2,2))(conv7))
    merge8 = layers.concatenate([conv2,up8], axis = 3)
    conv8 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(merge8)
    conv8 = layers.Dropout(rate=dropoutRate)(conv8)
    conv8 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv8)
    conv8 = layers.Dropout(rate=dropoutRate)(conv8)

    up9 = layers.Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(layers.UpSampling2D(size = (2,2))(conv8))
    merge9 = layers.concatenate([conv1,up9], axis = 3)
    conv9 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(merge9)
    conv9 = layers.Dropout(rate=dropoutRate)(conv9)
    conv9 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv9)
    conv9 = layers.Dropout(rate=dropoutRate)(conv9)
    conv9 = layers.Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'glorot_normal')(conv9)
    conv9 = layers.Dropout(rate=dropoutRate)(conv9)
    conv10 = layers.Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs = inputs, outputs = conv10)    
    #model.summary()

    if(pretrained_weights):
    	model.load_weights(pretrained_weights)

    return model

In [None]:
#properties for data augmentation
dataGenArgs = dict(rotation_range=5,
                   width_shift_range=0.05,
                   height_shift_range=0.05,
                   shear_range=0,#0.05,
                   zoom_range=0.05,
                   horizontal_flip=False, #DO NOT FLIP THE IMAGES FFS
                   vertical_flip=False,
                   fill_mode='nearest',
                   data_format= 'channels_last',
                   featurewise_center=False,
                   featurewise_std_normalization=False,
                   zca_whitening=False,
                  )


earlyStop = callbacks.EarlyStopping(patience=10, #be a bit patient...
                                    min_delta=0,
                                    monitor='loss',
                                    restore_best_weights=True,
                                    mode='min',
                                   )

reduceLR = callbacks.ReduceLROnPlateau(monitor='loss',
                                       patience=5,
                                       factor=0.3,
                                       verbose=1,
                                       cooldown=5,
                                      )

CALLBACKS = [earlyStop,
             reduceLR
            ]

OPT = Adam(learning_rate = 3e-4,
           beta_1 = 0.9,
           beta_2 = 0.999,
           amsgrad = False
          )



#calculate weights but over whole training set
MULTIPLIER = Y.size/Y.sum()

#other hyperparameters
BATCHSIZE = 16 #THIS MATTERS A LOT
DROPOUTRATE = 0
WEIGHT_DICT = {0.:1.,1.:MULTIPLIER}

#Spatial smoothing for weights
SIGMA = 20



Instantiate and train the model.

In [None]:
keras.clear_session()

tf.random.set_seed(101) #FIXME!!! this is not sufficient to guarantee deterministic behaviour during fitting.

model = unet(input_size=X.shape[1:],dropoutRate=DROPOUTRATE)

model.compile(optimizer = OPT, 
#               loss = 'binary_crossentropy',
#               loss = binary_crossentropy_weight_balance,
              loss = binary_crossentropy_closeness_to_foreground(sigma=SIGMA),
              metrics = ['accuracy',metrics.MeanIoU(num_classes=2)],
             )

fitHistory = model.fit(augmentImageSequence(X,Y,dataGenArgs,batchSize=BATCHSIZE),
                       epochs = 100,#think about me... 
                       steps_per_epoch= M//BATCHSIZE, #obvs
                       workers=8,
                       use_multiprocessing=True,
                       validation_data=(X_test,Y_test.astype('float')),
                       callbacks=CALLBACKS,
                       verbose=1,
                      )

Lets have a look at how fitting has proceeded

In [None]:
plt.figure(figsize = (15,10))

plt.subplot(2,1,1)
plt.plot(fitHistory.history['loss'],label = 'train')
plt.plot(fitHistory.history['val_loss'],label = 'dev')
plt.ylabel('loss')
plt.legend()
plt.xticks([])

plt.subplot(2,1,2)
plt.plot(fitHistory.history['mean_io_u'],label = 'train')
plt.plot(fitHistory.history['val_mean_io_u'],label = 'dev')
plt.ylabel('mean iou')

plt.xlabel('epoch #')

Lets have a look at the  distribution of IoU (rather than just the mean)...

In [None]:
def iou(yTrue,yPred):
    '''intersection-over-union score'''
    
    yTrue = yTrue>=0.5
    yPred = yPred>=0.5
    
    intersection = np.sum(np.logical_and(yTrue,yPred))
    
    union = np.sum(np.logical_or(yTrue,yPred))
    
    return intersection/union

In [None]:
predTest = model.predict(X_test)

predTrain = model.predict(X)


In [None]:
#loop over th eexample axis, calculating IoU for each image separately
plt.hist([iou(Y[m,:,:,:], predTrain[m,:,:]) for m in range(MTest)] , bins = np.arange(0,1.05,0.05), density=True, alpha=0.5, label = 'Train')
plt.hist([iou(Y[m,:,:,:], predTest[m,:,:]) for m in range(MTest)] ,  bins = np.arange(0,1.05,0.05), density=True, alpha=0.5, label = 'Test')

plt.xlabel('iou')
plt.ylabel('probability density')

plt.legend()

 

How well do predicted **areas** of fat match? That is what the project is all about

In [None]:

areasPredTrain = np.sum(predTrain,axis=(1,2,3)) * pxSize
areasTrueTrain = np.sum(Y,axis=(1,2,3)) * pxSize

areasPredTest = np.sum(predTest,axis=(1,2,3)) * pxSize_test
areasTrueTest = np.sum(Y_test,axis=(1,2,3)) * pxSize_test

plt.scatter(areasTrueTrain,areasPredTrain,label = 'train')
plt.scatter(areasTrueTest,areasPredTest,label = 'test')

r2,p = pearsonr(areasTrueTest,areasPredTest)

plt.title('for test set, R^2 = ' + str(r2) + ', p = ' + str(p))

plt.xlabel('human area (mm^2)')

plt.ylabel('machine area (mm^2)')

plt.legend()

a few examples of the training set segmentations

In [None]:
negs = 15

sample = np.random.randint(low=0,high=X.shape[0],size=negs)

plt.figure(figsize = (15,5*negs))
for ind,eg in enumerate(sample):
    
    plt.subplot(negs,3,ind*3+1)
    plt.imshow(np.squeeze(X[eg,:]),vmin=0,vmax=1)

    plt.subplot(negs,3,ind*3+2)
    plt.imshow(np.squeeze(Y[eg,:]),vmin=0,vmax=1)
    
    plt.subplot(negs,3,ind*3+3)
    plt.imshow(np.squeeze(predTrain[eg,:]),vmin=0,vmax=1)

Examples from the test set:

In [None]:
plt.figure(figsize = (15,5*MTest))


f = True

#loop over rows
for ind in range(MTest): #FIXME when I have more data, should take a random sample from test set rather than the whole thing
    
    if ind>0:
        f= False
    
    
    #Show original image
    plt.subplot(MTest,3,ind*3+1)
    plt.imshow(np.squeeze(X_test[ind,:]),vmin=0,vmax=1)
    plt.xticks([])
    plt.yticks([])
    if f:
        plt.title('image')
    
    #show human segmentation
    plt.subplot(MTest,3,ind*3+2)
    plt.imshow(np.squeeze(Y_test[ind,:]),vmin=0,vmax=1)
    plt.xticks([])
    plt.yticks([])
    if f:
        plt.title('human')
    
    #show automated segmentation
    plt.subplot(MTest,3,ind*3+3)
    plt.imshow(np.squeeze(predTest[ind,:]),vmin=0,vmax=1)
    plt.xticks([])
    plt.yticks([])
    if f:
        plt.title('machine')
    

In [None]:
import skimage