In [33]:
import tensorflow as tf
import numpy as np
import os
import tqdm
import glob

In [34]:
MODEL_CP_DIR = '/content/drive/My Drive/AGCNNMergeChest14/DenseNetComAPPA/model-{epoch:03d}-{loss:03f}-{val_loss:03f}.h5'
MODEL_LOG_DIR ='/content/drive/My Drive/AGCNNMergeChest14/DenseNetComAPPA/logs'
MODEL_AUC_LOG_DIR = '/content/drive/My Drive/AGCNNMergeChest14/DenseNetComAPPA/logs/aucperepoch.csv'

GLOBAL_BASE_MODEL_DIR = '/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Guide_MIMIC_OGR_Baseline_BS32/DenseNetPA(LA)Baseline/model2-004-0.109718-0.144615.h5'
LOCAL_BASE_MODEL_DIR = '/content/drive/My Drive/AGCNNLocal/DenseNetPALA/model2-010-0.143988-0.155678.h5'

MODEL_ID=''

INIT_LR = 1e-4
BS=32
WK=5
EPOCH=20

CROP_THRESHOLD = 0.7

AUC_BS = 512

## Data Pipeline

### Image Augmentation

In [35]:
import math
from tensorflow.keras import backend as K

def get_mat(rotation,height_zoom=1,width_zoom=1):
    rotation = math.pi * rotation / 180.
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    return K.dot(rotation_matrix, zoom_matrix)

In [36]:
def transform(image):
    DIM = 224
    XDIM = DIM%2  
    rot = 10. * tf.random.normal([1],dtype='float32')
    m = get_mat(rot) 
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
    return tf.reshape(d,[DIM,DIM,3])

In [37]:
def augment_img(image):
  image = tf.image.resize(image, [264,264])
  image = tf.image.random_brightness(image, 0.4)
  image = tf.image.random_crop(image, [224,224,3])
  image = tf.image.random_flip_left_right(image)
  image = transform(image)
  return image

### TFRecord

In [38]:
def augment_img_wot(image):
  image = tf.image.resize(image, [264,264])
  image = tf.image.crop_to_bounding_box(image, 20,20, 224,224)
  return image

In [39]:
def read_tfrecord(example, output_mode='both', with_transform=False):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.VarLenFeature(tf.float32),  
        "image_index": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, features)
    label  = tf.reshape(tf.sparse.to_dense(example['label']), [1, 14])
    # label = tf.where(tf.math.is_nan(label), tf.ones_like(label) * 0, label)
    # label = tf.where(label == -1, tf.ones_like(label) * 0, label)
    label = label[0]

    if output_mode=='label':
      return label
    
    image = tf.image.decode_jpeg(example['image'], channels=3)

    if with_transform:
      image = augment_img(image)
    else: 
      image = augment_img_wot(image)

    if output_mode=='img':
      return image
    return image, label

In [40]:
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False

### Image Selection and Preprocessing

In [41]:
def preprocess_image(image):
  image = tf.cast(image, tf.float32)
  image = image / 255.
  image_net_mean = np.array([0.485, 0.456, 0.406])
  image_net_std = np.array([0.229, 0.224, 0.225])
  image = tf.math.subtract(image, image_net_mean)
  image = tf.math.divide(image, image_net_std)
  return image
  
def preprocess_image_dataset(image, label=None, replicate_label=False):
  image = preprocess_image(image)
  if len(image.shape) ==5:
    image = (image[:,0], image[:,1])
  if label is None:
    return image
  if replicate_label:
    label = (label, label, label)
  return image, label

### Data Loader

In [42]:
filenames = tf.io.gfile.glob('/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Chest Xray14 Tfrecord/train_pa/*.tfrec')
train_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)

filenames = tf.io.gfile.glob('/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Chest Xray14 Tfrecord/train_ap/*.tfrec')
train_dsr_ap = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
train_dsr = train_dsr.concatenate(train_dsr_ap)

train_dsrl = train_dsr.shuffle(300000).with_options(option_no_order)
train_dsrl = train_dsrl.map(lambda record : read_tfrecord(record, 'both', True), num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(BS)
train_dsrl = train_dsrl.map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

In [43]:
filenames = tf.io.gfile.glob('/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Chest Xray14 Tfrecord/val_pa/*.tfrec')
val_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)

