In [1]:
import tensorflow as tf
import numpy as np
import os
import tqdm
import glob

In [2]:
MODEL_CP_DIR = '/content/drive/My Drive/AGCNNMulti/DenseNetComPALADouble1/model-{epoch:03d}-{loss:03f}-{val_loss:03f}.h5'
MODEL_LOG_DIR ='/content/drive/My Drive/AGCNNMulti/DenseNetComPALADouble1/logs'
MODEL_AUC_LOG_DIR = '/content/drive/My Drive/AGCNNMulti/DenseNetComPALADouble1/logs/aucperepoch.csv'

FRONT_BASE_MODEL_DIR = '/content/drive/My Drive/AGCNNMerge/DenseNetComPALA/model-003-0.124592-0.139646.h5'
SIDE_BASE_MODEL_DIR = '/content/drive/My Drive/AGCNNMerge/DenseNetCom(PA)LA/model-003-0.128735-0.144051.h5'

MODEL_ID=''

INIT_LR = 1e-4
BS=32
WK=5
EPOCH=20
SELECTED_VIEW = ['pa_la']

AUC_BS = 128

## Data Pipeline

### Image Augmentation

In [3]:
import math
from tensorflow.keras import backend as K

def get_mat(rotation,height_zoom=1,width_zoom=1):
    rotation = math.pi * rotation / 180.
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    return K.dot(rotation_matrix, zoom_matrix)

In [4]:
def transform(image):
    DIM = 224
    XDIM = DIM%2  
    rot = 10. * tf.random.normal([1],dtype='float32')
    m = get_mat(rot) 
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
    return tf.reshape(d,[DIM,DIM,3])

In [5]:
def augment_img(image):
  image = tf.image.resize(image, [264,264])
  image = tf.image.random_brightness(image, 0.4)
  image = tf.image.random_crop(image, [224,224,3])
  image = tf.image.random_flip_left_right(image)
  image = transform(image)
  return image

### TFRecord

In [6]:
def augment_img_wot(image):
  image = tf.image.resize(image, [264,264])
  image = tf.image.crop_to_bounding_box(image, 20,20, 224,224)
  return image

