In [None]:
# TODO

# 0. Rewrite to use tf.estimator.Estimator()!


# 1. Move func:get_imgloc_labels into common/utilities.py
# 2. Move train/val/test split as function into common/utilities.py
# 3. Move auc_roc function into common/utilities.py
# 4. Change data-order from NHWC to NCHW to improve speed

# 5. IMPORTANT: For inference the symbol probably requires a training:false flag
# 6. IMPORTANT: See if tfrecords reduces IO latency
# 7. IMPORTANT: Multi-gpu wrapper????

In [None]:
%%bash
#wget -N http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
#tar -xvf resnet_v1_50_2016_08_28.tar.gz
#rm resnet_v1_50_2016_08_28.tar.gz

In [None]:
import os
import sys
import time
import multiprocessing
import numpy as np
import pandas as pd
import tensorflow as tf
from nets import densenet  # Download from https://github.com/pudae/tensorflow-densenet
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.contrib.data import Iterator
from sklearn.metrics.ranking import roc_auc_score
from sklearn.model_selection import train_test_split
from PIL import Image
import random
from common.utils import *
slim = tf.contrib.slim

In [None]:
tf.__version__

In [None]:
CPU_COUNT = multiprocessing.cpu_count()
print("CPUs: ", CPU_COUNT)

In [None]:
# Globals
CLASSES = 14
WIDTH = 224
HEIGHT = 224
CHANNELS = 3
LR = 0.0001  # Effective learning-rate will decrease as BATCHSIZE rises
EPOCHS = 5
BATCHSIZE = 64  # Chainer auto scales batch
IMAGENET_RGB_MEAN = np.array([123.68, 116.78, 103.94], dtype=np.float32)
IMAGENET_SCALE_FACTOR = 0.017
TOT_PATIENT_NUMBER = 30805  # From data

In [None]:
# Paths
CSV_DEST = "chestxray"
IMAGE_FOLDER = os.path.join(CSV_DEST, "images")
LABEL_FILE = os.path.join(CSV_DEST, "Data_Entry_2017.csv")
print(IMAGE_FOLDER, LABEL_FILE)
# Model checkpoint
PRETRAINED_WEIGHTS = True
CHKPOINT = 'tfdensenet/tf-densenet121.ckpt'

In [None]:
%%time
# Download data
print("Please make sure to download")
print("https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy")
download_data_chextxray(CSV_DEST)

In [None]:
#####################################################################################################
## Data Loading

In [None]:
class XrayData():
    
    def __init__(self, img_dir, lbl_file, patient_ids, mode='inference', 
                 width=WIDTH, height=HEIGHT, batch_size=BATCHSIZE, 
                 imagenet_mean=IMAGENET_RGB_MEAN, imagenet_scaling = IMAGENET_SCALE_FACTOR,
                 shuffle=True):
        # Get data
        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)
        self.data_size = len(self.labels)
        self.imagenet_mean = imagenet_mean
        self.imagenet_scaling = imagenet_scaling
        self.width = width
        self.height = height
        # Create dataset
        # Performance: https://www.tensorflow.org/versions/master/performance/datasets_performance
        # Following: https://stackoverflow.com/a/48096625/6772173
        data = tf.data.Dataset.from_tensor_slices((self.img_locs, self.labels))
        # Processing
        if mode == 'training':
            data = data.shuffle(self.data_size).map(self._parse_function_train,
                            num_parallel_calls=CPU_COUNT).prefetch(10*batch_size).batch(batch_size)
        else:
            data = data.map(self._parse_function_inference,
                            num_parallel_calls=CPU_COUNT).prefetch(10*batch_size).batch(batch_size)
        
        self.data = data        
        print("Loaded {} labels and {} images".format(len(self.labels), len(self.img_locs)))
        
        
    def _parse_function_train(self, filename, label):
        img_rgb, label = self._preprocess_image_labels(filename, label)
        # Super high CPU usuage bottlenecking GPU
        # Random crop
        img_rgb = tf.image.resize_images(img_rgb, [self.height+40, self.width+40])
        img_rgb = tf.random_crop(img_rgb, [self.height, self.width, 3])
        # Random flip
        img_rgb = tf.image.random_flip_left_right(img_rgb)
        # Random rotation
        rot_angle = np.random.randint(-10, 10)
        img_rgb = tf.contrib.image.rotate(img_rgb, rot_angle)
        return img_rgb, label
        
        
    def _parse_function_inference(self, filename, label):
        img_rgb, label = self._preprocess_image_labels(filename, label)
        # Resize to final dimensions
        img_rgb = tf.image.resize_images(img_rgb, [self.height, self.width])
        return img_rgb, label 
       
    
    def _preprocess_image_labels(self, filename, label):
        # load and preprocess the image
        img_decoded = tf.to_float(tf.image.decode_png(tf.read_file(filename), channels=3))
        img_centered = tf.subtract(img_decoded, self.imagenet_mean)
        img_rgb = img_centered * self.imagenet_scaling
        return img_rgb, label
    
    
