In [None]:
from models.ResNet50 import ResNet50
from utils.Reader import AdvancedReader
from utils.MyTimer import MyTimer

from sklearn.metrics import roc_curve, auc

import numpy as np
import tensorflow as tf

import pickle

import matplotlib.pyplot as plt
%matplotlib inline


onset = -1
train_source = '/gpfs01/berens/user/mayhan/kaggle_dr_data/train_JF_BG_512/'
test_source = '/gpfs01/berens/user/mayhan/kaggle_dr_data/test_JF_BG_512/'

network_config = dict([('instance_shape', [512, 512, 3]),
                       ('num_classes', 5),
                       ('conv_depths', [3, 4, 6, 3]),
                       ('num_filters', [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]]),
                       ('fc_depths', [1024]),
                       ('_lambda', 0.0001),
                       ('lr', 0.005),
                       ('decay_steps', 50000),
                       ('decay_rate', 0.9),
                       ('data_aug', True), 
                       ('batch_renorm', True),
                       ('preactivation', False),
                       ('max_iter', 2000000),
                       ('oversampling_limit', 0.2),
                       ('batch_size', 8), # ResNet50: Max batch sizes allowed by BatchNorm and BatchReNorm are 14 and 8, respectively.
                       ('val_step', 20000),
                       ('quick_dirty_val', False)
                      ])

with tf.Graph().as_default():
    drTrain = AdvancedReader(source=train_source, file_type = '.jpeg', 
                             csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/trainLabels.csv', 
                             onset_level = onset, mode = 'train'
                            )
    model = ResNet50(instance_shape = network_config['instance_shape'], 
                     num_classes = network_config['num_classes'], 
                     name='KaggleDR_ResNet50'
                    )
    model.build(conv_stack_depths = network_config['conv_depths'],
                num_filters = network_config['num_filters'], 
                fc_depths = network_config['fc_depths'],
                _lambda = network_config['_lambda'], 
                learning_rate = network_config['lr'], 
                decay_steps = network_config['decay_steps'], 
                decay_rate = network_config['decay_rate'],
                data_aug = network_config['data_aug'], 
                use_batch_renorm = network_config['batch_renorm'], 
                preactivation = network_config['preactivation']
               )
    model.initialize()
    
    with MyTimer('bazinga'):
        model.train(tr_reader=drTrain, max_iter=network_config['max_iter'], 
                    batch_size=network_config['batch_size'], normalize=True, 
                    oversampling_threshold=network_config['oversampling_limit'],
                    val_step=network_config['val_step'], quick_dirty_val=network_config['quick_dirty_val'],
                    val_source=test_source
                   )
    
    model.finalize()
    
# X-entropy across iterations
plt.figure()
plt.plot(range(0, len(model.diagnostics['losses'])), model.diagnostics['losses'], 
         color='b', linestyle='-', label='Minibatch loss')
plt.plot(range(0, len(model.diagnostics['avg_losses'])), model.diagnostics['avg_losses'], 
         color='c', linestyle='-.', label='Avg. loss', linewidth=2)
plt.xlabel("iteration")
plt.ylabel("Avg. cross entropy")
plt.title("ResNet50")
plt.legend()
plt.show()

# Validation performance across iterations
plt.figure()
x = np.multiply(network_config['val_step'], list(range(0, len(model.diagnostics['val_roc1']))))
y = model.diagnostics['val_roc1']
plt.plot(x, y, color='r', linestyle='-', label='onset 1')

x = np.multiply(network_config['val_step'], list(range(0, len(model.diagnostics['val_roc2']))))
y = model.diagnostics['val_roc2']
plt.plot(x, y, color='g', linestyle='-.', label='onset 2')
plt.xlabel("iteration")
plt.ylabel("ROC-AUC on validation set")
plt.title("ResNet50")
plt.legend()
plt.show()

diagnostics2save = model.diagnostics

# Now, save the diagnostic results
RESULTS_DIR = './RESULTS/'

key = 'ResNet50_'
if not network_config['batch_renorm']:
    key = key + 'BatchNorm_'
else:
    key = key + 'BatchReNorm_'

if network_config['preactivation']:
    key = key + 'Preactivation'
else:
    key = key + 'Original'

result_file_name = RESULTS_DIR + key + '_DIAG.pkl'
with open(result_file_name, 'wb') as filehandler:
    pickle.dump(diagnostics2save, filehandler)

conv1
	[5, 5], 64 /2
conv2/1
	[1, 1], 64 /1
	[3, 3], 64 /1
	[1, 1], 256 /1
conv2/2
	[1, 1], 64 /1
	[3, 3], 64 /1
	[1, 1], 256 /1