filenames = tf.io.gfile.glob('/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Chest Xray14 Tfrecord/val_ap/*.tfrec')
val_dsr_ap = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
val_dsr = val_dsr.concatenate(val_dsr_ap)

val_dsrl = val_dsr.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(BS)
val_dsrl = val_dsrl.map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

In [44]:
filenames = tf.io.gfile.glob('/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Chest Xray14 Tfrecord/test_pa/*.tfrec')
test_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)


filenames = tf.io.gfile.glob('/content/drive/Shared drives/CMB - corpora/Chest_x-ray_report_Jan2020/Chest Xray14 Tfrecord/test_ap/*.tfrec')
test_dsr_ap = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
test_dsr = test_dsr.concatenate(test_dsr_ap)

test_dsrs_img = test_dsr.map(lambda record : read_tfrecord(record, 'img'), num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(AUC_BS)
test_dsrs_img = test_dsrs_img.map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

test_dsrs_lab = test_dsr.map(lambda record : read_tfrecord(record, 'label'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

## Model Construction

### Crop function

In [45]:
def find_last_layer(model):
      for layer in reversed(model.layers):
        if len(layer.output_shape) == 4:
          return layer.name
      raise ValueError("Could not find the last convolution layer.")

In [46]:
@tf.function
def get_focus_area(imgs, last_conv_output, threshold = CROP_THRESHOLD):
  last_conv_output = tf.math.abs(last_conv_output)
  last_conv_output_max = tf.math.reduce_max(last_conv_output, axis=-1)
  last_conv_output_max_min = tf.math.reduce_min(last_conv_output_max, axis=0)
  last_conv_output_max_max = tf.math.reduce_max(last_conv_output_max, axis=0)
  last_conv_output_norm = (last_conv_output_max - last_conv_output_max_min) / (last_conv_output_max_max-last_conv_output_max_min)
  last_conv_output_cast = tf.cast(last_conv_output_norm >= threshold, tf.float32)
  return last_conv_output_cast

In [47]:
@tf.function
def generate_box(masked):
  positions = tf.where(masked == 1)
  positions = tf.cast(positions, dtype=tf.float32)
  if positions.shape[0] == 0:
    box = tf.constant([0.,0., 1., 1.], dtype=tf.float32)
  else:
    xmin = tf.math.reduce_min(positions[:,0])/6.
    xmax = tf.math.reduce_max(positions[:,0])/6.
    ymin = tf.math.reduce_min(positions[:,1])/6.
    ymax = tf.math.reduce_max(positions[:,1])/6.
    if xmin == xmax:
      xmax = tf.math.minimum(xmin + 0.3, 1.)
    if ymin == ymax:
      ymax = tf.math.minimum(ymin + 0.3, 1.)
    ymin = tf.cond(tf.math.is_inf(ymin), lambda: 0., lambda: ymin)
    xmin = tf.cond(tf.math.is_inf(xmin), lambda: 0., lambda: xmin)
    ymax = tf.cond(tf.math.is_inf(ymax), lambda: 1., lambda: ymax)
    xmax = tf.cond(tf.math.is_inf(xmax), lambda: 1., lambda: xmax)
    box = tf.stack([ymin,xmin, ymax, xmax])
  return box

In [48]:
@tf.function
def generate_box_batch(data):
  imgs, last_conv_output = data
  masks = get_focus_area(imgs, last_conv_output)

  batch_size = tf.shape(imgs)[0]
  boxes = tf.TensorArray(tf.float32, size=batch_size)
  for i in range(batch_size):
    boxes.write(i, generate_box(masks[i]))
  boxes = boxes.stack()
  return crop_imgs(imgs, boxes, batch_size)

In [49]:
@tf.function
def crop_imgs(imgs, boxes, bs=BS, labels=None):
  idxs = tf.range(bs)
  cropped_images = tf.image.crop_and_resize(imgs, boxes, idxs, (224,224), method='nearest')
  if labels is None:
    return cropped_images
  return cropped_images, labels

### Model

In [50]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, GlobalAveragePooling2D, Lambda, Concatenate
from tensorflow.keras.optimizers import Adam

In [19]:
# global_model = tf.keras.models.load_model(GLOBAL_BASE_MODEL_DIR)
# local_model = tf.keras.models.load_model(LOCAL_BASE_MODEL_DIR)

In [51]:
from tensorflow.keras.applications.densenet import DenseNet121

global_model = DenseNet121(weights="imagenet", include_top=False,
	input_tensor=Input(shape=(224, 224, 3)))
local_model = DenseNet121(weights="imagenet", include_top=False,
	input_tensor=Input(shape=(224, 224, 3)))

In [52]:
def freeze_model(model):
  for layer in model.layers:
    layer.trainable = False

In [53]:
def createMergeModel(global_model, local_model):
  global_last_conv_layer = find_last_layer(global_model)
  global_last_conv_layer_output=global_model.get_layer(global_last_conv_layer).output
  
  focus_layer = Lambda(generate_box_batch)([global_model.layers[0].output,global_last_conv_layer_output])

  local_last_conv_layer = find_last_layer(local_model)
  local_last_conv_layer_output=local_model.get_layer(local_last_conv_layer).output

  tmp_local_model = Model(inputs=local_model.input, outputs=local_last_conv_layer_output)

  local_branch = tmp_local_model(focus_layer)
  local_branch = GlobalAveragePooling2D()(local_branch)

  global_branch = GlobalAveragePooling2D()(global_last_conv_layer_output)

  merge_branch = Concatenate(axis=-1)([global_branch, local_branch])
  merge_branch = Dense(14, activation="sigmoid")(merge_branch)

  model = Model(inputs=global_model.input, outputs=merge_branch)

  optimizer = Adam(learning_rate=INIT_LR)
  model.compile(loss="binary_crossentropy", optimizer=optimizer,
	  metrics=[tf.keras.metrics.AUC(multi_label=True)])
  return model

In [23]:
# freeze_model(global_model)
# freeze_model(local_model)

In [54]:
model = createMergeModel(global_model, local_model)

In [None]:
model.summary()

### Callbacks

In [55]:
from tensorflow.keras.callbacks import ReduceLROnPlateau,EarlyStopping, ModelCheckpoint, TensorBoard, LearningRateScheduler
import datetime
from sklearn.metrics import roc_auc_score
import csv

In [56]:
class AUCCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir, img, lab):
      self.img_set = img
      self.lab_set = lab
      self.log_dir = log_dir
        
    def on_epoch_end(self, epoch, logs={}):
      pred = self.model.predict(self.img_set, verbose=1, workers=WK, use_multiprocessing=True)
      rocs = np.zeros(14)
      for j in range(14):
        rocs[j] = roc_auc_score(self.lab_set[:,j], pred[:,j])
        
      with open(self.log_dir, 'a') as fp:
        writer = csv.writer(fp, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(rocs)
      return

In [57]:
def create_auc_logger(logdir, auc_test_img, auc_test_lab):
  test_label = list(auc_test_lab)
  test_label = np.array(tf.stack(test_label))
  return AUCCallback(logdir, auc_test_img,test_label)

## Training Process

In [58]:
tblogdir = os.path.join(MODEL_LOG_DIR, datetime.datetime.now().strftime("%Y%m%d") + '-model' + MODEL_ID)
tensorboard_callback = TensorBoard(tblogdir, histogram_freq=1)

In [59]:
auc_callback = create_auc_logger(MODEL_AUC_LOG_DIR, test_dsrs_img, test_dsrs_lab)

In [60]:
callbacks = [
             ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, min_lr=1e-6),
             EarlyStopping(patience=3, verbose=1),
             ModelCheckpoint(MODEL_CP_DIR, save_best_only=True, monitor='val_loss',save_weights_only=True),
             tensorboard_callback,
             auc_callback
  ]

In [None]:
# model.load_weights('/content/drive/My Drive/AGCNNMerge/DenseNet/model-007-0.176370-0.181600.h5')

In [61]:
history = model.fit(train_dsrl, 
                    epochs=EPOCH, 
                    callbacks=callbacks,
                    validation_data=val_dsrl,
                    use_multiprocessing=True,
                    workers=WK)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 00011: early stopping


In [None]:
test_label = list(test_dsrs_lab_view)
test_label = np.array(tf.stack(test_label))
test_label = test_label[:,0,:]

In [None]:
pred = model.predict(test_dsrs_img_view, verbose=1, workers=WK, use_multiprocessing=True)
rocs = np.zeros(14)
for j in range(14):
  rocs[j] = roc_auc_score(test_label[:,j], pred[:,j])
        
with open(MODEL_LOG_DIR + "/APauc{}.csv".format(''), 'a') as fp:
  writer = csv.writer(fp, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
  writer.writerow(rocs)