def get_imgloc_labels(img_dir, lbl_file, patient_ids):
    """ Function to process data into a list of img_locs containing string paths
    and labels, which are one-hot encoded.
    
    Move this to the common/utilities file"""
    # Read labels-csv
    df = pd.read_csv(lbl_file)
    # Process
    # Split labels on unfiltered data
    df_label = df['Finding Labels'].str.split(
        '|', expand=False).str.join(sep='*').str.get_dummies(sep='*')
    # Filter by patient-ids (both)
    df_label['Patient ID'] = df['Patient ID']
    df_label = df_label[df_label['Patient ID'].isin(patient_ids)]
    df = df[df['Patient ID'].isin(patient_ids)]
    # Remove unncessary columns
    df_label.drop(['Patient ID','No Finding'], axis=1, inplace=True)  

    # List of images (full-path)
    img_locs =  df['Image Index'].map(lambda im: os.path.join(img_dir, im)).values
    # One-hot encoded labels (float32 for BCE loss)
    labels = df_label.values   
    return img_locs, labels

In [None]:
# Training / Valid / Test split (70% / 10% / 20%)
train_set, other_set = train_test_split(
    range(1,TOT_PATIENT_NUMBER+1), train_size=0.7, test_size=0.3, shuffle=False)
valid_set, test_set = train_test_split(other_set, train_size=1/3, test_size=2/3, shuffle=False)
print("train:{} valid:{} test:{}".format(
    len(train_set), len(valid_set), len(test_set)))

In [None]:
with tf.device('/cpu:0'):
    # Create dataset for iterator
    train_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=train_set,  mode='training')
    valid_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=valid_set, shuffle=False)
    test_dataset  = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=test_set, shuffle=False)
    
    # Create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(train_dataset.data.output_types,
                                       train_dataset.data.output_shapes)
    next_batch = iterator.get_next()

In [None]:
#####################################################################################################
## Helper Functions

In [None]:
def get_symbol(model_name, in_tensor, 
               reuse=True, is_training=True, chkpoint=CHKPOINT, out_features=CLASSES):
    """ Conver to dictionary lookup """
    if model_name == 'resnet50':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            base_model, _ = resnet_v1.resnet_v1_50(X, num_classes=None, 
                                                   is_training=is_training)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore()
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)   
        # Attach extra layers
        fc = tf.layers.dense(base_model, out_features, name='output')
        # Activation function will be included in loss
        sym = tf.reshape(fc, shape=[-1, out_features])
        
    elif model_name == 'densenet121':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        dense_args = densenet.densenet_arg_scope()
        #dense_args[data_format]='NCHW'
        with slim.arg_scope(dense_args):
            logits, _ = densenet.densenet121(X, num_classes=out_features, 
                                             is_training=is_training, reuse=reuse)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore(
            exclude=['densenet121/logits', 'predictions'])
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)  
        # Reshape logits to (None, CLASSES) since my label is (None, CLASSES)
        sym = tf.reshape(logits, shape=[-1, CLASSES])
            
    else:
        raise ValueError("Unknown model-name")
        
    return sym, init_fn

