In [2]:
import tensorflow as tf
import numpy as np
import datetime
import os
import glob

In [34]:
MAIN_AUC_FILE_DIR = '/content/drive/My Drive/DenseNetTransferPALA/logs/aucperepoch.csv'

FRONT_BASE_DIR = '/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Guide_MIMIC_OGR_Baseline_BS32/DenseNetPA(LA)Baseline/model-001-0.141840-0.143799.h5'
SIDE_BASE_DIR = '/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Guide_MIMIC_OGR_Baseline_BS32/DenseNet(PA)LABaseline/model-004-0.115540-0.148441.h5'

MODEL_CP_DIR = '/content/drive/My Drive/DenseNetTransferPALA/'
MODEL_LOG_DIR = '/content/drive/My Drive//DenseNetTransferPALA/logs'

MODEL_TYPES = ['full', 'frontal', 'combined', 'side']

INIT_LR = 1e-4
BS=32
WK=5
EPOCH=50

AUC_BS = 512

## **Data Pipeline**

### Image Augmentation

In [3]:
import math
from tensorflow.keras import backend as K

def get_mat(rotation,height_zoom=1,width_zoom=1):
    rotation = math.pi * rotation / 180.
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    return K.dot(rotation_matrix, zoom_matrix)

In [4]:
def transform(image):
    DIM = 224
    XDIM = DIM%2  
    rot = 10. * tf.random.normal([1],dtype='float32')
    m = get_mat(rot) 
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
    return tf.reshape(d,[DIM,DIM,3])

In [5]:
def augment_img(image):
  image = tf.image.resize(image, [264,264])
  image = tf.image.random_brightness(image, 0.4)
  image = tf.image.random_crop(image, [224,224,3])
  image = tf.image.random_flip_left_right(image)
  image = transform(image)
  return image

### TFrecord

In [6]:
def augment_img_wot(image):
  image = tf.image.resize(image, [264,264])
  image = tf.image.crop_to_bounding_box(image, 20,20, 224,224)
  return image

