# CNN Training for Cell Cycle State Classification

### Welcome!

This notebook allows you to train a convolutional neural network (CNN) using your annotated single-cell image patches. 
Follow the step-wise instructions to proceed with the network training and evaluation. 


### Important Notes:

1. If using the virtual environment of [Google Colab](https://colab.research.google.com/notebooks/intro.ipynb "Google Colaboratory"): You will need to be signed in with a Google email address. Your session will 'timeout' if you do not interact with it. Although documentation claims the runtime should last 90 minutes if you close the browser or 12 hours if you keep the browser open, our experience shows it should disconnect after 60 minutes even if you keep the browser open. Please visit this [StackOverflow](https://stackoverflow.com/questions/54057011/google-colab-session-timeout "Google Colab Session Timeout") discussion where users report different experiences with Colab's session timeout. Additionally, please remember your access to Colab resources is limited to a maximum of 12h per session. If you exceed this limit, your access to Colab may be temporarily suspended by Google.

2. To be able to train the neural network on your own data, you must first **import your annotated data** into the folders to source from. Please follow the running instructions below.


### Running Instructions:

1. Execute the first cell containing code below. This will install our [CellX](https://github.com/quantumjot/cellx) library & create local directories in the environment of the virtual machine. The executed first cell will print ```Building wheel for cellx (setup.py) ... done```. 

2. Click on the ``` 📁``` folder icon located on the left-side dashboard of the Colab notebook. You should now see 4 subfolders in this directory: `sample_data` (default), `logs`, `train` and `validation` folder, which should all be empty. **Manually drag your 'annotation_XXX.zip' files into the `train` and `validation` folders**.

3. Prior to training the model, the training & validation sets are created and image augmentations introduced to the training set. In the next step the model is trained based on the default parameters defined in the **"Set up CNN training hyperparameters:"** section, but you are welcome to modify the values if you wish.

4. You can now run the entire notebook by clicking on ```Runtime``` > ```Run all``` in the upper main dashboard. Note: Re-running the initial cell will fail to create the `logs`, `train` and `validation` folders as they have already been created, but this won't prevent the subsequent cells from running.

5. During training, you can actively visualise what the network is doing via [TensorBoard](https://www.tensorflow.org/tensorboard/get_started "TensorFlow || Tensorboard"), a tool for providing the measurements and visualisations needed during the machine learning workflow. It enables tracking experiment metrics like loss and accuracy, visualising the model graph, projecting embeddings to a lower dimensional space, and much more.

---

**Happy training!**

*Your [CellX](http://lowe.cs.ucl.ac.uk/cellx.html "Lowe Lab @ UCL") team*


### Install the CellX library & create subdirectories in the virtual machine:

In [None]:
# if using colab, install cellx library and make log and data folders

if 'google.colab' in str(get_ipython()):
    try: 
        import cellx
    except:
        !pip install -q git+git://github.com/quantumjot/cellx.git
        !mkdir logs
        !mkdir train
        !mkdir validation

### Import libraries and CellX toolkit:

In [None]:
import os
import zipfile
import numpy as np
import matplotlib.pyplot as plt

from datetime import datetime
from skimage.transform import resize

In [None]:
import tensorflow.keras as K
import tensorflow as tf

In [None]:
from cellx.layers import Encoder2D
from cellx.tools.dataset import build_dataset
from cellx.tools.dataset import write_dataset
from cellx.augmentation.utils import append_conditional_augmentation, augmentation_label_handler
from cellx.callbacks import tensorboard_confusion_matrix_callback
from cellx.tools.io import read_annotations

### Define paths & class labels:

In [None]:
TRAIN_PATH = "./train"
VAL_PATH = "./validation"
TRAIN_FILE = os.path.join(TRAIN_PATH, 'CNN_train.tfrecord')
VAL_FILE = os.path.join(VAL_PATH, 'CNN_validation.tfrecord')

### Set up CNN training hyperparameters:

In [None]:
BATCH_SIZE = 64
BUFFER_SIZE = 20_000
TRAINING_EPOCHS = 100
BOUNDARY_AUGMENTATION = True
INPUT_SHAPE = (64, 64, 1)

### Load the TensorBoard extension for real-time visualisation of CNN training:

In [None]:
%load_ext tensorboard
LOG_ROOT = './logs'
LOG_DIR = os.path.join(LOG_ROOT, datetime.now().strftime("%Y%m%d-%H%M%S"))

### Generate TensorFlow Record (TFRecord) files:

In [None]:
def create_tf_record(
    root: str, 
    filename: str,
    use_flagged: bool = False
):

    # load the annotations
    _images, _labels, _states = read_annotations(root, use_flagged=use_flagged)
    images_arr = np.stack(_images, axis=0)[..., np.newaxis]
    labels_arr = np.stack(_labels, axis=0)
    
    # write the tf dataset
    write_dataset(filename, images_arr.astype(np.uint8), labels=labels_arr.astype(np.int64))

    # return the state labels 
    states = [k for k, v in sorted(_states.items(), key=lambda item: item[1])]

    # plot some stats 
    stats = {k: _labels.count(v) for k, v in _states.items()}
    print(f"Exported \'{filename}\' containing:")
    if not use_flagged: print(f" - Excluding flagged files")
    for k, v in stats.items():
        print(f" - [{_states[k]}] {k}: {v}")
    print(f" - Total images: {images_arr.shape[0]}")
    return states

### IMPORTANT!

**Prior to calling the function to create the TFRecord files:**

You need to manually drag the annotation_XXX.zip files into the newly created folders (if you haven't yet done so from following the running instructions at the top of the notebook). If you are working in the Google Colab environment, click on the folder icon on the left-side dashboard, which should now contain the `logs`, `train` and `validation` directories. They should be empty until you drag your annotation files into them.

Once the files have been imported, run the following cell:

In [None]:
LABELS = create_tf_record(TRAIN_PATH, TRAIN_FILE, use_flagged=not BOUNDARY_AUGMENTATION)
_ = create_tf_record(VAL_PATH, VAL_FILE, use_flagged=not BOUNDARY_AUGMENTATION)

### Create a simple CNN for classification:

In [None]:
img = K.layers.Input(shape=INPUT_SHAPE)
x = Encoder2D(layers=[8, 16, 32, 64, 128])(img)
x = K.layers.Flatten()(x)
x = K.layers.Dense(256, activation="relu")(x)
x = K.layers.Dropout(0.2)(x)
logits = K.layers.Dense(len(LABELS), activation="linear")(x)

In [None]:
model = K.Model(inputs=img, outputs=logits)

In [None]:
model.summary()

### Set up some augmentations to be used while training:

In [None]:
@augmentation_label_handler
def normalize(img):
    img = tf.image.per_image_standardization(img)
    # clip to 4 standard deviations
    img = tf.clip_by_value(img, -4., 4.)
    tf.debugging.check_numerics(img, "Image contains NaN")
    return img

In [None]:
@augmentation_label_handler
def augment(img):
    if BOUNDARY_AUGMENTATION:
        # this will randomly simulate the cropping that occurs at the edge of
        # an image volume

        vignette = np.ones(INPUT_SHAPE, dtype=np.float32)
        width = np.random.randint(0,INPUT_SHAPE[0]//2)
        vignette[:,:width,...] = 0

        img = tf.cond(pred=tf.random.uniform(shape=())<0.05,
                true_fn=lambda: tf.multiply(img, vignette),
                false_fn=lambda: img)

    # do some data augmentation
    k = tf.random.uniform(maxval=3, shape=(), dtype=tf.int32)
    img = tf.image.rot90(img, k=k)

    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

In [None]:
@augmentation_label_handler
def random_contrast(x):
    return tf.image.random_contrast(x, 0.3, 1.0)

@augmentation_label_handler
def random_brightness(x):
    return tf.image.random_brightness(x, 0.3, 1.0)

### Build the training dataset, with random augmentations:

In [None]:
dataset = build_dataset(TRAIN_FILE, read_label=True)

In [None]:
dataset = dataset.map(augment)
dataset = append_conditional_augmentation(dataset, [random_contrast, random_brightness])
dataset = dataset.map(normalize)
dataset = dataset.shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.prefetch(1)

### Build the validation dataset, without augmentations:

In [None]:
validation_dataset = build_dataset(VAL_FILE, read_label=True)
validation_dataset = validation_dataset.map(normalize)
validation_dataset = validation_dataset.take(-1).as_numpy_iterator()

validation_images, validation_labels = zip(*list(validation_dataset))

### Set up TensorBoard callbacks to monitor training:

In [None]:
tensorboard_callback = K.callbacks.TensorBoard(log_dir=LOG_DIR)
confusion_matrix_callback = tensorboard_confusion_matrix_callback(
    model, 
    np.asarray(validation_images), 
    validation_labels,
    LOG_DIR,
    class_names=LABELS,
    is_binary=False
)

### Set up the loss function:

In [None]:
loss = K.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss, metrics=['accuracy'])

## Finally, train the model and evaluate performance using TensorBoard:

Running the next cell will show the TensorBoard GUI, which will show the calculations of the model performance after each epoch. 
You can monitor the model performance while progressing through training epochs by switching to the `SCALARS` (accuracy & loss) or `IMAGES` (confusion matrix) tabs in the upper menu bar. 

Note: If you only see the `GRAPHS` tab and at least one epoch has already completed, go to the dropdown on the right side (will probably show `INACTIVE`) and choose `SCALARS` from the list then press the refresh button to the right of the dropdown.
Press the refresh button again at any later point to visualise the most up-to-date model performance calculations.

In [None]:
%tensorboard --logdir $LOG_ROOT --host localhost

In [None]:
model.fit(
    dataset, 
    steps_per_epoch=BUFFER_SIZE//BATCH_SIZE, 
    epochs=TRAINING_EPOCHS, 
    callbacks=[tensorboard_callback, confusion_matrix_callback],
)

### Save the Model:

**Do not terminate this notebook before saving the model.** 
To export the saved model to your local machine, press the '...' button next to the `model.h5` file on the left-side dashboard.

In [None]:
model_name = 'model'
model.save('{}.h5'.format(model_name))