In [1]:
# TODO

# 4. Change data-order from NHWC to NCHW to improve speed

# 5. IMPORTANT: For inference the symbol probably requires a training:false flag
# 6. IMPORTANT: See if tfrecords reduces IO latency
# 7. IMPORTANT: Multi-gpu wrapper????

In [2]:
import os
import sys
import time
import multiprocessing
import numpy as np
import pandas as pd
import tensorflow as tf
from nets import densenet  # Download from https://github.com/pudae/tensorflow-densenet
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.contrib.data import Iterator
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
from PIL import Image
import random
from common.utils import download_data_chextxray, get_imgloc_labels, get_train_valid_test_split
from common.utils import compute_roc_auc
slim = tf.contrib.slim

In [3]:
tf.__version__

'1.4.0'

In [4]:
CPU_COUNT = multiprocessing.cpu_count()
print("CPUs: ", CPU_COUNT)

CPUs:  12


In [5]:
# Globals
CLASSES = 14
WIDTH = 224
HEIGHT = 224
CHANNELS = 3
LR = 0.0001  # Effective learning-rate will decrease as BATCHSIZE rises
EPOCHS = 5
BATCHSIZE = 64  # Chainer auto scales batch
IMAGENET_RGB_MEAN = np.array([123.68, 116.78, 103.94], dtype=np.float32)
IMAGENET_SCALE_FACTOR = 0.017
TOT_PATIENT_NUMBER = 30805  # From data

In [6]:
# Paths
CSV_DEST = "chestxray"
IMAGE_FOLDER = os.path.join(CSV_DEST, "images")
LABEL_FILE = os.path.join(CSV_DEST, "Data_Entry_2017.csv")
print(IMAGE_FOLDER, LABEL_FILE)
# Model checkpoint
PRETRAINED_WEIGHTS = True
CHKPOINT = 'tfdensenet/tf-densenet121.ckpt'
CHKPOINT = 'resnet_v1_50.ckpt'

chestxray/images chestxray/Data_Entry_2017.csv


In [7]:
%%time
# Download data
print("Please make sure to download")
print("https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy")
download_data_chextxray(CSV_DEST)

Please make sure to download
https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy
Data already exists
CPU times: user 667 ms, sys: 208 ms, total: 875 ms
Wall time: 874 ms


In [8]:
#####################################################################################################
## Data Loading

In [9]:
class XrayData():
    
    def __init__(self, img_dir, lbl_file, patient_ids, mode='inference', 
                 width=WIDTH, height=HEIGHT, batch_size=BATCHSIZE, 
                 imagenet_mean=IMAGENET_RGB_MEAN, imagenet_scaling = IMAGENET_SCALE_FACTOR,
                 shuffle=True):
        # Get data
        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)
        self.data_size = len(self.labels)
        self.imagenet_mean = imagenet_mean
        self.imagenet_scaling = imagenet_scaling
        self.width = width
        self.height = height
        # Create dataset
        # Performance: https://www.tensorflow.org/versions/master/performance/datasets_performance
        # Following: https://stackoverflow.com/a/48096625/6772173
        data = tf.data.Dataset.from_tensor_slices((self.img_locs, self.labels))
        # Processing
        if mode == 'training':
            data = data.shuffle(self.data_size).map(self._parse_function_train,
                            num_parallel_calls=CPU_COUNT).prefetch(10*batch_size).batch(batch_size)
        else:
            data = data.map(self._parse_function_inference,
                            num_parallel_calls=CPU_COUNT).prefetch(10*batch_size).batch(batch_size)
        
        self.data = data        
        print("Loaded {} labels and {} images".format(len(self.labels), len(self.img_locs)))
        
        
    def _parse_function_train(self, filename, label):
        img_rgb, label = self._preprocess_image_labels(filename, label)
        # Super high CPU usuage bottlenecking GPU
        # Random crop
        img_rgb = tf.image.resize_images(img_rgb, [self.height+40, self.width+40])
        img_rgb = tf.random_crop(img_rgb, [self.height, self.width, 3])
        # Random flip
        img_rgb = tf.image.random_flip_left_right(img_rgb)
        # Random rotation
        rot_angle = np.random.randint(-10, 10)
        img_rgb = tf.contrib.image.rotate(img_rgb, rot_angle)
        return img_rgb, label
        
        
    def _parse_function_inference(self, filename, label):
        img_rgb, label = self._preprocess_image_labels(filename, label)
        # Resize to final dimensions
        img_rgb = tf.image.resize_images(img_rgb, [self.height, self.width])
        return img_rgb, label 
       
    
    def _preprocess_image_labels(self, filename, label):
        # load and preprocess the image
        img_decoded = tf.to_float(tf.image.decode_png(tf.read_file(filename), channels=3))
        img_centered = tf.subtract(img_decoded, self.imagenet_mean)
        img_rgb = img_centered * self.imagenet_scaling
        return img_rgb, label

In [10]:
train_set, valid_set, test_set = get_train_valid_test_split(TOT_PATIENT_NUMBER)

train:21563 valid:3080 test:6162


In [11]:
with tf.device('/cpu:0'):
    # Create dataset for iterator
    train_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=train_set,  mode='training')
    valid_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=valid_set, shuffle=False)
    test_dataset  = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=test_set, shuffle=False)
    
    # Create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(train_dataset.data.output_types,
                                       train_dataset.data.output_shapes)
    next_batch = iterator.get_next()

Loaded 87306 labels and 87306 images
Loaded 7616 labels and 7616 images
Loaded 17198 labels and 17198 images


In [12]:
#####################################################################################################
## Helper Functions