conv2/3
	[1, 1], 64 /1
	[3, 3], 64 /1
	[1, 1], 256 /1
conv3/1
	[1, 1], 128 /2
	[3, 3], 128 /1
	[1, 1], 512 /1
conv3/2
	[1, 1], 128 /1
	[3, 3], 128 /1
	[1, 1], 512 /1
conv3/3
	[1, 1], 128 /1
	[3, 3], 128 /1
	[1, 1], 512 /1
conv3/4
	[1, 1], 128 /1
	[3, 3], 128 /1
	[1, 1], 512 /1
conv4/1
	[1, 1], 256 /2
	[3, 3], 256 /1
	[1, 1], 1024 /1
conv4/2
	[1, 1], 256 /1
	[3, 3], 256 /1
	[1, 1], 1024 /1
conv4/3
	[1, 1], 256 /1
	[3, 3], 256 /1
	[1, 1], 1024 /1
conv4/4
	[1, 1], 256 /1
	[3, 3], 256 /1
	[1, 1], 1024 /1
conv4/5
	[1, 1], 256 /1
	[3, 3], 256 /1
	[1, 1], 1024 /1
conv4/6
	[1, 1], 256 /1
	[3, 3], 256 /1
	[1, 1], 1024 /1
conv5/1
	[1, 1], 512 /2
	[3, 3], 512 /1
	[1, 1], 2048 /1
conv5/2
	[1, 1], 512 /1
	[3, 3], 512 /1
	[1, 1], 2048 /1
conv5/3
	[1, 1], 512 /1
	[3, 3], 512 /1
	[1, 1], 2048 /1
fc1
	[4096, 1024]
logits
	[1024, 5]
Training KaggleDR_ResNet50 ...
Iter 0/200000

In [None]:
from models.ResNet50 import ResNet50
from utils.Reader import AdvancedReader

from sklearn.metrics import roc_curve, auc

import numpy as np
import tensorflow as tf

import pickle

tf.reset_default_graph()
with tf.Session() as sess:
    model = ResNet50(instance_shape = network_config['instance_shape'], 
                     num_classes = network_config['num_classes'], 
                     name='KaggleDR_ResNet50'
                    )
    model.build(conv_stack_depths = network_config['conv_depths'],
                num_filters = network_config['num_filters'], 
                fc_depths = network_config['fc_depths'],
                _lambda = network_config['_lambda'], 
                learning_rate = network_config['lr'], 
                decay_steps = network_config['decay_steps'], 
                decay_rate = network_config['decay_rate'],
                data_aug = network_config['data_aug'], 
                use_batch_renorm = network_config['batch_renorm'], 
                preactivation = network_config['preactivation']
               )
    
    # No init. or training. Just restore the variables for model.
    model.saver.restore(sess, model.model_path)
    model.session = sess
    
    print('=======================================================\nEvaluating the performance on TRAINING set')
    dr = AdvancedReader(source=train_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/trainLabels.csv', 
                        mode = 'valtest' # valtest to read all from the training set
                       )
    labels_1hot_tr, predictions_1hot_tr, roc_auc_tr_onset1, roc_auc_tr_onset2 = model.inference(dr, 
                                                                                                batch_size=network_config['batch_size'], 
                                                                                                quick_dirty=network_config['quick_dirty_val'])
    
    print('=======================================================\nEvaluating the performance on VALIDATION set')
    dr = AdvancedReader(source=test_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/retinopathy_solution.csv', 
                        mode = 'val'
                       )
    labels_1hot_val, predictions_1hot_val, roc_auc_val_onset1, roc_auc_val_onset2 = model.inference(dr, 
                                                                                                    batch_size=network_config['batch_size'], 
                                                                                                    quick_dirty=network_config['quick_dirty_val'])

    print('=======================================================\nEvaluating the performance on TEST set')
    dr = AdvancedReader(source=test_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/retinopathy_solution.csv', 
                        mode = 'test'
                       )
    labels_1hot_te, predictions_1hot_te, roc_auc_te_onset1, roc_auc_te_onset2 = model.inference(dr, 
                                                                                                batch_size=network_config['batch_size'],
                                                                                                quick_dirty=network_config['quick_dirty_val'])
    
    print('=======================================================\nEvaluating the performance on VAL and TEST sets combined')
    dr = AdvancedReader(source=test_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/retinopathy_solution.csv', 
                        mode = 'valtest'
                       )
    labels_1hot_valte, predictions_1hot_valte, roc_auc_valte_onset1, roc_auc_valte_onset2 = model.inference(dr, 
                                                                                                            batch_size=network_config['batch_size'],
                                                                                                            quick_dirty=network_config['quick_dirty_val'])
        
