In [1]:
%%bash
#wget -N http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
#tar -xvf resnet_v1_50_2016_08_28.tar.gz
#rm resnet_v1_50_2016_08_28.tar.gz

In [2]:
import os
import sys
import time
import multiprocessing
import numpy as np
import pandas as pd

import tensorflow as tf
from nets import densenet  # Download from https://github.com/pudae/tensorflow-densenet
from tensorflow.contrib.slim.nets import resnet_v1
slim = tf.contrib.slim

from tensorflow.contrib.data import Dataset
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.contrib.data import Iterator

from sklearn.model_selection import train_test_split
from PIL import Image
import random
#import cv2
from common.utils import *

In [3]:
# Globals
CLASSES = 14
WIDTH = 224
HEIGHT = 224
CHANNELS = 3
LR = 0.0001  # Effective learning-rate will decrease as BATCHSIZE rises
EPOCHS = 5
#BATCHSIZE = 64*NUM_GPUS
BATCHSIZE = 64  # Chainer auto scales batch
IMAGENET_RGB_MEAN =  np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_RGB_SD =  np.array([0.229, 0.224, 0.225], dtype=np.float32)
TOT_PATIENT_NUMBER = 30805  # From data

In [4]:
# Paths
CSV_DEST = "chestxray"
IMAGE_FOLDER = os.path.join(CSV_DEST, "images")
LABEL_FILE = os.path.join(CSV_DEST, "Data_Entry_2017.csv")
print(IMAGE_FOLDER, LABEL_FILE)
# Model checkpoint
PRETRAINED_WEIGHTS = True
#CHKPOINT = 'resnet_v1_50.ckpt' 
CHKPOINT = 'tfdensenet/tf-densenet121.ckpt'

chestxray/images chestxray/Data_Entry_2017.csv


In [5]:
%%time
# Download data
print("Please make sure to download")
print("https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy")
download_data_chextxray(CSV_DEST)

Please make sure to download
https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy
Data already exists
CPU times: user 813 ms, sys: 332 ms, total: 1.15 s
Wall time: 1.15 s


In [6]:
#####################################################################################################
## Data Loading

In [7]:
class XrayData():
    
    def __init__(self, img_dir, lbl_file, patient_ids, mode='inference', 
                 width=WIDTH, height=HEIGHT, batch_size=BATCHSIZE,
                 shuffle=True, buffer_size=20):
        
        # Following: https://github.com/kratzert/finetune_alexnet_with_tensorflow/blob/master/datagenerator.py
        
        # Read labels-csv
        df = pd.read_csv(lbl_file)
        
        # Process
        # Split labels on unfiltered data
        df_label = df['Finding Labels'].str.split(
            '|', expand=False).str.join(sep='*').str.get_dummies(sep='*')
        # Filter by patient-ids (both)
        df_label['Patient ID'] = df['Patient ID']
        df_label = df_label[df_label['Patient ID'].isin(patient_ids)]
        df = df[df['Patient ID'].isin(patient_ids)]
        # Remove unncessary columns
        df_label.drop(['Patient ID','No Finding'], axis=1, inplace=True)  
        
        # List of images (full-path)
        self.img_locs =  df['Image Index'].map(lambda im: os.path.join(img_dir, im)).values
        # One-hot encoded labels (float32 for BCE loss)
        self.labels = df_label.values      
        
        # Number of samples in datatset
        self.data_size = len(self.labels)
        
        # Create dataset
        data = Dataset.from_tensor_slices((self.img_locs, self.labels))
        
        # distinguish between train/infer. when calling the parsing functions
        if mode == 'training':
            data = data.map(self._parse_function_train, num_threads=8,
                      output_buffer_size=20*batch_size)
        elif mode == 'inference':
            data = data.map(self._parse_function_train, num_threads=8,
                      output_buffer_size=20*batch_size)
        else:
            raise ValueError("Invalid mode '%s'." % (mode))

        # shuffle the first `buffer_size` elements of the dataset
        if shuffle:
            data = data.shuffle(buffer_size=buffer_size)

        # create a new dataset with batches of images
        data = data.batch(batch_size)

        self.data = data        
        print("Loaded {} labels and {} images".format(len(self.labels), len(self.img_locs)))

     
    def _parse_function_train(self, filename, label):

        # load and preprocess the image
        img_string = tf.read_file(filename)
        img_decoded = tf.image.decode_png(img_string, channels=3)
        img_resized = tf.image.resize_images(img_decoded, [224, 224])
        img_centered = tf.subtract(img_resized, IMAGENET_RGB_MEAN)

        # RGB -> BGR
        img_bgr = img_centered[:, :, ::-1]

        return img_bgr, label    

In [8]:
# Training / Valid / Test split (70% / 10% / 20%)
train_set, other_set = train_test_split(
    range(1,TOT_PATIENT_NUMBER+1), train_size=0.7, test_size=0.3, shuffle=False)