In [None]:
def init_symbol(sym, out_tensor, lr=LR, multi_gpu=True):
    loss_fn = tf.nn.sigmoid_cross_entropy_with_logits(logits=sym, labels=y)
    loss = tf.reduce_mean(loss_fn)
    optimizer = tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999)
    training_op = optimizer.minimize(loss)
    return training_op, loss

In [None]:
def init_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
    if len(not_initialized_vars):
        #print("Initialising: ", not_initialized_vars)
        sess.run(tf.variables_initializer(not_initialized_vars))

In [None]:
def compute_roc_auc(data_gt, data_pd, full=True, classes=CLASSES):
    # Push to util
    roc_auc = []
    for i in range(classes):
        roc_auc.append(roc_auc_score(data_gt[:, i], data_pd[:, i]))
    print("Full AUC", roc_auc)
    roc_auc = np.mean(roc_auc)
    return roc_auc

In [None]:
%%time
# Place-holders
sess = tf.Session()
X = tf.placeholder(tf.float32, shape=[None, WIDTH, HEIGHT, CHANNELS])
y = tf.placeholder(tf.float32, shape=[None, CLASSES])
# Create symbol
sym, init_fn = get_symbol(model_name='densenet121', in_tensor=X)

In [None]:
# Create training operation
model, loss = init_symbol(sym=sym, out_tensor=y)
# Create iterator
training_init_op = iterator.make_initializer(train_dataset.data)
train_batches_per_epoch = int(np.floor(train_dataset.data_size/BATCHSIZE))

In [None]:
%%time
# Restoring parameters from tfdensenet/tf-densenet121.ckpt
init_fn(sess)
# Initialise uninitialised vars (FC layer & Adam)
init_uninitialized(sess)

In [None]:
"""
Epoch number: 1
Average loss: 0.1680990606546402
Epoch time: 768 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch number: 2
Average loss: 0.1511792093515396
Epoch time: 760 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""

In [None]:
%%time
for epoch in range(EPOCHS):
    
    print("Epoch number: {}".format(epoch+1))
    # Logging
    epoch_loss = []
    stime = time.time()
    # Initialize iterator with the training dataset
    sess.run(training_init_op)
    
    for step in range(train_batches_per_epoch):
        
        # get next batch of data
        img_batch, label_batch = sess.run(next_batch)
        # And run the training op
        _, loss_tr = sess.run([model, loss], feed_dict={X: img_batch, y: label_batch})
        epoch_loss.append(loss_tr)
        
    etime = time.time()
    print("Average loss: {}".format(np.mean(epoch_loss)))
    # 7min20s for chainer
    print("Epoch time: {0:.0f} seconds".format(etime-stime))
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

In [None]:
# Create graph for testing
sym_test, _ = get_symbol(model_name='densenet121', in_tensor=X, is_training=False)

In [None]:
%%time
# Test
testing_init_op = iterator.make_initializer(test_dataset.data)
sess.run(testing_init_op)

test_batches_per_epoch = int(np.floor(test_dataset.data_size/BATCHSIZE))
pred = tf.sigmoid(sym_test)
y_guess = []

for step in range(test_batches_per_epoch):
    # get next batch of data
    img_batch, _ = sess.run(next_batch)
    output = sess.run(pred, feed_dict={X: img_batch})
    y_guess.append(output)
        
# Concatenate
y_guess = np.concatenate(y_guess, axis=0)

In [None]:
y_truth = test_dataset.labels
y_truth = y_truth[:len(y_guess)]  # Iterator only returns complete batches

In [None]:
print("Test AUC: {0:.4f}".format(compute_roc_auc(y_truth, y_guess)))
# 0.7755 if training:False
# x if training-flag omitted -> no effect
# Test AUC: 0.6500