#### Now, save the results
result = {}
result['train_labels_1hot'] = labels_1hot_tr
result['val_labels_1hot'] = labels_1hot_val
result['test_labels_1hot'] = labels_1hot_te
result['valtest_labels_1hot'] = labels_1hot_valte
result['train_pred_1hot'] = predictions_1hot_tr
result['val_pred_1hot'] = predictions_1hot_val
result['test_pred_1hot'] = predictions_1hot_te
result['valtest_pred_1hot'] = predictions_1hot_valte

result_file_name = RESULTS_DIR + key + '.pkl'
with open(result_file_name, 'wb') as filehandler:
    pickle.dump(result, filehandler)
    

In [None]:
##################################################
### Test-time data augmentation for predictive uncertainty estimation
##################################################

from models.ResNet50 import ResNet50
from utils.Reader import AdvancedReader
from utils.DataAugmentation import data_augmentation

from sklearn.metrics import roc_curve, auc

import numpy as np
import tensorflow as tf

import pickle


onset = -1
train_source = '/gpfs01/berens/user/mayhan/kaggle_dr_data/train_JF_BG_512/'
test_source = '/gpfs01/berens/user/mayhan/kaggle_dr_data/test_JF_BG_512/'


def feed_dict(model, x_batch, y_batch, _iter=1, max_iter=1):
    progress = float(_iter) / float(max_iter)
    if progress < 0.05:  # up to this point, use BatchNorm alone
        rmax = 1.
        rmin = 1.
        dmax = 0.
    else:  # then, gradually increase the clipping values
        rmax = np.exp(2. * progress)  # 1.5
        rmin = 1. / rmax
        dmax = np.exp(2.5 * progress) - 1  # 2.
    if progress > 0.95:
        rmin = 0.
    
    feed_dict = {model.inputs: x_batch,
                 model.labels: y_batch,
                 model.is_training: False,
                 model.rmin: [rmin],
                 model.rmax: [rmax],
                 model.dmax: [dmax]
                }
    return feed_dict

def inference_with_test_time_data_aug(reader, model, test_input, test_input_aug, T=32, k=-1):
    labels_all = []
    predictions_all = []
    
    kkk = 0
    
    while not reader.exhausted_test_cases:
        org_ex, label, _ = reader.next_batch(batch_size=1, normalize=True, shuffle=False)
        
        feed_img = {test_input: org_ex}
        images = []
        labels = []
        for i in range(T):
            aug_ex = np.squeeze(model.session.run([test_input_aug], feed_dict=feed_img))
            images.append(aug_ex)
            labels.append(label)
            
        x_batch = np.reshape(np.asarray(images, dtype=np.float32), [-1, 512, 512, 3])
        y_batch = np.reshape(np.asarray(labels, dtype=np.float32), [-1, 1])
        
        predictions, labels = model.session.run([model.predictions_1hot, model.labels_1hot], 
                                                feed_dict=feed_dict(model, x_batch, y_batch)
                                               )        
        labels_all.append(labels[0])
        predictions_all.append(predictions)        
        
        if k != -1:
            k = k - 1
            if k == 0:
                dr.exhausted_test_cases = True
            print('k = %d' % k)
        kkk = kkk + 1
        if kkk % 1000 == 0:
            print('kkk = %d' % kkk)
    print('kkk = %d' % kkk)
    
    # Convert from a list of M items of size Tx5 to an array of dims MxTx5. For labels_1hot: Mx5.   
    labels_1hot = np.asarray(labels_all)
    
    predictions_all = np.asarray(predictions_all)
    
    # use the median of T predictions for the final class membership: Mx1x5
    predictions_1hot_median = np.median(predictions_all, axis=1)
    
    correct = np.equal(np.argmax(labels_1hot, axis=1), np.argmax(predictions_1hot_median, axis=1))
    acc = np.mean(np.asarray(correct, dtype=np.float32))
    print('Accuracy : %.5f' % acc)
        
    onset_level = 1
    labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
    pred_bin = np.sum(predictions_all[:, :, onset_level:], axis=2) # MxTx1
    pred_bin_median = np.median(pred_bin, axis=1) # Mx1x1  
    fpr, tpr, _ = roc_curve(labels_bin, np.squeeze(pred_bin_median))
    roc_auc_onset1 = auc(fpr, tpr)
    print('Onset level = %d\t ROC-AUC: %.5f' % (onset_level, roc_auc_onset1))
            
    onset_level = 2
    labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
    pred_bin = np.sum(predictions_all[:, :, onset_level:], axis=2) # MxTx1
    pred_bin_median = np.median(pred_bin, axis=1) # Mx1x1  
    fpr, tpr, _ = roc_curve(labels_bin, np.squeeze(pred_bin_median))
    roc_auc_onset1 = auc(fpr, tpr)
    print('Onset level = %d\t ROC-AUC: %.5f' % (onset_level, roc_auc_onset1))
        
    return labels_1hot, predictions_all
    
    
