In [1]:
import absl.logging
import numpy as np
import tensorflow as tf

from sklearn.utils import class_weight
from project_scripts.widgets import explore_dataset_widget, plot_data, dataset_movie_widget
from project_scripts.data_loading import load_dataset, slices_to_textures, dataset_to_embeddings
from project_scripts.neural_networks import B_frame_CNN, texture_CNN, ensemble_MLP

# suppress 'WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op ...' errors
# these errors do not affect training or inference and even appear in the Tensorflow offficial tutorials
tf.get_logger().setLevel('ERROR')
absl.logging.set_verbosity('ERROR')

In [2]:
explore_dataset_widget(plot_data)

interactive(children=(IntSlider(value=0, description='slice_idx', max=6), IntSlider(value=0, description='fram…

# Load Dataset

<img src='manuscript/figure_1.png' width="600"/>

Figure 1. Caption

In [6]:
X_train, Y_train, max_val, min_val = load_dataset('training_data')
X_train_textures = slices_to_textures(X_train)
X_train_embeddings = dataset_to_embeddings(X_train, X_train_textures)

LOADING: directory=(training_data) | label=([0, 1]) | total_bframes=(0)
LOADING: directory=(training_data\cancer) | label=([0, 1]) | total_bframes=(0)
LOADING: directory=(training_data\cancer\3-21-2017-s1) | label=([1, 0]) | total_bframes=(52)
LOADING: directory=(training_data\cancer\3-21-2017-s2) | label=([1, 0]) | total_bframes=(26)
LOADING: directory=(training_data\cancer\6-24-2019-s3) | label=([1, 0]) | total_bframes=(137)
LOADING: directory=(training_data\cancer\9-11-2018-s2) | label=([1, 0]) | total_bframes=(146)
LOADING: directory=(training_data\cancer\9-28-2020-Tumor) | label=([1, 0]) | total_bframes=(31)
LOADING: directory=(training_data\non_cancer) | label=([0, 1]) | total_bframes=(0)
LOADING: directory=(training_data\non_cancer\10-5-2020-C1-NormalWhiteMatter) | label=([0, 1]) | total_bframes=(42)
LOADING: directory=(training_data\non_cancer\4-24-2018-s2) | label=([0, 1]) | total_bframes=(156)
LOADING: directory=(training_data\non_cancer\8-24-2017-C2-s1) | label=([0, 1]) | to

In [8]:
print('TRAINING DATA')
print(f'(num_slices, slice_height, slice_width, slice_channels): {X_train.shape}')
print(f'(num_textures, texture_height, texture_width, texture_channels): {X_train_textures.shape}')
print(f'(num_embeddings, len_embedding): {X_train_embeddings.shape}')
print(f'(num_labels, #_classes): {Y_train.shape} \n')

dataset_movie_widget(X_train_textures, Y_train) # the widget will render when this cell is run

TRAINING DATA
(num_slices, slice_height, slice_width, slice_channels): (5831, 200, 100)
(num_textures, texture_height, texture_width, texture_channels): (5831, 100, 100, 1)
(num_embeddings, len_embedding): (5831, 128)
(num_labels, #_classes): (5831, 2) 



interactive(children=(Play(value=0, description='idx', interval=500, max=5830), IntSlider(value=0, description…

# Neural Network Training
- b_frame_CNN and texture_CNN need can be trained independently of one another
- In order to train ensemble_MLP, a pre-trained b_frame_CNN and texture_CNN is needed to convert the slices and textures into embeddings respectively. 

In [7]:
my_bframe_CNN = B_frame_CNN(3, 'relu', 'same')
my_bframe_CNN.model().summary()
my_bframe_CNN.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 
    loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1), 
    metrics=['accuracy'])

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 200, 100, 1)]     0         
                                                                 
 conv2d_4 (Conv2D)           (None, 100, 100, 32)      320       
                                                                 
 conv2d_5 (Conv2D)           (None, 50, 50, 64)        18496     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 24, 24, 64)       0         
 2D)                                                             
                                                                 
 gradmaps (Conv2D)           (None, 12, 12, 128)       73856     
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 5, 5, 128)        0         
 2D)                                                       

