# Training a CNN for cell cycle state classification

In [1]:
if 'google.colab' in str(get_ipython()):
    !pip install -q git+git://github.com/quantumjot/cellx.git

In [1]:
import os
import numpy as np

In [2]:
import tensorflow.keras as K
import tensorflow as tf

In [3]:
from cellx.layers import Encoder2D
from cellx.tools.dataset import build_dataset
from cellx.tools.dataset import write_dataset
from cellx.augmentation.utils import append_conditional_augmentation, augmentation_label_handler
from cellx.callbacks import tensorboard_confusion_matrix_callback

In [4]:
TRAIN_PATH = "/media/quantumjot/Data/TrainingData/CNN_training_MDCK_v2_merged"
TEST_PATH = "/media/quantumjot/Data/TrainingData/CNN_test_MDCK_v2"
FILE = os.path.join(TRAIN_PATH, 'CNN_cellx.tfrecord')
LABELS = ["interphase", "prometaphase", "metaphase", "anaphase", "apoptosis"]
BATCH_SIZE = 64
BUFFER_SIZE = 20_000
TRAINING_EPOCHS = 100

In [5]:
%load_ext tensorboard
from datetime import datetime
LOG_ROOT = '/media/quantumjot/DataIII/Training/logs'
LOG_DIR = os.path.join(LOG_ROOT, datetime.now().strftime("%Y%m%d-%H%M%S"))

In [6]:
from skimage.io import imread
from skimage.transform import resize

def create_tf_record(
    filename,
    root, 
    labels=LABELS
):
    
    _images = []
    _labels = []
    
    for numeric_label, label in enumerate(labels):
        
        _path = os.path.join(root, label)
        files = [f for f in os.listdir(_path) if f.endswith(".png")]
        images = [imread(os.path.join(_path, f))[..., 1] for f in files]
        images_resized = [resize(img, (64, 64), preserve_range=True) for img in images]

        _images += images_resized
        _labels += [numeric_label] * len(images_resized)
        
        print(f"{label}: {len(images_resized)}")
        
    images_arr = np.stack(_images, axis=0)[..., np.newaxis]
    labels_arr = np.stack(_labels, axis=0)
    
    print(f"Total images: {images_arr.shape[0]}")
    write_dataset(filename, images_arr.astype(np.uint8), labels=labels_arr.astype(np.int64))
        

In [7]:
# create_tf_record(FILE, TRAIN_PATH)

In [8]:
img = K.layers.Input(shape=(64, 64, 1))
x = Encoder2D(layers=[8, 16, 32, 64, 128])(img)
x = K.layers.Flatten()(x)
x = K.layers.Dense(256, activation="swish")(x)
x = K.layers.Dropout(0.2)(x)
logits = K.layers.Dense(5, activation="linear")(x)

In [9]:
model = K.Model(inputs=img, outputs=logits)

In [10]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 64, 64, 1)]       0         
_________________________________________________________________
encoder2d (Encoder2D)        (None, 4, 4, 128)         99232     
_________________________________________________________________
flatten (Flatten)            (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 256)               524544    
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 1285      
Total params: 625,061
Trainable params: 624,565
Non-trainable params: 496
_____________________________________________________

In [11]:
@augmentation_label_handler
def normalize(img):
    img = tf.image.per_image_standardization(img)
    # clip to 4 standard deviations
    img = tf.clip_by_value(img, -4., 4.)
    tf.debugging.check_numerics(img, "Image contains NaN")
    return img

In [12]:
@augmentation_label_handler
def augment(img):
    boundary_augmentation=True
    if boundary_augmentation:
        # this will randomly simulate the cropping that occurs at the edge of
        # an image volume

        vignette = np.ones((64, 64, 1), dtype=np.float32)
        width = np.random.randint(0,30)
        vignette[:,:width,...] = 0

        img = tf.cond(pred=tf.random.uniform(shape=())<0.05,
                true_fn=lambda: tf.multiply(img, vignette),
                false_fn=lambda: img)

    # do some data augmentation
    k = tf.random.uniform(maxval=3, shape=(), dtype=tf.int32)
    img = tf.image.rot90(img, k=k)

    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

In [13]:
@augmentation_label_handler
def random_contrast(x):
    return tf.image.random_contrast(x, 0.3, 1.0)

@augmentation_label_handler
def random_brightness(x):
    return tf.image.random_brightness(x, 0.3, 1.0)

In [14]:
dataset = build_dataset(FILE, read_label=True)

In [15]:
dataset = dataset.map(augment)
dataset = append_conditional_augmentation(dataset, [random_contrast, random_brightness])
dataset = dataset.map(normalize)
dataset = dataset.shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.prefetch(1)

In [16]:
test_dataset = build_dataset(os.path.join(TEST_PATH, "CNN_cellx_test.tfrecord"), read_label=True)
test_dataset = test_dataset.map(normalize)
test_dataset = test_dataset.take(-1).as_numpy_iterator()

test_images, test_labels = zip(*list(test_dataset))

In [17]:
tensorboard_callback = K.callbacks.TensorBoard(log_dir=LOG_DIR)
confusion_matrix_callback = tensorboard_confusion_matrix_callback(model, 
                                                                  np.asarray(test_images), 
                                                                  test_labels,
                                                                  LOG_DIR,
                                                                  class_names=LABELS,
                                                                  is_binary=False)

In [18]:
%tensorboard --logdir $LOG_ROOT --host localhost

Reusing TensorBoard on port 6006 (pid 22371), started 0:46:25 ago. (Use '!kill 22371' to kill it.)

In [19]:
loss = K.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss, metrics=['accuracy'])

In [20]:
model.fit(
    dataset, 
    steps_per_epoch=BUFFER_SIZE//BATCH_SIZE, 
    epochs=TRAINING_EPOCHS, 
    callbacks=[tensorboard_callback, confusion_matrix_callback],
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x7f6ce052ddd0>