valid_set, test_set = train_test_split(other_set, train_size=1/3, test_size=2/3, shuffle=False)
print("train:{} valid:{} test:{}".format(
    len(train_set), len(valid_set), len(test_set)))

train:21563 valid:3080 test:6162


In [9]:
with tf.device('/cpu:0'):
    train_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=train_set,  mode='training')
    valid_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=valid_set, shuffle=False)
    test_dataset  = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=test_set, shuffle=False)
    
    # create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(train_dataset.data.output_types,
                                       train_dataset.data.output_shapes)
    next_batch = iterator.get_next()

Instructions for updating:
Use `tf.data.Dataset.from_tensor_slices()`.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
Loaded 87306 labels and 87306 images
Loaded 7616 labels and 7616 images
Loaded 17198 labels and 17198 images


In [10]:
#####################################################################################################
## Helper Functions

In [11]:
def get_symbol(model_name, in_tensor, chkpoint=CHKPOINT, out_features=CLASSES):
    if model_name == 'resnet50':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            base_model, _ = resnet_v1.resnet_v1_50(X, None, is_training=True)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore()
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)   
        # Attach extra layers
        fc = tf.layers.dense(base_model, out_features, name='output')
        # Activation function will be included in loss
        sym = tf.reshape(fc, shape=[-1, out_features])
        
    elif model_name == 'densenet121':
        # Load variables into model (without this nothing is restored)
        tf.train.get_or_create_global_step()
        # Import symbol
        dense_args = densenet.densenet_arg_scope()
        #print(dense_args)  # Add NCHW later
        with slim.arg_scope(dense_args):
            logits, _ = densenet.densenet121(X, num_classes=out_features, is_training=True, reuse=None)
        # Collect variables to restore from checkpoint
        variables_to_restore = slim.get_variables_to_restore(
            exclude=['densenet121/logits', 'predictions'])
        #print(variables_to_restore)
        init_fn = slim.assign_from_checkpoint_fn(chkpoint, variables_to_restore)  
        # Reshape logits to (None, CLASSES) since my label is (None, CLASSES)
        sym = tf.reshape(logits, shape=[-1, CLASSES])
            
    else:
        raise ValueError("Unknown model-name")
        
    return sym, init_fn

In [12]:
def init_symbol(sym, out_tensor, lr=LR):
    loss_fn = tf.nn.sigmoid_cross_entropy_with_logits(logits=sym, labels=y)
    loss = tf.reduce_mean(loss_fn)
    optimizer = tf.train.AdamOptimizer(lr, beta1=0.9, beta2=0.999)
    training_op = optimizer.minimize(loss)
    return training_op, loss

In [13]:
def init_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
    if len(not_initialized_vars):
        #print("Initialising: ", not_initialized_vars)
        sess.run(tf.variables_initializer(not_initialized_vars))

In [14]:
%%time
# Place-holders
X = tf.placeholder(tf.float32, shape=[None, WIDTH, HEIGHT, CHANNELS])
y = tf.placeholder(tf.float32, shape=[None, CLASSES])

# Create symbol
sym, init_fn = get_symbol(model_name='densenet121', in_tensor=X)

# Create training operation
model, loss = init_symbol(sym=sym, out_tensor=y)
training_init_op = iterator.make_initializer(train_dataset.data)
train_batches_per_epoch = int(np.floor(train_dataset.data_size/BATCHSIZE))

CPU times: user 10.2 s, sys: 132 ms, total: 10.3 s
Wall time: 10.3 s


In [15]:
%%time
# Launch session and load model from checkpoint
sess = tf.Session()

# Temp
if PRETRAINED_WEIGHTS:
    print("Loading pre-trained weights")
    init_fn(sess)  # Load from checkpoint

# Initialise uninitialised vars (FC layer & Adam)
init_uninitialized(sess)

Loading pre-trained weights
INFO:tensorflow:Restoring parameters from tfdensenet/tf-densenet121.ckpt
CPU times: user 5.21 s, sys: 1.47 s, total: 6.67 s
Wall time: 6.26 s


In [16]:
#Epoch number: 1
#Average loss: 0.1651265025138855
#Epoch time: 510 seconds
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#Epoch number: 2
#Average loss: 0.14739550650119781
#Epoch time: 511 seconds

In [None]:
for epoch in range(EPOCHS):
    
    print("Epoch number: {}".format(epoch+1))
    # Logging
    epoch_loss = []
    stime = time.time()
    # Initialize iterator with the training dataset
    sess.run(training_init_op)
    
    for step in range(train_batches_per_epoch):
        
        # get next batch of data
        img_batch, label_batch = sess.run(next_batch)
        # And run the training op
        _, loss_val = sess.run([model, loss], feed_dict={X: img_batch, y: label_batch})
        epoch_loss.append(loss_val)
        
    etime = time.time()
    print("Average loss: {}".format(np.mean(epoch_loss)))
    # 7min20s for chainer
    print("Epoch time: {0:.0f} seconds".format(etime-stime))
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

Epoch number: 1
Average loss: 0.17311111092567444
Epoch time: 702 seconds
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch number: 2