In [8]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath = 'saved_models\\models_history\\bframe_cnn\\epoch_{epoch:02d}-val_acc_{val_accuracy:.2f}.tf',
    monitor = 'val_accuracy',
    save_best_only = True,
    mode = 'max'
)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10) # stops training after 'patience' epochs of no improvement
log_csv = tf.keras.callbacks.CSVLogger('saved_models\\models_history\\logs\\bframe_cnn_log.csv', separator=',', append=False) # save training and validation curves

y_integers = np.argmax(Y_train, axis=1)
class_weights = class_weight.compute_class_weight(
                                        class_weight = 'balanced',
                                        classes = np.unique(y_integers),
                                        y = y_integers                                                  
                                    )

history = my_bframe_CNN.fit(
    np.expand_dims(X_train, axis=3), 
    Y_train, 
    batch_size = 8,
    shuffle = True,
    epochs = 30, 
    validation_split = 0.2,
    class_weight=dict(enumerate(class_weights)), 
    callbacks = [early_stop, log_csv, checkpoint]
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [33]:
my_texture_CNN = texture_CNN(3, 'relu', 'same')
my_texture_CNN.model().summary() 
my_texture_CNN.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), 
    loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1), 
    metrics=['accuracy'])

Model: "model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, 100, 100, 1)]     0         
                                                                 
 conv2d_20 (Conv2D)          (None, 50, 50, 32)        320       
                                                                 
 max_pooling2d_12 (MaxPoolin  (None, 24, 24, 32)       0         
 g2D)                                                            
                                                                 
 gradmaps (Conv2D)           (None, 12, 12, 32)        9248      
                                                                 
 max_pooling2d_13 (MaxPoolin  (None, 5, 5, 32)         0         
 g2D)                                                            
                                                                 
 conv2d_21 (Conv2D)          (None, 3, 3, 32)          9248

In [34]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath = 'saved_models\\models_history\\texture_cnn\\epoch_{epoch:02d}-val_acc_{val_accuracy:.2f}.tf',
    monitor = 'val_accuracy',
    save_best_only = True, # only save model if val_accuracy improves
    mode = 'max' # higher val_accuracy is better
)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10) # stops training after 'patience' epochs of no improvement
log_csv = tf.keras.callbacks.CSVLogger('saved_models\\models_history\\logs\\texture_cnn_log.csv', separator=',', append=False) # save training and validation curves

history = my_texture_CNN.fit(
    X_train_textures, 
    Y_train, 
    batch_size = 8,
    shuffle = True,
    epochs = 30, 
    validation_split = 0.2,
    class_weight=dict(enumerate(class_weights)), 
    callbacks = [early_stop, log_csv, checkpoint]
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [13]:
my_ensemble_MLP = ensemble_MLP()
my_ensemble_MLP.model().summary()
my_ensemble_MLP.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), 
    loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0), 
    metrics=['accuracy'])

Model: "model_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_6 (InputLayer)        [(None, 1, 128)]          0         
                                                                 
 dense_6 (Dense)             (None, 1, 64)             8256      
                                                                 
 dropout_7 (Dropout)         (None, 1, 64)             0         
                                                                 
 dense_7 (Dense)             (None, 1, 2)              130       
                                                                 
Total params: 8,386
Trainable params: 8,386
Non-trainable params: 0
_________________________________________________________________


In [14]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath = 'saved_models\\models_history\\ensemble_mlp\\epoch_{epoch:02d}-val_acc_{val_accuracy:.2f}.tf',
    monitor = 'val_accuracy',
    save_best_only = True,
    mode = 'max'
)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10)
log_csv = tf.keras.callbacks.CSVLogger('saved_models\\models_history\\logs\\ensemble_mlp_log.csv', separator=',', append=False)

history = my_ensemble_MLP.fit(
    X_train_embeddings,
    Y_train, 
    batch_size = 8,
    shuffle = True,
    epochs = 30, 
    validation_split = 0.2,
    callbacks = [early_stop, log_csv, checkpoint]
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