In [7]:
def read_tfrecord(example, output_mode='both', with_transform=False):
    features = {
        "image_front": tf.io.FixedLenFeature([], tf.string),
        "image_side": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.VarLenFeature(tf.float32),  
        "view": tf.io.FixedLenFeature([], tf.string),
        "select": tf.io.FixedLenFeature([], tf.int64),
        "study_id": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(example, features)
    label  = tf.reshape(tf.sparse.to_dense(example['label']), [1, 14])
    label = tf.where(tf.math.is_nan(label), tf.ones_like(label) * 0, label)
    label = tf.where(label == -1, tf.ones_like(label) * 0, label)
    label = label[0]

    select = example['select']
    if output_mode=='label':
      return label, select
    elif output_mode=='view_label':
      return label, example['view'], select
    elif output_mode=='combine_label':
      return (label,label, label), select
    image_front = tf.image.decode_jpeg(example['image_front'], channels=3)
    image_side =  tf.image.decode_jpeg(example['image_side'], channels=3)

    if with_transform:
      image_front = augment_img(image_front)
      image_side = augment_img(image_side)
    else: 
      image_front = augment_img_wot(image_front)
      image_side = augment_img_wot(image_side)
    if output_mode=='both':
      return [image_front, image_side], label, select
    elif output_mode=='img_view':
      return [image_front, image_side], example['view'], select
    elif output_mode=='both_view':
      return [image_front, image_side], label, example['view'], select
    else:
      return [image_front, image_side], select

In [8]:
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False

### Image Selection and Preprocessing

In [9]:
def select_image(*record):
  select = record[-1]
  return select == 1

def remove_select(*record):
  return record[:-1]

In [10]:
def select_view_cat(view, *record):
  current_view = record[-1]
  if isinstance(view, list):
    matched = False
    for cat in view:
      if current_view == cat:
        matched = True
    return matched
  return current_view == view

def remove_view(*record):
  return record[:-1]

In [11]:
def select_view(view, *record):
  record = record[0]
  img = None
  if view == 'frontal':
    img = record[0][0]
  elif view == 'side':
    img = record[0][1]
  else: 
    img = record[0]
  if len(record) == 1:
    return img
  return img, record[1]

In [12]:
def preprocess_image(image):
  image = tf.cast(image, tf.float32)
  image = image / 255.
  image_net_mean = np.array([0.485, 0.456, 0.406])
  image_net_std = np.array([0.229, 0.224, 0.225])
  image = tf.math.subtract(image, image_net_mean)
  image = tf.math.divide(image, image_net_std)
  return image
def preprocess_image_dataset(image, label=None, replicate_label=False):
  image = preprocess_image(image)
  if len(image.shape) ==5:
    image = (image[:,0], image[:,1])
  if label is None:
    return image
  if replicate_label:
    label = (label, label, label)
  return image, label

In [13]:
def choose_view(ds, view):
  ds = ds.filter(lambda *record: select_view_cat(view, *record)).map(remove_view, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return ds

In [14]:
def choose_selected(ds):
  ds = ds.filter(select_image).map(remove_select, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return ds

### Data Loader

In [15]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_train_mul_nr/*.tfrec')
train_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
train_dsrl = train_dsr.shuffle(300000).with_options(option_no_order)
train_dsrl = train_dsrl.map(lambda record : read_tfrecord(record, 'both', True), num_parallel_calls=tf.data.experimental.AUTOTUNE)
#  select only images that were randomly selected
train_dsrs = choose_selected(train_dsrl)

In [16]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_val_mul_nr/*.tfrec')
val_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
val_dsrl = val_dsr.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)

#  select only images that were randomly selected
val_dsrs = choose_selected(val_dsrl)

In [17]:
train_dsrl = train_dsr.shuffle(30000).with_options(option_no_order)

train_dsrs_view = train_dsrl.map(lambda record : read_tfrecord(record, 'both_view'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_dsrs_view = val_dsr.map(lambda record : read_tfrecord(record, 'both_view'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_dsrs_view = choose_selected(train_dsrs_view)
val_dsrs_view = choose_selected(val_dsrs_view)

train_dsrs_view = choose_view(train_dsrs_view, SELECTED_VIEW).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
val_dsrs_view = choose_view(val_dsrs_view, SELECTED_VIEW).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

In [18]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_test_mul_nr/*.tfrec')
test_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
test_dsrl = test_dsr.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_dsrs_img = test_dsr.map(lambda record : read_tfrecord(record, 'img'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dsrs_lab = test_dsr.map(lambda record : read_tfrecord(record, 'label'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

#  select only images that were randomly selected
test_dsrs = choose_selected(test_dsrl)
test_dsrs_img = choose_selected(test_dsrs_img)
test_dsrs_lab = choose_selected(test_dsrs_lab)

In [19]:
filenames = tf.io.gfile.glob('/content/drive/My Drive/AG-CNN/tfrecord_test_mul_nr/*.tfrec')
test_dsr = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)

test_dsrs_img_view = test_dsr.map(lambda record : read_tfrecord(record, 'both_view'), num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dsrs_lab_view = test_dsr.map(lambda record : read_tfrecord(record, 'view_label'), num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_dsrs_img_view = choose_selected(test_dsrs_img_view)
test_dsrs_lab_view = choose_selected(test_dsrs_lab_view)

test_dsrs_img_view = choose_view(test_dsrs_img_view, SELECTED_VIEW).batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
test_dsrs_lab_view = choose_view(test_dsrs_lab_view, SELECTED_VIEW).prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
frontal_train_dsr = train_dsrs.map(lambda *record : select_view('frontal',record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
side_train_dsr = train_dsrs.map(lambda *record : select_view( 'side',record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
frontal_val_dsr = val_dsrs.map(lambda *record : select_view('frontal',record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
side_val_dsr = val_dsrs.map(lambda *record : select_view( 'side', record)).batch(BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
frontal_test_dsr_img = test_dsrs_img.map(lambda *record : select_view( 'frontal',  record)).batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
side_test_dsr_img = test_dsrs_img.map(lambda *record : select_view('side', record)).batch(AUC_BS).map(preprocess_image_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_dsr_lab = test_dsrs_lab.prefetch(tf.data.experimental.AUTOTUNE)
# labels for specific view
frontal_test_dsr_lab = test_dsrs.map(lambda *record : select_view( 'frontal', record)).map(lambda img, lab: lab, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)
side_test_dsr_lab  = test_dsrs.map(lambda *record : select_view('side',record)).map(lambda img, lab: lab, num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)

## Model Construction

### Crop function

In [20]:
def find_last_layer(model):
      for layer in reversed(model.layers):
        if len(layer.output_shape) == 4:
          return layer.name
      raise ValueError("Could not find the last convolution layer.")

In [21]:
@tf.function
def get_focus_area(imgs, last_conv_output, threshold = 0.7):
  last_conv_output = tf.math.abs(last_conv_output)
  last_conv_output_max = tf.math.reduce_max(last_conv_output, axis=-1)
  last_conv_output_max_min = tf.math.reduce_min(last_conv_output_max, axis=0)
  last_conv_output_max_max = tf.math.reduce_max(last_conv_output_max, axis=0)
  last_conv_output_norm = (last_conv_output_max - last_conv_output_max_min) / (last_conv_output_max_max-last_conv_output_max_min)
  last_conv_output_cast = tf.cast(last_conv_output_norm >= threshold, tf.float32)
  return last_conv_output_cast

In [22]:
@tf.function
def generate_box(masked):
  positions = tf.where(masked == 1)
  positions = tf.cast(positions, dtype=tf.float32)
  if positions.shape[0] == 0:
    box = tf.constant([0.,0., 1., 1.], dtype=tf.float32)
  else:
    xmin = tf.math.reduce_min(positions[:,0])/6.
    xmax = tf.math.reduce_max(positions[:,0])/6.
    ymin = tf.math.reduce_min(positions[:,1])/6.
    ymax = tf.math.reduce_max(positions[:,1])/6.
    if xmin == xmax:
      xmax = tf.math.minimum(xmin + 0.3, 1.)
    if ymin == ymax:
      ymax = tf.math.minimum(ymin + 0.3, 1.)
    ymin = tf.cond(tf.math.is_inf(ymin), lambda: 0., lambda: ymin)
    xmin = tf.cond(tf.math.is_inf(xmin), lambda: 0., lambda: xmin)
    ymax = tf.cond(tf.math.is_inf(ymax), lambda: 1., lambda: ymax)
    xmax = tf.cond(tf.math.is_inf(xmax), lambda: 1., lambda: xmax)
    box = tf.stack([ymin,xmin, ymax, xmax])
  return box

In [23]:
@tf.function
def generate_box_batch(data):
  imgs, last_conv_output = data
  masks = get_focus_area(imgs, last_conv_output)

  batch_size = tf.shape(imgs)[0]
  boxes = tf.TensorArray(tf.float32, size=batch_size)
  for i in range(batch_size):
    boxes.write(i, generate_box(masks[i]))
  boxes = boxes.stack()
  return crop_imgs(imgs, boxes, batch_size)

In [24]:
@tf.function
def crop_imgs(imgs, boxes, bs=BS, labels=None):
  idxs = tf.range(bs)
  cropped_images = tf.image.crop_and_resize(imgs, boxes, idxs, (224,224), method='nearest')
  if labels is None:
    return cropped_images
  return cropped_images, labels

### Model

In [25]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, GlobalAveragePooling2D, Lambda, Concatenate, GlobalMaxPooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.densenet import DenseNet121

In [26]:
def set_suffix(model, suffix):
  for layer in model.layers:
    layer._name = layer.name + suffix

In [27]:
global_model_front = DenseNet121(include_top=False,
	input_tensor=Input(shape=(224, 224, 3)))
global_model_side = DenseNet121(include_top=False,
	input_tensor=Input(shape=(224, 224, 3)))
local_model = DenseNet121(include_top=False,
	input_tensor=Input(shape=(224, 224, 3)))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/densenet/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5


In [28]:
set_suffix(global_model_front, "_front")
set_suffix(global_model_side, "_side")
set_suffix(local_model, "_local")

In [None]:
# layer = ZeroPadding2D()(local_model.layers[0].output)
# for i in range(2, len(local_model.layers)):
#   current_layer = local_model.layers[i]
#   if current_layer.name[-2:] == 'bn':
#     layer = LayerNormalization()(layer)
#   else:
#     print(layer)
#     layer = current_layer(layer)

In [None]:
# local_model.summary()

In [29]:
def createMultiDoubleModel(global_model_front, global_model_side, local_model):
  global_last_conv_layer_front = find_last_layer(global_model_front)
  global_last_conv_layer_output_front=global_model_front.get_layer(global_last_conv_layer_front).output

  global_last_conv_layer_side = find_last_layer(global_model_side)
  global_last_conv_layer_output_side=global_model_side.get_layer(global_last_conv_layer_side).output
  
  focus_layer_front = Lambda(generate_box_batch)([global_model_front.layers[0].output,global_last_conv_layer_output_front])
  focus_layer_side = Lambda(generate_box_batch)([global_model_side.layers[0].output,global_last_conv_layer_output_side])

  focus_layer_front = Lambda(lambda x: x[:,:,:,:2])(focus_layer_front)  
  focus_layer_side = Lambda(lambda x: x[:,:,:,0:1])(focus_layer_side)  

  focus_layer = Concatenate(axis=-1)([focus_layer_front, focus_layer_side])

  local_last_conv_layer = find_last_layer(local_model)
  local_last_conv_layer_output=local_model.get_layer(local_last_conv_layer).output

  tmp_local_model = Model(inputs=local_model.input, outputs=local_last_conv_layer_output)

  local_branch = tmp_local_model(focus_layer)
  local_branch = GlobalMaxPooling2D()(local_branch)
  global_branch_front = GlobalAveragePooling2D()(global_last_conv_layer_output_front)
  global_branch_side = GlobalAveragePooling2D()(global_last_conv_layer_output_side)

  merge_branch = Concatenate(axis=-1, name='concatenate_branches')([global_branch_front, global_branch_side, local_branch])
  merge_branch = Dense(14, activation="sigmoid")(merge_branch)

  model = Model(inputs=[global_model_front.input, global_model_side.input], outputs=merge_branch)

  optimizer = Adam(learning_rate=INIT_LR)
  model.compile(loss="binary_crossentropy", optimizer=optimizer,
	  metrics=[tf.keras.metrics.AUC(multi_label=True)])
  return model

In [30]:
model = createMultiDoubleModel(global_model_front, global_model_side, local_model)

In [31]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1_front (InputLayer)      [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
input_2_side (InputLayer)       [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
zero_padding2d_front (ZeroPaddi (None, 230, 230, 3)  0           input_1_front[0][0]              
__________________________________________________________________________________________________
zero_padding2d_2_side (ZeroPadd (None, 230, 230, 3)  0           input_2_side[0][0]               
____________________________________________________________________________________________

### Callbacks

In [32]:
from tensorflow.keras.callbacks import ReduceLROnPlateau,EarlyStopping, ModelCheckpoint, TensorBoard, LearningRateScheduler
import datetime
from sklearn.metrics import roc_auc_score
import csv

In [33]:
class AUCCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir, img, lab):
      self.img_set = img
      self.lab_set = lab
      self.log_dir = log_dir
        
    def on_epoch_end(self, epoch, logs={}):
      pred = self.model.predict(self.img_set, verbose=1, workers=WK, use_multiprocessing=True)
      rocs = np.zeros(14)
      for j in range(14):
        rocs[j] = roc_auc_score(self.lab_set[:,j], pred[:,j])
        
      with open(self.log_dir, 'a') as fp:
        writer = csv.writer(fp, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(rocs)
      return

In [34]:
def create_auc_logger(logdir, auc_test_img, auc_test_lab):
  test_label = list(auc_test_lab)
  test_label = np.array(tf.stack(test_label))
  test_label = test_label[:,0,:]
  return AUCCallback(logdir, auc_test_img,test_label)

## Training Process

In [35]:
tblogdir = os.path.join(MODEL_LOG_DIR, datetime.datetime.now().strftime("%Y%m%d") + '-model' + MODEL_ID)
tensorboard_callback = TensorBoard(tblogdir, histogram_freq=1)

In [36]:
auc_callback = create_auc_logger(MODEL_AUC_LOG_DIR, test_dsrs_img_view, test_dsrs_lab_view)

In [37]:
callbacks = [
             ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, min_lr=1e-6),
             EarlyStopping(patience=3, verbose=1),
             ModelCheckpoint(MODEL_CP_DIR, save_best_only=True, monitor='val_loss'),
             tensorboard_callback,
             auc_callback
  ]

In [38]:
history = model.fit(train_dsrs_view, 
                    epochs=EPOCH, 
                    callbacks=callbacks,
                    validation_data=val_dsrs_view,
                    use_multiprocessing=True,
                    workers=WK)

Epoch 1/20
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 00006: early stopping


In [None]:
test_label = list(test_dsrs_lab_view)
test_label = np.array(tf.stack(test_label))
test_label = test_label[:,0,:]

In [None]:
pred = model.predict(test_dsrs_img_view, verbose=1, workers=WK, use_multiprocessing=True)
rocs = np.zeros(14)
for j in range(14):
  rocs[j] = roc_auc_score(test_label[:,j], pred[:,j])
        
with open(MODEL_LOG_DIR + "/APauc{}.csv".format(''), 'a') as fp:
  writer = csv.writer(fp, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
  writer.writerow(rocs)