# Now, reset the graph and create a new session in which to run the model
tf.reset_default_graph()
with tf.Session() as sess:
    model = ResNet50(instance_shape = network_config['instance_shape'], 
                     num_classes = network_config['num_classes'], 
                     name='KaggleDR_ResNet50'
                    )
    model.build(conv_stack_depths = network_config['conv_depths'],
                num_filters = network_config['num_filters'], 
                fc_depths = network_config['fc_depths'],
                _lambda = network_config['_lambda'], 
                learning_rate = network_config['lr'], 
                decay_steps = network_config['decay_steps'], 
                decay_rate = network_config['decay_rate'],
                data_aug = network_config['data_aug'], 
                use_batch_renorm = network_config['batch_renorm'], 
                preactivation = network_config['preactivation']
               )
    
    # No init.  or training. Just restore the variables for model.
    model.saver.restore(sess, model.model_path)
    model.session = sess
    
    # Now, set up the data augmentation components of the graph to be used during test time.
    # These components do not belong to the model.     
    test_input = tf.placeholder(dtype=tf.float32, shape=[None, 512, 512, 3], name='test_input')
    test_input_aug = data_augmentation(test_input)
    # end of test-time augmentation subnetwork
    
    T = 4 # number of MC samples
    k = 10  # early stop threshold for inference (for a quick evaluation)
    
    print('=======================================================\nEvaluating the performance on TRAINING set')
    dr = AdvancedReader(source=train_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/trainLabels.csv', 
                        mode = 'valtest' # valtest to read all from the training set
                       )
    labels_1hot_tr, predictions_1hot_tr_ttaug = inference_with_test_time_data_aug(reader=dr, model=model, 
                                                                                  test_input=test_input, 
                                                                                  test_input_aug=test_input_aug, 
                                                                                  T=T, k=k)
    
    print('=======================================================\nEvaluating the performance on VALIDATION set')
    dr = AdvancedReader(source=test_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/retinopathy_solution.csv', 
                        mode = 'val'
                       )
    labels_1hot_val, predictions_1hot_val_ttaug = inference_with_test_time_data_aug(reader=dr, model=model, 
                                                                                    test_input=test_input, 
                                                                                    test_input_aug=test_input_aug, 
                                                                                    T=T, k=k)
    
    print('=======================================================\nEvaluating the performance on TEST set') 
    dr = AdvancedReader(source=test_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/retinopathy_solution.csv', 
                        mode = 'test'
                       )
    labels_1hot_te, predictions_1hot_te_ttaug = inference_with_test_time_data_aug(reader=dr, model=model, 
                                                                                  test_input=test_input, 
                                                                                  test_input_aug=test_input_aug, 
                                                                                  T=T, k=k)
    
    print('=======================================================\nEvaluating the performance on VAL and TEST sets combined')
    dr = AdvancedReader(source=test_source,
                        csv_file='/gpfs01/berens/user/mayhan/kaggle_dr_data/retinopathy_solution.csv', 
                        mode = 'valtest'
                       )
    labels_1hot_valte, predictions_1hot_valte_ttaug = inference_with_test_time_data_aug(reader=dr, model=model, 
                                                                                        test_input=test_input, 
                                                                                        test_input_aug=test_input_aug, 
                                                                                        T=T, k=k)
  
    
#### Now, save the results
result_ttaug = {}
result_ttaug['train_labels_1hot'] = labels_1hot_tr
result_ttaug['val_labels_1hot'] = labels_1hot_val
result_ttaug['test_labels_1hot'] = labels_1hot_te
result_ttaug['valtest_labels_1hot'] = labels_1hot_valte
result_ttaug['train_pred_1hot'] = predictions_1hot_tr_ttaug
result_ttaug['val_pred_1hot'] = predictions_1hot_val_ttaug
result_ttaug['test_pred_1hot'] = predictions_1hot_te_ttaug
result_ttaug['valtest_pred_1hot'] = predictions_1hot_valte_ttaug

result_file_name = RESULTS_DIR + key + '_TTAUG.pkl'
with open(result_file_name, 'wb') as filehandler:
    pickle.dump(result_ttaug, filehandler)