# Training a CNN for cell cycle state classification

In [None]:
# if using colab, install cellx library and make log and data folders
if 'google.colab' in str(get_ipython()):
    !pip install -q git+git://github.com/quantumjot/cellx.git
    !mkdir logs
    !mkdir train
    !mkdir test

## Import libraries and set up hyper-parameters

In [None]:
import os
import zipfile
import numpy as np
from datetime import datetime
from skimage.io import imread
from skimage.transform import resize

In [None]:
import tensorflow.keras as K
import tensorflow as tf

In [None]:
from cellx.layers import Encoder2D
from cellx.tools.dataset import build_dataset
from cellx.tools.dataset import write_dataset
from cellx.augmentation.utils import append_conditional_augmentation, augmentation_label_handler
from cellx.callbacks import tensorboard_confusion_matrix_callback

In [None]:
TRAIN_PATH = "./train"
TEST_PATH = "./test"
TRAIN_FILE = os.path.join(TRAIN_PATH, 'CNN_train.tfrecord')
TEST_FILE = os.path.join(TEST_PATH, 'CNN_test.tfrecord')
LABELS = ["Interphase", "Prometaphase", "Metaphase", "Anaphase", "Apoptosis"]
BATCH_SIZE = 64
BUFFER_SIZE = 20_000
TRAINING_EPOCHS = 100

In [None]:
%load_ext tensorboard
LOG_ROOT = './logs'
LOG_DIR = os.path.join(LOG_ROOT, datetime.now().strftime("%Y%m%d-%H%M%S"))

## Load the training/testing data and generate TFRecord files

In [None]:
def create_tf_record(
    root, 
    filename,
    labels=LABELS
):
    
    _images = []
    _labels = []
    
    # find the zip files:
    zipfiles = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(".zip") and f.startswith("annotation_")]
    
    for zfn in zipfiles:
        print(f"Loading file: {zfn}")
        with zipfile.ZipFile(zfn, 'r') as zip_data:
            files = zip_data.namelist()

            for numeric_label, label in enumerate(labels):

                patch_files = [f for f in files if f.endswith(".tif") and f.startswith(label.capitalize())]
                images = [imread(zip_data.open(f)) for f in patch_files]
                images_resized = [resize(img, (64, 64), preserve_range=True) for img in images]

                _images += images_resized
                _labels += [numeric_label] * len(images_resized)

                
    images_arr = np.stack(_images, axis=0)[..., np.newaxis]
    labels_arr = np.stack(_labels, axis=0)
    
    print(f"Total images: {images_arr.shape[0]}")
    write_dataset(filename, images_arr.astype(np.uint8), labels=labels_arr.astype(np.int64))

In [None]:
create_tf_record(TRAIN_PATH, TRAIN_FILE)
create_tf_record(TEST_PATH, TEST_FILE)

## Create a simple CNN for classification

In [None]:
img = K.layers.Input(shape=(64, 64, 1))
x = Encoder2D(layers=[8, 16, 32, 64, 128])(img)
x = K.layers.Flatten()(x)
x = K.layers.Dense(256, activation="relu")(x)
x = K.layers.Dropout(0.2)(x)
logits = K.layers.Dense(5, activation="linear")(x)

In [None]:
model = K.Model(inputs=img, outputs=logits)

In [None]:
model.summary()

## Set-up some augmentations to be used while training

In [None]:
@augmentation_label_handler
def normalize(img):
    img = tf.image.per_image_standardization(img)
    # clip to 4 standard deviations
    img = tf.clip_by_value(img, -4., 4.)
    tf.debugging.check_numerics(img, "Image contains NaN")
    return img

In [None]:
@augmentation_label_handler
def augment(img):
    boundary_augmentation=True
    if boundary_augmentation:
        # this will randomly simulate the cropping that occurs at the edge of
        # an image volume

        vignette = np.ones((64, 64, 1), dtype=np.float32)
        width = np.random.randint(0,30)
        vignette[:,:width,...] = 0

        img = tf.cond(pred=tf.random.uniform(shape=())<0.05,
                true_fn=lambda: tf.multiply(img, vignette),
                false_fn=lambda: img)

    # do some data augmentation
    k = tf.random.uniform(maxval=3, shape=(), dtype=tf.int32)
    img = tf.image.rot90(img, k=k)

    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

In [None]:
@augmentation_label_handler
def random_contrast(x):
    return tf.image.random_contrast(x, 0.3, 1.0)

@augmentation_label_handler
def random_brightness(x):
    return tf.image.random_brightness(x, 0.3, 1.0)

## Build the training dataset, with random augmentations

In [None]:
dataset = build_dataset(TRAIN_FILE, read_label=True)

In [None]:
dataset = dataset.map(augment)
dataset = append_conditional_augmentation(dataset, [random_contrast, random_brightness])
dataset = dataset.map(normalize)
dataset = dataset.shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.prefetch(1)

## Build the test dataset, without augmentations

In [None]:
test_dataset = build_dataset(TEST_FILE, read_label=True)
test_dataset = test_dataset.map(normalize)
test_dataset = test_dataset.take(-1).as_numpy_iterator()

test_images, test_labels = zip(*list(test_dataset))

## Set up tensorboard callbacks to monitor training

In [None]:
tensorboard_callback = K.callbacks.TensorBoard(log_dir=LOG_DIR)
confusion_matrix_callback = tensorboard_confusion_matrix_callback(
    model, 
    np.asarray(test_images), 
    test_labels,
    LOG_DIR,
    class_names=LABELS,
    is_binary=False
)

## Set up the loss function

In [None]:
loss = K.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss, metrics=['accuracy'])

## Finally, train the model and evaluate performance using tensorboard

In [None]:
%tensorboard --logdir $LOG_ROOT --host localhost

In [None]:
model.fit(
    dataset, 
    steps_per_epoch=BUFFER_SIZE//BATCH_SIZE, 
    epochs=TRAINING_EPOCHS, 
    callbacks=[tensorboard_callback, confusion_matrix_callback],
)