In [7]:
def read_tfrecord(example, output_mode='both', with_transform=False):
    features = {
        "image_front": tf.io.FixedLenFeature([], tf.string),
        "image_side": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.VarLenFeature(tf.float32),  
        "view": tf.io.FixedLenFeature([], tf.string),
        "select": tf.io.FixedLenFeature([], tf.int64),
        "study_id": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(example, features)
    label  = tf.reshape(tf.sparse.to_dense(example['label']), [1, 14])
    label = tf.where(tf.math.is_nan(label), tf.ones_like(label) * 0, label)
    label = tf.where(label == -1, tf.ones_like(label) * 0, label)
    label = label[0]

    select = example['select']
    if output_mode=='label':
      return label, select
    elif output_mode=='view_label':
      return label, example['view'], select
    elif output_mode=='combine_label':
      return (label,label, label), select
    image_front = tf.image.decode_jpeg(example['image_front'], channels=3)
    image_side =  tf.image.decode_jpeg(example['image_side'], channels=3)

    if with_transform:
      image_front = augment_img(image_front)
      image_side = augment_img(image_side)
    else: 
      image_front = augment_img_wot(image_front)
      image_side = augment_img_wot(image_side)
    if output_mode=='both':
      return [image_front, image_side], label, select
    elif output_mode=='img_view':
      return [image_front, image_side], example['view'], select
    elif output_mode=='both_view':
      return [image_front, image_side], label, example['view'], select
    else:
      return [image_front, image_side], select

In [8]:
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False

### Data loader

In [9]:
def select_image(*record):
  select = record[-1]
  return select == 1

def remove_select(*record):
  return record[:-1]

In [10]:
def select_view_cat(view, *record):
  current_view = record[-1]
  return current_view == view

def remove_view(*record):
  return record[:-1]

In [11]:
def select_view(view, *record):
  record = record[0]
  img = None
  if view == 'frontal':
    img = record[0][0]
  elif view == 'side':
    img = record[0][1]
  else: 
    img = record[0]
  if len(record) == 1:
    return img
  return img, record[1]

In [12]:
def preprocess_image(image):
  image = tf.cast(image, tf.float32)
  image = image / 255.
  image_net_mean = np.array([0.485, 0.456, 0.406])
  image_net_std = np.array([0.229, 0.224, 0.225])
  image = tf.math.subtract(image, image_net_mean)
  image = tf.math.divide(image, image_net_std)
  return image
def preprocess_image_dataset(image, label=None, replicate_label=False):
  image = preprocess_image(image)
  if len(image.shape) ==5:
    image = (image[:,0], image[:,1])
  if label is None:
    return image
  if replicate_label:
    label = (label, label, label)
  return image, label

In [13]:
def choose_view(ds, view):
  ds = ds.filter(lambda *record: select_view_cat(view, *record)).map(remove_view, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return ds

In [14]:
def choose_selected(ds):
  ds = ds.filter(select_image).map(remove_select, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return ds

In [15]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_train_mul_nr/*.tfrec')
train_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
train_dsrl = train_dsr.shuffle(300000).with_options(option_no_order)
train_dsrl = train_dsrl.map(lambda record : read_tfrecord(record, 'both', True), num_parallel_calls=tf.data.experimental.AUTOTUNE)
#  select only images that were randomly selected
train_dsrs = choose_selected(train_dsrl)

In [16]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_val_mul_nr/*.tfrec')
val_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
val_dsrl = val_dsr.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)

#  select only images that were randomly selected
val_dsrs = choose_selected(val_dsrl)

In [17]:
train_dsrl = train_dsr.shuffle(300000).with_options(option_no_order)

train_dsrs_view = train_dsrl.map(lambda record : read_tfrecord(record, 'both_view'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_dsrs_view = val_dsr.map(lambda record : read_tfrecord(record, 'both_view'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_dsrs_view = choose_selected(train_dsrs_view)
val_dsrs_view = choose_selected(val_dsrs_view)

train_dsrs_view = choose_view(train_dsrs_view, 'pa_la').batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
val_dsrs_view = choose_view(val_dsrs_view, 'pa_la').batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

In [18]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_test_mul_nr/*.tfrec')
test_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
test_dsrl = test_dsr.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_dsrs_img = test_dsr.map(lambda record : read_tfrecord(record, 'img'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dsrs_lab = test_dsr.map(lambda record : read_tfrecord(record, 'label'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dsrs_lab3 = test_dsr.map(lambda record : read_tfrecord(record, 'combine_label'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

#  select only images that were randomly selected
test_dsrs = choose_selected(test_dsrl)
test_dsrs_img = choose_selected(test_dsrs_img)
test_dsrs_lab = choose_selected(test_dsrs_lab)
test_dsrs_lab3 = choose_selected(test_dsrs_lab3)

In [19]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_test_mul_nr/*.tfrec')
test_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)

test_dsrs_img_view = test_dsr.map(lambda record : read_tfrecord(record, 'both_view'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dsrs_lab_view = test_dsr.map(lambda record : read_tfrecord(record, 'view_label'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_dsrs_img_view = choose_selected(test_dsrs_img_view)
test_dsrs_lab_view = choose_selected(test_dsrs_lab_view)

test_dsrs_img_view = choose_view(test_dsrs_img_view, 'pa_la').batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
test_dsrs_lab_view = choose_view(test_dsrs_lab_view, 'pa_la').prefetch(tf.data.experimental.AUTOTUNE)

In [20]:
combine_train_ds = train_dsrs.map(lambda *record : select_view('both', record)).batch(BS)

# two side images, one label
combine_train_dsr = combine_train_ds.map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
# two side images, three label
combine_full_train_dsr = combine_train_ds.map(lambda img, lab : preprocess_image_dataset(img, lab, True), num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
frontal_train_dsr = train_dsrs.map(lambda *record : select_view('frontal',record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
side_train_dsr = train_dsrs.map(lambda *record : select_view( 'side',record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

In [21]:
combine_val_ds = val_dsrs.map(lambda *record : select_view('both', record)).batch(BS)

combine_val_dsr = combine_val_ds.map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
combine_full_val_dsr = combine_val_ds.map(lambda img, lab : preprocess_image_dataset(img, lab, True), num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
frontal_val_dsr = val_dsrs.map(lambda *record : select_view('frontal',record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
side_val_dsr = val_dsrs.map(lambda *record : select_view( 'side', record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

In [22]:
combine_test_dsr_img = test_dsrs.map(lambda *record : select_view( 'both', record)).batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
frontal_test_dsr_img = test_dsrs.map(lambda *record : select_view( 'frontal',  record)).batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
side_test_dsr_img = test_dsrs.map(lambda *record : select_view('side', record)).batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

test_dsr_lab3 = test_dsrs_lab3.prefetch(tf.data.experimental.AUTOTUNE)
test_dsr_lab = test_dsrs_lab.prefetch(tf.data.experimental.AUTOTUNE)
# labels for specific view
frontal_test_dsr_lab = test_dsrs.map(lambda *record : select_view( 'frontal', record)).map(lambda img, lab: lab, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
side_test_dsr_lab  = test_dsrs.map(lambda *record : select_view('side',record)).map(lambda img, lab: lab, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

### Data counter

In [23]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_train_mul_nr/*.tfrec')
dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE).with_options(option_no_order)
dsr = dsr.map(lambda record : read_tfrecord(record, "view_label"), num_parallel_calls=tf.data.experimental.AUTOTUNE)

#  select only images that were randomly selected
dsrs = dsr.filter(select_image).map(remove_select, num_parallel_calls=tf.data.experimental.AUTOTUNE)

ds_set = list(dsrs)

KeyboardInterrupt: ignored

In [None]:
ap_ll = []
ap_la = []
pa_ll = []
pa_la = []

In [None]:
for lab, view  in ds_set:
  view = view.numpy().decode('UTF-8')
  if view == 'ap_ll':
    ap_ll.append(lab.numpy())
  elif view =='ap_la':
    ap_la.append(lab.numpy())
  elif view == 'pa_ll':
    pa_ll.append(lab.numpy())
  elif view == 'pa_la':
    pa_la.append(lab.numpy())

In [None]:
len(ap_la)

In [None]:
len(ap_ll)

In [None]:
len(pa_la)

In [None]:
len(pa_ll)

In [None]:
dataset = pa_la

In [None]:
dataset = np.array(tf.stack(dataset))
dataset = dataset[:,0,:]

In [None]:
counts = []
for i in range(14):
  tmp = dataset[:,i]
  counts.append(np.count_nonzero(tmp == 1))

In [None]:
counts

[3036, 2640, 533, 1140, 277, 674, 971, 4852, 24677, 2513, 257, 2043, 266, 956]

## **Model creation**

### Model

In [24]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, GlobalAveragePooling2D, concatenate, Lambda, Reshape
from tensorflow.keras.optimizers import Adam

In [25]:
def fix_model(model):
  for layer in model.layers:
    layer.trainable = False

In [26]:
def create_main_model():
  front_model = tf.keras.models.load_model(FRONT_BASE_DIR)
  side_model = tf.keras.models.load_model(SIDE_BASE_DIR)

  fix_model(front_model)
  fix_model(side_model)

  for layer in side_model.layers:
    layer._name = layer.name[:-5] + "side"

  input_front = front_model.input
  input_side = side_model.input

  front_last = front_model.layers[-2].output
  side_last = side_model.layers[-2].output

  headModelCombine = concatenate([front_last, side_last], axis=-1)
  headModelCombine = Dense(14, activation="sigmoid", name='combined_output')(headModelCombine)
  input = [input_front, input_side]

  model = Model(inputs=input, outputs=headModelCombine)

  optimizer = Adam(learning_rate=INIT_LR)
  model.compile(loss="binary_crossentropy", optimizer=optimizer,
	  metrics=[tf.keras.metrics.AUC(multi_label=True)])
  return model

In [35]:
main_model = create_main_model()

In [None]:
main_model.summary()

### Callbacks

In [29]:
from sklearn.metrics import roc_auc_score
import csv

class AUCCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir, img, lab):
      self.img_set = img
      self.lab_set = lab
      self.log_dir = log_dir
        
    def on_epoch_end(self, epoch, logs={}):
      pred = self.model.predict(self.img_set, use_multiprocessing=True, workers=WK, verbose=1)
      rocs = np.zeros(14)
      for j in range(14):
        rocs[j] = roc_auc_score(self.lab_set[:,j], pred[:,j])
        
      with open(self.log_dir, 'a') as fp:
        writer = csv.writer(fp, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(rocs)
      return

In [30]:
def create_auc_logger(logdir, auc_test_img, auc_test_lab):
  test_label = list(auc_test_lab)
  test_label = np.array(tf.stack(test_label))
  test_label = test_label[:,0,:]
  return AUCCallback(logdir, auc_test_img,test_label)

In [31]:
from tensorflow.keras.callbacks import ReduceLROnPlateau,EarlyStopping, ModelCheckpoint, TensorBoard, LearningRateScheduler

def set_callback(tblogdir, checkpoint_address=None, auc_log_dir=None, auc_test_img=None, auc_test_lab=None):
  tensorboard_callback = TensorBoard(tblogdir, histogram_freq=1)
 
  callbacks = [
             ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, min_lr=1e-6),
             tensorboard_callback
  ]

  if checkpoint_address:
    checkpoint_address += 'model-{epoch:03d}-{loss:03f}-{val_loss:03f}.h5'
    callbacks += [
            EarlyStopping(patience=3, verbose=1),
            ModelCheckpoint(checkpoint_address, save_best_only=True, monitor='val_loss')
    ]

  if auc_log_dir:
    auc_logger = create_auc_logger(auc_log_dir,auc_test_img,auc_test_lab)
    callbacks.append(auc_logger)
  return callbacks

## **Training Process**

### Training

In [36]:
callbacks = set_callback(MODEL_LOG_DIR, MODEL_CP_DIR, MAIN_AUC_FILE_DIR,test_dsrs_img_view,test_dsrs_lab_view)

In [37]:
history = main_model.fit(
        train_dsrs_view,
        epochs=EPOCH, 
        validation_data= val_dsrs_view,
        callbacks=callbacks,
        use_multiprocessing=True,
        workers=WK,
        # initial_epoch=8
    )

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 00004: early stopping


In [None]:
model = tf.keras.models.load_model("/content/drive/My Drive/OGR+/DenseNetFix/model_full/fullmodel-007-0.175719-0.182741.h5")

In [None]:
_, combine_model, _ = separate_model(model)

In [None]:
test_label = list(test_dsrs_lab_view)
test_label = np.array(tf.stack(test_label))
test_label = test_label[:,0,:]

In [None]:
pred = combine_model.predict(test_dsrs_img_view, verbose=1)
rocs = np.zeros(14)
for j in range(14):
  rocs[j] = roc_auc_score(test_label[:,j], pred[:,j])



In [None]:
with open('/content/drive/My Drive/test.csv', 'a') as fp:
  writer = csv.writer(fp, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
  writer.writerow(rocs)