In [13]:
def get_symbol(model_name, in_tensor, is_training,
               chkpoint=CHKPOINT, out_features=CLASSES):
    if model_name == 'resnet50':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            base_model, _ = resnet_v1.resnet_v1_50(X, num_classes=None, 
                                                   is_training=is_training)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore()
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)   
        # Attach extra layers
        fc = tf.layers.dense(base_model, out_features, name='output')
        # Activation function will be included in loss
        sym = tf.reshape(fc, shape=[-1, out_features])
        
    elif model_name == 'densenet121':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        dense_args = densenet.densenet_arg_scope()
        print(dense_args)
        #dense_args[data_format]='NCHW'
        with slim.arg_scope(dense_args):
            logits, _ = densenet.densenet121(X, num_classes=out_features,
                                             reuse=None, is_training=is_training)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore(
            exclude=['densenet121/logits', 'predictions'])
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)  
        # Reshape logits to (None, CLASSES) since my label is (None, CLASSES)
        sym = tf.reshape(logits, shape=[-1, CLASSES])
            
    else:
        raise ValueError("Unknown model-name")
        
    return sym, init_fn

In [14]:
def init_symbol(sym, out_tensor, lr=LR):
    loss_fn = tf.nn.sigmoid_cross_entropy_with_logits(logits=sym, labels=y)
    loss = tf.reduce_mean(loss_fn)
    optimizer = tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999)
    #optimizer = tf.train.GradientDescentOptimizer(lr)
    training_op = optimizer.minimize(loss)
    return training_op, loss

In [15]:
def init_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
    if len(not_initialized_vars):
        #print("Initialising: ", not_initialized_vars)
        sess.run(tf.variables_initializer(not_initialized_vars))

In [16]:
%%time
# Place-holders
sess = tf.Session()
X = tf.placeholder(tf.float32, shape=[None, WIDTH, HEIGHT, CHANNELS])
y = tf.placeholder(tf.float32, shape=[None, CLASSES])
training = tf.placeholder(tf.bool) 
# Create symbol (for training and inference)
sym, init_fn = get_symbol(model_name='resnet50', in_tensor=X, is_training=training)

CPU times: user 3.33 s, sys: 827 ms, total: 4.15 s
Wall time: 4.15 s


In [17]:
# Create training operation
model, loss = init_symbol(sym=sym, out_tensor=y)
# Create iterator
training_init_op = iterator.make_initializer(train_dataset.data)
train_batches_per_epoch = int(np.floor(train_dataset.data_size/BATCHSIZE))

In [18]:
%%time
# Restoring parameters from tfdensenet/tf-densenet121.ckpt
init_fn(sess)
# Initialise uninitialised vars (FC layer & Adam)
init_uninitialized(sess)

INFO:tensorflow:Restoring parameters from resnet_v1_50.ckpt
CPU times: user 2.82 s, sys: 862 ms, total: 3.68 s
Wall time: 2.89 s


In [19]:
%%time
#1hr3min
for epoch in range(EPOCHS):
    
    print("Epoch number: {}".format(epoch+1))
    # Logging
    epoch_loss = []
    stime = time.time()
    # Initialize iterator with the training dataset
    sess.run(training_init_op)
    for step in range(train_batches_per_epoch):
        
        # get next batch of data
        img_batch, label_batch = sess.run(next_batch)
        # And run the training op
        _, loss_tr = sess.run([model, loss], feed_dict={X: img_batch, y: label_batch, training: True})
        epoch_loss.append(loss_tr)
        
    etime = time.time()
    print("Average loss: {}".format(np.mean(epoch_loss)))
    # 7min20s for chainer
    print("Epoch time: {0:.0f} seconds".format(etime-stime))
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

Epoch number: 1
Average loss: 0.1613440364599228
Epoch time: 604 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch number: 2
Average loss: 0.1491870880126953
Epoch time: 602 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch number: 3
Average loss: 0.1451054960489273
Epoch time: 601 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch number: 4
Average loss: 0.14188571274280548
Epoch time: 599 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch number: 5
Average loss: 0.13917018473148346
Epoch time: 600 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
CPU times: user 8h 27min 57s, sys: 10min 40s, total: 8h 38min 38s
Wall time: 50min 6s


In [22]:
%%time
# Test
testing_init_op = iterator.make_initializer(test_dataset.data)
sess.run(testing_init_op)

test_batches_per_epoch = int(np.floor(test_dataset.data_size/BATCHSIZE))
pred = tf.sigmoid(sym)
y_guess = []
y_truth = []

for step in range(test_batches_per_epoch):
    img_batch, label_batch = sess.run(next_batch)
    output = sess.run(pred, feed_dict={X: img_batch, training: True})
    y_guess.append(output)
    y_truth.append(label_batch)
        
# Concatenate
y_guess = np.concatenate(y_guess, axis=0)
y_truth = np.concatenate(y_truth, axis=0)

CPU times: user 16min 55s, sys: 10.4 s, total: 17min 5s
Wall time: 1min 32s


In [23]:
print("Test AUC: {0:.4f}".format(compute_roc_auc(y_truth, y_guess, classes=CLASSES)))
# 0.5211 (training:false)
# 0.6408 (training:true)??

Full AUC [0.8067662388310473, 0.8516482589811255, 0.7805391403572336, 0.8824562675781004, 0.8702802591190884, 0.8878719432811005, 0.6949038142574767, 0.8819361474680872, 0.6130602438290237, 0.834636699950543, 0.725231232059445, 0.7848622615438152, 0.7481382714648435, 0.8461683326513139]
Test AUC: 0.8006
