# 3D Image Classification from CT Scans

**Author:** [Hasib Zunair](https://twitter.com/hasibzunair)<br>
**Date created:** 2020/09/23<br>
**Last modified:** 2020/09/23<br>
**Description:** Train a 3D convolutional neural network to predict presence of pneumonia.

## Introduction

This example will show the steps needed to build a 3D convolutional neural network (CNN)
to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are
commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D
equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan),
3D CNNs are a powerful model for learning representations for volumetric data.

## References

- [A survey on Deep Learning Advances on Different 3D DataRepresentations](https://arxiv.org/pdf/1808.01462.pdf)
- [VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition](https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)
- [FusionNet: 3D Object Classification Using MultipleData Representations](http://3ddl.cs.princeton.edu/2016/papers/Hegde_Zadeh.pdf)
- [Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction](https://arxiv.org/abs/2007.13224)

## Setup

In [None]:
!pip install pip --upgrade -q
!pip install nibabel scipy matplotlib tensorflow ipywidgets -q

In [None]:
import os
import zipfile
import numpy as np
import tensorflow as tf
import logging
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
#
# Load model from storage.
#
import requests
url = "https://koz.s3.amazonaws.com/models/3d_image_classification.h5"
model_file = '3d_image_classification.h5'

filename = os.path.join(os.getcwd(), model_file)
keras.utils.get_file(filename, url)

model = keras.models.load_model(filename)

In [None]:
#
# Load volume data from storage.
#
url = "https://koz.s3.amazonaws.com/data/ct-data.zip"
filename = os.path.join(os.getcwd(), "ct-data.zip")
keras.utils.get_file(filename, url)

# Unzip data in the newly created directory.
with zipfile.ZipFile("ct-data.zip", "r") as z_fp:
    z_fp.extractall("./")

In [None]:
import nibabel as nib
from scipy import ndimage

def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan


## Make predictions on a single CT scan

In [None]:
from ipywidgets import interact, interactive, fixed
import matplotlib as mpl
import requests


In [None]:
def predict(filename):
    #
    # payload format
    # payload = {"data": {"ndarray": X.tolist()} }
    #
    
    # 
    # Load the data set for prediction.
    #
    v = read_nifti_file(filename)

    # Local prediction.
    prediction = model.predict(np.expand_dims(v, axis=0))[0]
    logging.info(f'Local Prediction {filename} = {prediction}')

    #
    # Prediction via REST.
    #
    url = 'http://mymodel-mygroup-odh.apps.ocp.a122.sandbox1172.opentlc.com/api/v1.0/predictions'
    logging.info(f'Serializing and predicting volume {filename} via REST')
    payload = {"data": {"ndarray": v.tolist()} }
    try:
        r = requests.post(url, json = payload, timeout = 5)
        logging.debug(f'response: {r}')
        j = r.json()['data']['ndarray']
        logging.info(f'Volume {filename} prediction: {j:.3f}')

    except:
        logging.info(f'REST endpoint timed out!')
        return None
    
    pass

In [None]:
#
# Load a volume so default dimensions are known for interaction widgets.
#
global global_v
study = 0
filename = f'./data/volume{study}.nii.gz'
global_v = read_nifti_file(filename)

In [None]:
def slice_image(slice = global_v.shape[2] / 2, cmap='none'):
    return mpl.pyplot.imshow(global_v[:, :, slice], cmap=cmap, vmin=global_v.min(), vmax=global_v.max())

In [None]:
def slicer():
    global global_v
    interact(slice_image, slice = (0, global_v.shape[2] - 1, 1), cmap=['gray', 'bone', 'hot', 'magma', 'gnuplot2', 'pink']);

In [None]:

def set_volume(study = None):
    logging.debug(f'set_volume = {study}')
    if (study != None):
        filename = f'./data/volume{study}.nii.gz'
        logging.debug(f'Loading {filename}')
        filename = f'./data/volume{study}.nii.gz'
        global global_v
        global_v = read_nifti_file(filename)
        logging.debug(f'Calling slicer with {filename}, mean = {global_v.mean()}')
        predict(filename)
        slicer()
    else:
        global_v = None
    
    pass



In [None]:
interact(set_volume, study = [0, 1, 2, 3, 4, 5]);