# CNN Inference for Cell Cycle State Classification

### Welcome!

This notebook allows you to take the convolutional neural network (CNN) that you trained in the previous notebook and use it for inference on previously unseen single-cell image patches. Follow the step-wise instructions to proceed with testing the network.


### Important Notes:

1. You are using the virtual environment of the [Google Colab](https://colab.research.google.com/notebooks/intro.ipynb "Google Colaboratory"). To be able to test the neural network, you must first **import images not used during training** into the folder to source from. Please follow the running instructions after executing the first cell of this notebook.

2. If using Google Colab: You will need to be signed in with a Google email address. Your session will 'timeout' if you do not interact with it. Although documentation claims the runtime should last 90 minutes if you close the browser or 12 hours if you keep the browser open, our experience shows it should disconnect after 60 minutes even if you keep the browser open. Please visit this [StackOverflow](https://stackoverflow.com/questions/54057011/google-colab-session-timeout "Google Colab Session Timeout") discussion where others have reported even shorter periods of time until the runtime disconnects when failing to interact with the session. Additionally, please remember your access to Colab resources is limited to a maximum of 12h per session. If you exceed this limit, your access to Colab may be temporarily suspended by Google.


### Running Instructions:

1. Execute the first cell containing code below, which will install the CellX library & create a local test directory in the environment of the virtual machine. The executed first cell will print ```Building wheel for cellx (setup.py) ... done```. (Note: This virtual environment is different from the one created for the Training notebook, which is why we need to re-install the external `cellx` library etc.)

2. Click on the ``` 📁``` folder icon located on the left-side dashboard of the Colab notebook, this is the default `content` directory where you can see the following subdirectories: `sample_data` (default) & `test`. Drag your saved model (the `.h5` file) into the `content` folder and your annotated zip file(s) into the `test` folder.

3. You can now now run the entire notebook by clicking on ```Runtime``` > ```Run``` in the upper main dashboard. 

---

**Happy testing!**

*Your [CellX](http://lowe.cs.ucl.ac.uk/cellx.html "Lowe Lab @ UCL") team*


### Install the CellX library & create subdirectories in the virtual machine:

In [None]:
# if using colab, install cellx library and make log and data folders

if 'google.colab' in str(get_ipython()):
    !pip install -q git+git://github.com/quantumjot/cellx.git
    !mkdir test

### Import libraries and CellX toolkit:

In [None]:
import os
import zipfile
import numpy as np
import matplotlib.pyplot as plt

from datetime import datetime
from scipy.special import softmax
from skimage.transform import resize
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from umap import UMAP

In [None]:
import tensorflow.keras as K
import tensorflow as tf

In [None]:
from cellx.core import load_model
from cellx.layers import Encoder2D
from cellx.tools.confusion import plot_confusion_matrix
from cellx.tools.io import read_annotations
from cellx.tools.projection import ManifoldProjection2D

### Define paths & class labels:

In [None]:
TEST_PATH = "./test"

### Import test dataset from the zip files:

In [None]:
test_images, test_labels, states = read_annotations(TEST_PATH)

### Load the Model:

By using the "load_model" function from the CellX library, we can import models without needing to specify the CellX custom layers that had been used to build them.

In [None]:
model_name = 'model'
model = load_model(f'{model_name}.h5')

In [None]:
model.summary()

### Normalize the images in the test dataset:

In [None]:
# image normalization function
def normalize_image_array(img):
    img_mean = np.mean(img)
    img_stddev = max(np.std(img), 1.0/np.size(img))
    img = np.subtract(img,img_mean)
    img = np.divide(img,img_stddev)
    # clip to 4 standard deviations
    img = np.clip(img, -4, 4)
    return img

In [None]:
test_images = [normalize_image_array(image) for image in test_images]
test_images_array = np.array(test_images)[...,np.newaxis] # convert to numpy array for model prediction
test_labels_array = np.array(test_labels)

## Run the Model on the test images:

In [None]:
test_predictions = model.predict(test_images_array)

The 'softmax' function transforms test_predictions into an array of scores for each class for each instance in the testing set. Across classes, the scores sum to one. The class associated with the highest score is the model's 'prediction'.

In [None]:
test_predictions = softmax(test_predictions,axis=1)

### Show predictions on the test images:

Sample N images out of the testing set to check the model's predictions on them.

In [None]:
def show_testing_predictions(
    num_examples, # number of testing examples to show
    test_images
):
    plt.figure(figsize=(10,3*(int(num_examples/5)+1)))
    plt.suptitle('Predictions',fontsize=25,x=0.5,y=0.95)
    for image_num in range(min(np.shape(test_images_array)[0],num_examples)-1):
        plt.subplot(int(num_examples/5)+1,5,image_num+1)
        plt.imshow(test_images_array[image_num,:,:,0])
        plt.title('Image {}'.format(image_num+1))
        plt.yticks([])
        plt.xticks([])
        plt.xlabel(list(states)[np.argmax(test_predictions[image_num])])
    plt.show()

In [None]:
show_testing_predictions(20,test_images)

### Calculate evaluation metrics:

We will next calculate the "precision", "recall" and "F1 score" metrics for each class, as well as the "confusion matrix" for the CNN's performance on the testing set. The three metrics are calculated using the number of "false positive", "true positive" and "false negative" predictions for each class.
- The "precision" of class X is calculated by $$precision(X) = \frac{No.\;of\;true\;positives}{No.\;of\;true\;positives+No.\;of\;false\;positives}$$
- The "recall" of class X is calculated by $$recall(X) = \frac{No.\;of\;true\;positives}{No.\;of\;true\;positives+No.\;of\;false\;negatives}$$
- The "F1 score" of class X is calculated by $$F1(X) = 2*\frac{precision(X)*recall(X)}{precision(X)+recall(X)}$$
<br>

The "confusion matrix" is a table that visually represents the performance of a network on a testing set. The number shown in row A and column B is the number of testing examples of ground-truth class A that have been predicted as belonging to class B by the network.

Reading resource for confusion matrices: https://towardsdatascience.com/understanding-confusion-matrix-a9ad42dcfd62

In [None]:
loss,accuracy = model.evaluate(test_images_array, test_labels_array)

test_confusion_matrix = confusion_matrix(test_labels,np.argmax(test_predictions,axis=1))
test_confusion_matrix_plot = plot_confusion_matrix(test_confusion_matrix,list(states))
test_confusion_matrix_plot.show()

print('Testing Accuracy = ',accuracy)
print('Testing Loss = ',loss)

precision,recall,fscore,support = precision_recall_fscore_support(test_labels,np.argmax(test_predictions,axis=1))
print('Testing Precision = ',precision)
print('Testing Recall = ',recall)
print('Testing F1 Score = ',fscore)

### Dimensionality reduction with UMAP:

By running the below cell, we see that the model output is an array of 2 dimensions: 
* the 1st dimension corresponds to the number of test images used 
* the 2nd dimension corresponds to the number of possible classes pre-defined in our model

In [None]:
test_predictions.shape

We can use UMAP to easily visualise the network's classification performance by embedding the predictions from 5D space (number of classes/features) into a lower 2D space while attempting to keep the data's inherent structure and underlying relationships.

We first define our parameters of choice. In this simple example, we chose to only modify the following ones:
* `n_neighbors` - the number of neighbours determines the size of the local neighbourhood that UMAP should focus on when creating the embedding, low values => emphasis on local structure, high values => emphasis on global structure
* `n_epochs` - the number of epochs determines the number of rounds the UMAP embedding will be optimised for (similar to training a CNN), the higher the number the more accurately the 2D embedding will replicate the original data structure
* `random_state` - UMAP is a stochastic algorithm, so we need to set a random seed to ensure that the results are reproducible across different runs. try eliminating this parameter, you should see slightly different UMAP embeddings from one run to the next

Feel free to adjust the parameters and check how the below image projection changes! You can read up on the most important parameters [here](https://umap-learn.readthedocs.io/en/latest/parameters.html#) or go through the whole list of parameters [here](https://umap-learn.readthedocs.io/en/latest/api.html).

If you're interested in reading about how UMAP works, [see here](https://umap-learn.readthedocs.io/en/latest/basic_usage.html).

In [None]:
# UMAP parameters
nbs = 5
eps = 50
rnd = 0

We then create a UMAP model with the defined parameters. The full configuration of the UMAP model will be printed out with all the parameter values to be used, including the ones modified above. 

Note:`verbose=True`enables written feedback to the user while UMAP is running.

In [None]:
mapper = UMAP(n_neighbors=nbs, n_epochs=eps, random_state=rnd, verbose=True)

Fit the UMAP model to the data.

In [None]:
mapper.fit(test_predictions)

### 2D image patch projection of model embedding:

By projecting the test images corresponding to the test predictions on top of the UMAP embedding, we can visually assess whether single-cell patches of the same class correctly cluster together in 2D space.

In [None]:
# convert single-channel test images to rgb three-channel images
print(f"shape of test images: {test_images_array.shape}")
rgb_images = np.concatenate([test_images_array]*3, axis=-1)
print(f"shape of rgb test images: {rgb_images.shape}")
# normalise image values to 0-1 range (Min-Max scaling) & convert to 8-bit
rgb_images = ((rgb_images-np.min(rgb_images))/(np.ptp(rgb_images)) * 255).astype(np.uint8)

Create the grid of image patches corresponding to the UMAP embedding. This is basically a 2D histogram where points on a same grid cell are binned and the average of the corresponding images is calculated before being overlaid.

In [None]:
projection = ManifoldProjection2D(rgb_images)
img_grid, heatmap, delimiters = projection(mapper.embedding_, components=(0,1))

Create a figure to show the image projection. 
You can uncomment the last line if you want to save the projection as `.png` file, it will appear in the Files tab (if you don't see it, press the middle Refresh button at the top of the tab). 
Remember to then go on the "..." button to the right of the `.png` file to download it.
Reminder: Files saved during a Colab session will be lost upon closing this session!

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))

im = plt.imshow(img_grid,
                origin="lower",
#                 extent=delimiters, 
                cmap="gray",)

plt.tight_layout()
plt.colorbar()

# (optional) uncomment the below line to save the UMAP image patch projection
# fig.savefig(f"umap_{mapper.n_neighbors}nbs_rnd{mapper.random_state}.png")