This colab processes the COCA dataset and trains models using the Keras framework with a number of different architectures, losses, etc

# Initialize

In [None]:
!pip install pydicom
!pip install tensorflow_addons

import ast
import collections
import dataclasses
import datetime
import logging
import math
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import plistlib
import pydicom
import pytz
import sys
import tensorflow as tf
from matplotlib.path import Path
from tensorflow import keras
from typing import Dict, List, Optional, Text, Tuple
from tensorflow.python.client import device_lib
from tensorflow.keras import backend as K
import tensorflow_addons as tfa

assert len(tf.config.experimental.list_physical_devices('GPU')) > 0

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

device_lib.list_local_devices()

project_base_dir = ''
parsed_lesions_path = ''
gated_dir = ''

np.random.seed(1216)



# Constants and Classes

In [None]:
@dataclasses.dataclass
class Artery:
  name: Text
  class_index: int
  color: Text

@dataclasses.dataclass
class Lesion:
  points: List[Tuple[float, float]]
  artery: Text

@dataclasses.dataclass
class Slice:
  patient_id: Text
  image_path: Text 
  patient_split: Text
  has_calcification: bool
  lesions: List[Lesion]
  image: np.ndarray
  ground_truth_mask: np.ndarray

ARTERIES = {'Right Coronary Artery': Artery(
                name='Right Coronary Artery',
                class_index=0,
                color='red',),
            'Left Anterior Descending Artery': Artery(
                name='Left Anterior Descending Artery',
                class_index=1,
                color='blue',),
            'Left Coronary Artery': Artery(
                name='Left Coronary Artery',
                class_index=2,
                color='green',),
            'Left Circumflex Artery': Artery(
                name='Left Circumflex artery',
                class_index=3,
                color='yellow',)
            }

# Create CSV (Run Once)

## Create lesions

In [None]:
patient_id_to_lesions = get_patient_id_to_lesions(gated_dir) 

In [None]:
def get_lesions(plist: Dict) -> Dict[int, List[Lesion]]:
  """Get lesion metadata from the parsed XML dict.

    Args:
      plist: Parsed XML dictionary.
    Returns:
      Dict[image_index, List[Lesion]]
  """
  def get_point(point):
    parsed_point = ast.literal_eval(point)
    for point in parsed_point:
      assert point > 0 and point <= 512, f'Invalid point {point}'
    return parsed_point

  image_index_to_lesions = {}
  for image in plist['Images']:
    image_index = image['ImageIndex']
    if image_index in image_index_to_lesions:
      raise ValueError(f'For patient {patient_id}, duplicate image index {image_index}')
    expected_num_rois = image['NumberOfROIs']
    assert (expected_num_rois == len(image['ROIs']),
      f'Expected {expected_num_rois} ROIs but found ')
    for roi in image['ROIs']:
      artery = roi['Name']
      if artery not in ARTERIES:
        logging.error(f'ROI has unexpected artery {artery}')
        continue
      if not roi['Point_px']:
        continue
      points = [get_point(point) for point in roi['Point_px']]
      lesion = Lesion(points=points, artery=artery)
      # TODO: prettify
      if not image_index in image_index_to_lesions:
        image_index_to_lesions[image_index] = [lesion]
      else:
        image_index_to_lesions[image_index].append(lesion)
  return image_index_to_lesions

def get_patient_id_to_lesions(gated_dir) -> Dict[int, Dict[int, Lesion]]:
  """Reads all XML data.
    Args:
      gated_dir: Path to gated directory. All XML files should be unnested in
        this folder.
    Returns:
      Dict[patient_id, Dict[image_index, Lesion]].
  """
  xml_dir = os.path.join(gated_dir, "calcium_xml")
  patient_id_to_lesions = {}
  for _, _, files in os.walk(xml_dir):
    for fname in files:
      if not fname.endswith('.xml'):
        continue
      patient_id = int(fname.split('.xml')[0])
      if patient_id > 450:
        continue
      fpath = os.path.join(xml_dir, fname)
      with open(fpath, 'rb') as f:
        logging.debug(f'Opened {fname}')
        plist = plistlib.load(f)
        image_index_to_lesions = get_lesions(plist)
      patient_id_to_lesions[patient_id] = image_index_to_lesions
  return patient_id_to_lesions
  
  

  assert (expected_num_rois == len(image['ROIs']),


In [None]:
df_dict = collections.defaultdict(list)

images_dir = os.path.join(gated_dir, "patient")
# images_dir = "/content/drive/MyDrive/Grad School Dreams/Stanford Grad Certificate/CS230/CS230 Project/Data sample/coca_sample/Gated_release_final/patient"

def add_healthy_image(df_dict, patient_id, relative_image_path, image_index):
  df_dict['patient_id'].append(patient_id)
  df_dict['image_index'].append(image_index)
  df_dict['lesion_points'].append(np.nan)
  df_dict['artery'].append(np.nan)
  df_dict['artery_class'].append(np.nan)
  df_dict['artery_color'].append(np.nan)
  df_dict['image_path'].append(relative_image_path)

image_count = 0
for patient_dir in os.listdir(images_dir):
  patient_dir_path = os.path.join(images_dir, patient_dir)
  if not os.path.isdir(patient_dir_path):
    continue
  patient_id = int(patient_dir)
  print(f'patient {patient_dir}')
  healthy_patient = False
  if patient_id > 450:
    # Patients 451 and above are known to be healthy patients.
    healthy_patient = True
  elif patient_id not in patient_id_to_lesions:
    logging.error(f'Patient {patient_id} unexpectedly has no metadata. Skipping.')
    continue
  else:
    # Patient has metadata of the form Dict[image_index, List[Lesion]]
    image_index_to_lesions = patient_id_to_lesions[patient_id]
  for subdir in os.listdir(patient_dir_path):
    subdir_path = os.path.join(patient_dir_path, subdir)
    if not os.path.isdir(subdir_path):
      continue
    for image_index, image_name in enumerate(sorted(os.listdir(subdir_path), reverse=True)):
      image_count += 1
      full_image_path = os.path.join(subdir_path, image_name)
      relative_image_path = full_image_path.split('Gated_release_final')[1]
      if healthy_patient or image_index not in image_index_to_lesions:
        add_healthy_image(df_dict, patient_id, relative_image_path, image_index)
      else:
        image_lesions = image_index_to_lesions[image_index]
        for lesion in image_lesions:
          df_dict['patient_id'].append(patient_id)
          df_dict['image_index'].append(image_index)
          df_dict['lesion_points'].append(lesion.points)
          df_dict['artery'].append(lesion.artery)
          df_dict['artery_class'].append(ARTERIES[lesion.artery].class_index)
          df_dict['artery_color'].append(ARTERIES[lesion.artery].color)
          df_dict['image_path'].append(relative_image_path)

df = pd.DataFrame(df_dict)

In [None]:
train_dev_fname = '/content/cs230-Coronary-Calcium-Scoring-/dataset/gated_train_dev_pids.dump'
test_fname = '/content/cs230-Coronary-Calcium-Scoring-/dataset/gated_test_pids.dump'
with open(train_dev_fname, 'rb') as f:
  train_patient_ids, dev_patient_ids = pickle.load(f)

with open(test_fname, 'rb') as f:
  test_patient_ids = pickle.load(f)

patient_split_map = {}
for pid in train_patient_ids:
  patient_split_map[int(pid)] = 'train'
for pid in dev_patient_ids:
  patient_split_map[int(pid)] = 'tune'
for pid in test_patient_ids:
  patient_split_map[int(pid)] = 'test'

df['patient_split'] = df.patient_id.map(patient_split_map)
df['has_calcification'] = df.lesion_points.apply(lambda x: isinstance(x, list))


In [None]:
df.artery.value_counts(normalize=True)

INFO:numexpr.utils:NumExpr defaulting to 2 threads.


Left Anterior Descending Artery    0.370996
Right Coronary Artery              0.370352
Left Circumflex Artery             0.213906
Left Coronary Artery               0.044745
Name: artery, dtype: float64

# Read Parsed Lesions (Run Each Time)

In [None]:
def read_parsed_lesions(parsed_lesions_path):
  BAD_PATIENT_IDS = set(
    [78, 120, 146, 3336, 3309])
  def maybe_literal_eval(x):
    try:
      return ast.literal_eval(x)
    except:
      return x

  with open(parsed_lesions_path, 'r') as f:
    df = pd.read_csv(f)
  bad_row_indices = df.loc[df.patient_id.isin(BAD_PATIENT_IDS)].index
  df = df.drop(index=bad_row_indices)
  df['lesion_points'] = df.lesion_points.apply(maybe_literal_eval)
  df = df.sample(frac=1, random_state=1216)
  return df

df = read_parsed_lesions(parsed_lesions_path)
train_image_paths = list(df.loc[df.patient_split == 'train'].image_path.unique())
tune_image_paths = list(df.loc[df.patient_split == 'tune'].image_path.unique())
test_image_paths = list(df.loc[df.patient_split == 'test'].image_path.unique())

# Visualize lesions

In [None]:
def plot_image_with_overlay(image_slice: Slice, prediction):
  image = image_slice.image
  ground_truth = image_slice.ground_truth_mask

  fig = plt.figure(figsize=(15, 15))
  ax = fig.add_subplot(1, 5, 1)
  ax.set_title('Original image')
  ax.imshow(image, cmap='gray', interpolation=None)
  ax = fig.add_subplot(1, 5, 2)
  ax.set_title('Ground truth mask')
  ax.imshow(ground_truth, cmap='gray', interpolation=None)
  ax = fig.add_subplot(1, 5, 3)
  ax.set_title('Image with ground truth mask')
  ax.imshow(image, cmap='gray', interpolation=None)
  ax.imshow(ground_truth, cmap='gray', alpha=0.5, interpolation=None)
  # if prediction:
  ax = fig.add_subplot(1, 5, 4)
  ax.set_title('Predicted mask')
  ax.imshow(prediction, cmap='gray', interpolation=None)

  ax = fig.add_subplot(1, 5, 5)
  ax.set_title('Image with predicted mask')
  ax.imshow(image, cmap='gray', interpolation=None)
  ax.imshow(prediction, cmap='gray', alpha=0.5, interpolation=None)

def plot_image_slice(image_slice: Slice, prediction):
  plot_image_with_overlay(image_slice, prediction=prediction)
  


## Plot ground truth

In [None]:
def plot_ground_truth(image, lesions, patient_id):
  title = f'Slice for patient {patient_id}'
  fig = plt.figure(figsize=(15, 15))
  ax = fig.add_subplot(1, 3, 1)
  ax.set_title(f'Original: {title}')
  plt.imshow(image, cmap='gray', interpolation=None)
  ax = fig.add_subplot(1, 3, 2)
  ax.set_title(f'MPL: {title}')
  for lesion in lesions:
    ax.add_patch(patches.Polygon(lesion.points, closed=True,
                                 color=ARTERIES[lesion.artery].color))
  ax.imshow(image, cmap='gray', interpolation=None)
  ax = fig.add_subplot(1, 3, 3)
  ax.set_title(f'Mask: {title}')
  mask = create_mask(image, lesions)
  if not lesions:
    assert np.sum(mask) < 0.000001, f'Bad mask'
  plt.imshow(mask, cmap='gray', interpolation=None)

def sample_images(df, split, num_positives, num_negatives, shuffle=False):
  if shuffle:
    df = df.sample(frac=1, random_state=1216)
  positives = df.loc[(df.patient_split == split) & (df.has_calcification==True)].image_path[0:num_positives]
  negatives = df.loc[(df.patient_split == split) & (df.has_calcification==False)].image_path[0:num_negatives]
  return pd.concat([positives, negatives])


# Model

##Define generator

In [None]:
def create_mask(image: np.ndarray, lesions: List[Lesion]):
  height, width = image.shape
  if not lesions:
    return np.zeros((height, width), dtype=np.float32)
  all_bool_masks = []
  for lesion in lesions:
    lesion_points = [(p[1], p[0]) for p in lesion.points]
    poly_path = Path(lesion_points)
    x, y = np.mgrid[:height, :width]
    coordinates = np.hstack((x.reshape(-1,1), y.reshape(-1,1)))
    bool_mask = poly_path.contains_points(coordinates).reshape(height, width)
    all_bool_masks.append(bool_mask)
  out_mask = np.zeros((height, width))
  for mask in all_bool_masks:
    out_mask += mask
  return out_mask.astype(np.float32)

def normalize_image(image: np.ndarray):
  """Normalize pixels to the range [0, 1]."""
  image_min = np.min(image)
  image_range = np.max(image) - image_min
  return ((image - image_min)/ image_range).astype(np.float32)

def load_single_example(image_path, image_df, gated_dir): 
  assert image_df.image_path.nunique() == 1
  full_image_path = os.path.join(gated_dir, image_path[1:])
  image = pydicom.dcmread(full_image_path).pixel_array
  image = normalize_image(image)
  lesions = []
  for _, lesion_row in image_df.iterrows():
    if isinstance(lesion_row.lesion_points, float):
      # No lesions for this image
      continue
    lesions.append(Lesion(
        artery=lesion_row.artery,
        points = lesion_row.lesion_points))
  mask = create_mask(image, lesions)
  assert image_df.patient_id.nunique() == 1
  assert image_df.patient_split.nunique() == 1
  assert image_df.has_calcification.nunique() == 1
  return Slice(
      patient_id = image_df.iloc[0].patient_id,
      image_path = image_path,
      patient_split = image_df.iloc[0].patient_split,
      has_calcification = image_df.iloc[0].has_calcification,
      lesions = lesions,
      image = image,
      ground_truth_mask = mask
  )

class DataGenerator(tf.keras.utils.Sequence):
  def __init__(self,
               df: pd.DataFrame,
               name: Text,
               image_paths: List[Text],
               batch_size: int,
               gated_dir: Text,
               positive_upsample_factor,
               add_dim: bool = False):
    self.df = df
    self.name = name
    self.batch_size = batch_size
    self.gated_dir = gated_dir
    self.add_dim = add_dim
    if positive_upsample_factor:
      original_length = len(image_paths)
      positive_image_paths = set(df.loc[(df.image_path.isin(image_paths)) & (df.has_calcification == True)].image_path)
      upsampled_positive_image_paths = list(positive_image_paths) * positive_upsample_factor
      image_paths = image_paths + upsampled_positive_image_paths
      additional_images = len(image_paths)-original_length
      print(f'Added {additional_images} images from {len(positive_image_paths)} original positives')
    else:
      print(f'No upsampling for {name}')
    np.random.shuffle(image_paths)
    self.image_paths = image_paths
    
  def __len__(self):
    length = math.ceil(len(self.image_paths) / self.batch_size)
    print(f'For {self.name} datagen length is {length}')
    return length

  def _prepare_batch(self, df, gated_dir): # TODO: rename to batch_df
    X_list = []
    Y_list = []
    grouped_df = df.groupby('image_path')
    for i, (image_path, image_df) in enumerate(grouped_df):
      image_slice = load_single_example(image_path, image_df, gated_dir) 
      if self.add_dim:
        X_list.append(np.expand_dims(image_slice.image, axis=2))
      else:
        X_list.append(image_slice.image)
      Y_list.append(image_slice.ground_truth_mask)
    X = np.array(X_list)
    Y = np.array(Y_list)
    return X, Y

  def __getitem__(self, idx):
    batch_image_paths = self.image_paths[idx * self.batch_size: (idx+1) * self.batch_size]
    batch_metadata_df = self.df[self.df.image_path.isin(batch_image_paths)]
    X, Y = self._prepare_batch(batch_metadata_df, self.gated_dir)
    return X, Y

  def on_epoch_end(self):
    np.random.shuffle(self.image_paths)



## Define model

### U-Net

In [None]:
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dense, Flatten, Dropout, BatchNormalization, Activation, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
import tensorflow.keras.layers as layers

def get_unet(dropout=None, batch_norm=False):

  def conv2d_block(input_tensor, n_filters, kernel_size=3):
    # TODO: revisit default kernel size.
    """Function to add 2 convolutional layers."""
    # first layer
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size),
                kernel_initializer='he_normal', padding='same')(input_tensor)
    if batch_norm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    if dropout:
        x = Dropout(dropout)(x)

    # second layer
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size),
                kernel_initializer='he_normal', padding='same')(x)
    if batch_norm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    if dropout:
        x = Dropout(dropout)(x)
    return x

  inputs = Input(shape=(512, 512, 1))

  # Encoder path
  # Conv2D-64 (3x3, same) -> 512x512x64 -> maxpool -> 256x256x64
  c1 = conv2d_block(inputs, 64)
  p1 = MaxPooling2D(pool_size=(2,2), strides=(2,2))(c1)
  # if dropout:
  #     p1 = Dropout(dropout)(p1)

  # # Conv2D-128 (3x3, same) -> 256x256x128 -> maxpool -> 128x128x128
  c2 = conv2d_block(p1, 128)
  p2 = MaxPooling2D(pool_size=(2,2), strides=(2,2))(c2)
  # if dropout:
  #     p2 = Dropout(dropout)(p2)

  # # Conv2D-256 (3x3, same) -> 128x128x128 -> maxpool -> 64x64x256
  c3 = conv2d_block(p2, 256)
  p3 = MaxPooling2D(pool_size=(2,2), strides=(2,2))(c3)
  # if dropout:
  #     p3 = Dropout(dropout)(p3)

  # # Conv2D-512 (3x3, same) -> 64x64x256 -> maxpool -> 32x32x512
  c4 = conv2d_block(p3, 512)
  p4 = MaxPooling2D(pool_size=(2, 2), strides=(2,2))(c4)
  # if dropout:
  #     p4 = Dropout(dropout)(p4)

  # # Conv2D-512 (3x3, same) -> 32x32x1024
  c5 = conv2d_block(p4, 1024)
  p5 = c5
  # if dropout:
  #     p5 = Dropout(dropout)(p5)

  # # Decoder path (64x64)
  u4 = Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same')(p5)
  u4 = layers.concatenate([u4, c4])
  # if dropout:
  #     u4 = Dropout(dropout)(u4)
  u4 = conv2d_block(u4, 512)

  # # 128x128
  u3 = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(u4)
  u3 = layers.concatenate([u3, c3])
  # if dropout:
  #     u3 = Dropout(dropout)(u3)
  u3 = conv2d_block(u3, 256)

  # # 256x256
  u2 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(u3)
  u2 = layers.concatenate([u2, c2])
  # if dropout:
  #     u2 = Dropout(dropout)(u2)
  u2 = conv2d_block(u2, 128)

  # # 512x512
  u1 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(u2)
  u1 = layers.concatenate([u1, c1])
  # if dropout:
  #     u1 = Dropout(dropout)(u1)
  u1 = conv2d_block(u1, 64)
  outputs = Conv2D(1, kernel_size=(1, 1), activation="sigmoid")(u1)

  model = Model(inputs=inputs, outputs=outputs)
  print(type(model))
  return model


### Attention U-Net

In [None]:
from keras.layers import Activation, add, multiply, Lambda
from keras.layers import AveragePooling2D, average, UpSampling2D, Dropout
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
from keras.initializers import glorot_normal

# From https://github.com/nabsabraham/focal-tversky-unet

# K.set_image_data_format('channels_last')  # TF dimension ordering in this code
kinit = 'glorot_normal'

def expend_as(tensor, rep,name):
	my_repeat = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), arguments={'repnum': rep},  name='psi_up'+name)(tensor)
	return my_repeat

def UnetConv2D(input, outdim, is_batchnorm, name):
	x = Conv2D(outdim, (3, 3), strides=(1, 1), kernel_initializer=kinit, padding="same", name=name+'_1')(input)
	if is_batchnorm:
		x =BatchNormalization(name=name + '_1_bn')(x)
	x = Activation('relu',name=name + '_1_act')(x)

	x = Conv2D(outdim, (3, 3), strides=(1, 1), kernel_initializer=kinit, padding="same", name=name+'_2')(x)
	if is_batchnorm:
		x = BatchNormalization(name=name + '_2_bn')(x)
	x = Activation('relu', name=name + '_2_act')(x)
	return x

def UnetGatingSignal(input, is_batchnorm, name):
  shape = K.int_shape(input)
  x = Conv2D(shape[3] * 1, (1, 1), strides=(1, 1), padding="same",  kernel_initializer=kinit, name=name + '_conv')(input)
  if is_batchnorm:
      x = BatchNormalization(name=name + '_bn')(x)
  x = Activation('relu', name = name + '_act')(x)
  return x


def AttnGatingBlock(x, g, inter_shape, name, dropout_block):
  shape_x = K.int_shape(x)  # 32
  shape_g = K.int_shape(g)  # 16

  theta_x = Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same', name='xl'+name)(x)  # 16
  shape_theta_x = K.int_shape(theta_x)

  phi_g = Conv2D(inter_shape, (1, 1), padding='same')(g)
  upsample_g = Conv2DTranspose(inter_shape, (3, 3),strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),padding='same', name='g_up'+name)(phi_g)  # 16

  concat_xg = add([upsample_g, theta_x])
  act_xg = Activation('relu')(concat_xg)
  if dropout_block:
    act_xg = Dropout(dropout_block, name='drop_psi'+name)(act_xg)
  psi = Conv2D(1, (1, 1), padding='same', name='psi'+name)(act_xg)
  sigmoid_xg = Activation('sigmoid')(psi)
  shape_sigmoid = K.int_shape(sigmoid_xg)
  upsample_psi = UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

  upsample_psi = expend_as(upsample_psi, shape_x[3],  name)
  y = multiply([upsample_psi, x], name='q_attn'+name)

  result = Conv2D(shape_x[3], (1, 1), padding='same',name='q_attn_conv'+name)(y)
  result_bn = BatchNormalization(name='q_attn_bn'+name)(result)
  return result_bn

def get_attn_unet(dropout_rate, dropout_block, input_size=(512, 512, 1)):   
    # New: adds dropout to the attention gating block before the linear layer.
    inputs = Input(shape=input_size)
    conv1 = UnetConv2D(inputs, 32, is_batchnorm=True, name='conv1')
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = UnetConv2D(pool1, 32, is_batchnorm=True, name='conv2')
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = UnetConv2D(pool2, 64, is_batchnorm=True, name='conv3')
    if dropout_rate:
      conv3 = Dropout(dropout_rate,name='drop_conv3')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = UnetConv2D(pool3, 64, is_batchnorm=True, name='conv4')
    if dropout_rate:
      conv4 = Dropout(dropout_rate, name='drop_conv4')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    center = UnetConv2D(pool4, 128, is_batchnorm=True, name='center')
    
    g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1')
    attn1 = AttnGatingBlock(conv4, g1, 128, '_1', dropout_block)
    up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1')
    
    g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2')
    attn2 = AttnGatingBlock(conv3, g2, 64, '_2', dropout_block)
    up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2')

    g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3')
    attn3 = AttnGatingBlock(conv2, g3, 32, '_3', dropout_block)
    up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3')

    up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4')
    out = Conv2D(1, (1, 1), activation='sigmoid',  kernel_initializer=kinit, name='final')(up4)
    
    model = Model(inputs=[inputs], outputs=[out])
    return model

### Reza-Net

In [None]:
# From https://github.com/rezazad68/BCDU-Net
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout
from keras.layers import ConvLSTM2D
    
def get_reza_net(input_size = (512, 512, 1)):
    N = input_size[0]
    inputs = Input(input_size) 
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
  
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    drop3 = Dropout(0.5)(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    # D1
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)     
    conv4_1 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    drop4_1 = Dropout(0.5)(conv4_1)
    # D2
    conv4_2 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(drop4_1)     
    conv4_2 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4_2)
    conv4_2 = Dropout(0.5)(conv4_2)
    # D3
    merge_dense = concatenate([conv4_2,drop4_1], axis = 3)
    conv4_3 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge_dense)     
    conv4_3 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4_3)
    drop4_3 = Dropout(0.5)(conv4_3)
    
    up6 = Conv2DTranspose(256, kernel_size=2, strides=2, padding='same',kernel_initializer = 'he_normal')(drop4_3)
    up6 = BatchNormalization(axis=3)(up6)
    up6 = Activation('relu')(up6)

    x1 = Reshape(target_shape=(1, np.int32(N/4), np.int32(N/4), 256))(drop3)
    x2 = Reshape(target_shape=(1, np.int32(N/4), np.int32(N/4), 256))(up6)
    merge6  = concatenate([x1,x2], axis = 1) 
    merge6 = ConvLSTM2D(filters = 128, kernel_size=(3, 3), padding='same', return_sequences = False, go_backwards = True,kernel_initializer = 'he_normal' )(merge6)
            
    conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    up7 = Conv2DTranspose(128, kernel_size=2, strides=2, padding='same',kernel_initializer = 'he_normal')(conv6)
    up7 = BatchNormalization(axis=3)(up7)
    up7 = Activation('relu')(up7)

    x1 = Reshape(target_shape=(1, np.int32(N/2), np.int32(N/2), 128))(conv2)
    x2 = Reshape(target_shape=(1, np.int32(N/2), np.int32(N/2), 128))(up7)
    merge7  = concatenate([x1,x2], axis = 1) 
    merge7 = ConvLSTM2D(filters = 64, kernel_size=(3, 3), padding='same', return_sequences = False, go_backwards = True,kernel_initializer = 'he_normal' )(merge7)
        
    conv7 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    up8 = Conv2DTranspose(64, kernel_size=2, strides=2, padding='same',kernel_initializer = 'he_normal')(conv7)
    up8 = BatchNormalization(axis=3)(up8)
    up8 = Activation('relu')(up8)    

    x1 = Reshape(target_shape=(1, N, N, 64))(conv1)
    x2 = Reshape(target_shape=(1, N, N, 64))(up8)
    merge8  = concatenate([x1,x2], axis = 1) 
    merge8 = ConvLSTM2D(filters = 32, kernel_size=(3, 3), padding='same', return_sequences = False, go_backwards = True,kernel_initializer = 'he_normal' )(merge8)    
    
    conv8 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
    conv8 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
    conv9 = Conv2D(1, 1, activation = 'sigmoid')(conv8)

    model = Model(inputs, conv9)
    return model

## Define custom callbacks

In [None]:
def get_time():
    return datetime.datetime.now(pytz.timezone('US/Pacific'))

def get_time_str():
    return get_time().strftime("%m-%d-%H:%M")

def get_str_from_time(time):
    return time.strftime("%m-%d-%H:%M")

def get_duration_sec(first: datetime.datetime, second: datetime.datetime):
  return (second - first).seconds

class SaveMetricsCallback(keras.callbacks.Callback):
  def __init__(self, metrics_file_prefix, save_frequency_steps):
    self.epoch_path = f'{metrics_file_prefix}.epochs.csv'
    self.step_path = f'{metrics_file_prefix}.steps.csv'
    self.save_frequency_steps = save_frequency_steps
    logging.info(f'Writing metrics to\n{self.epoch_path}\n{self.step_path}')
  
  def process(self, unit_name, unit_amount, logs, metrics_path):
    pd_logs = collections.defaultdict(list)
    pd_logs['time'] = get_time_str()
    pd_logs[unit_name].append(unit_amount)
    for metric, val in logs.items():
      pd_logs[metric].append(val)
    logs_df = pd.DataFrame(pd_logs)
    with open(metrics_path, 'a') as f:
      if unit_amount == 0:
        logs_df.to_csv(f, mode='a', header=True)
      else:
        logs_df.to_csv(f, mode='a', header=False)

  def on_epoch_end(self, epoch, logs=None):
    logging.info(f'Logging epoch {epoch}')
    self.process('epoch', epoch, logs=logs,
            metrics_path=self.epoch_path)

  def on_train_batch_end(self, batch, logs=None):
    if batch % self.save_frequency_steps == 0:
      time_start = get_time()
      self.process('step', batch, logs=logs,
              metrics_path=self.step_path)
      logging.info(f'Took {get_duration_sec(time_start, get_time())} seconds to record batch metrics')

class SaveCheckpointCallback(keras.callbacks.Callback):
  def __init__(self, save_frequency_steps, checkpoint_dir):
    self.save_frequency_steps = save_frequency_steps
    self.checkpoint_dir = checkpoint_dir
    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoints_dir)
    logging.info(f'Saving checkpoints to {self.checkpoint_dir}')

  def on_train_batch_end(self, batch, logs=None):
    if batch % self.save_frequency_steps == 0:
      start_time = get_time()
      out_path = os.path.join(self.checkpoint_dir, f'{get_str_from_time(start_time)}_step_{batch}.hf5')
      tf.keras.models.save_model(self.model, filepath=out_path)
      logging.info(f'Took {get_duration_sec(start_time, get_time())} seconds to save checkpoint')


## Define loss

In [None]:
def dice_coef(y_true, y_pred, smooth=1e-8):
  y_true_f = K.flatten(y_true)
  y_pred_f = K.flatten(y_pred)
  intersection = K.sum(y_true_f * y_pred_f)
  denom = K.sum(K.square(y_true_f)) + K.sum(K.square(y_pred_f))
  dice = (2 * intersection + smooth) / (denom + smooth)
  return dice

def dice_loss(y_true, y_pred):
  return 1.0 - dice_coef(y_true, y_pred)

def focal_loss(y_true, y_pred):
  alpha = 0.55
  gamma = 2.
  y_true_f = K.flatten(y_true)
  y_pred_f = K.flatten(y_pred)
  focal = tfa.losses.SigmoidFocalCrossEntropy(reduction=tf.keras.losses.Reduction.NONE,
                                              gamma=gamma,
                                              alpha=alpha)(y_true_f, y_pred_f)
  return K.sum(focal)


## Actually train

#### Resume saved model

In [None]:
SAVED_MODELS = {
  'upsample_dice_resume': '2021-11-26-01:58.upsample_dice_resume/08.hf5',
  'focal': '2021-11-27-02:37.final_unet_real_focal/11-28-01:31_step_11500.hf5',
  'focal_attention': '2021-11-29-00:09.attn_gate_focal/11-29-08:57_step_0.hf5',
}

CUSTOM_OBJECTS = {'dice_loss': dice_loss, 'dice_coef': dice_coef, 'focal_loss': focal_loss}

def load_model(fname, project_base_dir, custom_objects=CUSTOM_OBJECTS):
  fpath = os.path.join(project_base_dir, 'checkpoints', fname)
  saved_model = keras.models.load_model(fpath, custom_objects=custom_objects)
  return saved_model

focal_vanilla_model = load_model(
    SAVED_MODELS['focal'],
    project_base_dir
)

upsample_dice_model = load_model(
    SAVED_MODELS['upsample_dice_resume'],
    project_base_dir
)

focal_attention_model = load_model(
    SAVED_MODELS['focal_attention'],
    project_base_dir
)

In [None]:
model_prefix = 'attn_gate_focal' #@param{type:'string'}
epochs =  30#@param{type:'number'}
train_batch_size =   8#@param{type:'number'}
tune_batch_size =   8#@param{type:'number'}
current_time = datetime.datetime.now(pytz.timezone('US/Pacific')).strftime("%Y-%m-%d-%H:%M")
train_positive_upsample_factor =  7#@param{type:'number'}
metric_save_frequency = 2500 #@param{type:'number'}
checkpoint_save_frequency =  2500#@param{type:'number'}
dropout_rate = 0.2 #@param{type:'number'}
# Originally 0.0001
learning_rate = 0.001 #@param{type:'number'} 
add_dim = True #@param{type:'boolean'}
resume_training = False #@param{type:'boolean'}

init_checkpoint_path = os.path.join(project_base_dir, 'checkpoints/2021-11-26-01:58.upsample_dice_resume/06.hf5')

metrics_file_prefix = os.path.join(project_base_dir, 'experiment_metrics', f'{current_time}.{model_prefix}', f'{current_time}.{model_prefix}') 
checkpoints_dir = os.path.join(project_base_dir, 'checkpoints', f'{current_time}.{model_prefix}')
checkpoint_path = os.path.join(checkpoints_dir, f'{model_prefix}_{current_time}.h5')

if not os.path.exists(os.path.dirname(metrics_file_prefix)):
  os.makedirs(os.path.dirname(metrics_file_prefix))
if not os.path.exists(checkpoints_dir):
  os.mkdir(checkpoints_dir)
else:
  raise ValueError(f'Checkpoint dir {checkpoints_dir} already exists')

model = get_attn_unet(dropout_rate=dropout_rate, dropout_block=dropout_rate)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
              loss=focal_loss,
              metrics=['accuracy', 
                        dice_coef,
                      ])
print(f'Starting at {current_time}')
print(f'Writing checkpoint to {checkpoint_path}')

def lr_scheduler(epoch, lr):
  if epoch < 5:
    return 1e-3
  return 0.0005

history = model.fit(
    # initial_epoch=8, # TODO: get rid
    x=DataGenerator(
        name='train',
        df=df,
        image_paths=train_image_paths, #[0:32],
        batch_size=train_batch_size,
        gated_dir=gated_dir,
        positive_upsample_factor=train_positive_upsample_factor,
        add_dim=add_dim),
    epochs=epochs,
    validation_data=DataGenerator(
        name='tune',
        df=df,
        image_paths=tune_image_paths, #[0:32],
        batch_size=tune_batch_size,
        gated_dir=gated_dir,
        positive_upsample_factor=0,
        add_dim=add_dim),
    callbacks=[
      SaveMetricsCallback(metrics_file_prefix,
                          save_frequency_steps=metric_save_frequency),
      SaveCheckpointCallback(save_frequency_steps=checkpoint_save_frequency,
                             checkpoint_dir=checkpoints_dir),
      tf.keras.callbacks.LearningRateScheduler(lr_scheduler),          
    ],
    verbose=1,
)

## Evaluate model

In [None]:
test_results = focal_attention_model.evaluate(
    x=DataGenerator(
        name='test',
        df=df,
        image_paths=test_image_paths,
        batch_size=32,
        gated_dir=gated_dir,
        positive_upsample_factor=0,
        add_dim=True),
    verbose=1,
    return_dict=True,
)

## Evaluate on single example

In [None]:
target_image = '/patient/401/Pro_Gated_Calcium_Score_(CS)_3.0_Qr36_2_BestDiast_72_%/IM-7084-0013.dcm'
target_image = target_image
image_df = df[df.image_path == target_image] 
example_slice = load_single_example(
    image_path=target_image,
    image_df=image_df,
    gated_dir=gated_dir,
)


In [None]:
def prepare_input_image(image, expand_dims):
  image = np.expand_dims(image, axis=0)
  if expand_dims:
    return np.expand_dims(image, axis=3)
  else:
    return np.expand_dims(image, axis=2)

predictions = saved_reza(prepare_input_image(example_slice.image, True), training=False).numpy()
predictions = np.squeeze(predictions)

## Evaluate on batch

In [None]:
def get_prediction(saved_model, image, expand_dims=False):
  image = np.expand_dims(image, axis=0)
  if expand_dims:
    image = np.expand_dims(image, axis=3)
  prediction = saved_model(image, training=False).numpy()
  # prediction = (prediction > 0.5).astype(np.float32)
  return np.squeeze(prediction)


def eval_batch(saved_model, expand_dims, evaluation_image_paths, df):
  for image_path in evaluation_image_paths:
    print(image_path)
    image_df = df.loc[df.image_path == image_path]
    image_slice = load_single_example(image_path, image_df, gated_dir)
    prediction = get_prediction(saved_model, image_slice.image, expand_dims)
    print(f'prediction sum {np.sum(prediction)} and gt sum {np.sum(image_slice.ground_truth_mask)}')
    plot_image_slice(image_slice, prediction)
  

In [None]:
sampled_image_paths = sample_images(df, 'tune', num_positives=5, num_negatives=2, shuffle=False)
eval_batch(saved_focal_unet, False, sampled_image_paths, df)

In [None]:
sampled_image_paths = sample_images(df, 'tune', num_positives=5, num_negatives=2, shuffle=False)
eval_batch(saved_attn_net, True, sampled_image_paths, df)

## Compute tune set metrics

In [None]:
tune_positives = list((df.loc[(df.has_calcification == True) & (df.patient_split == 'tune')]).image_path)[0:33]
tune_negatives = list((df.loc[(df.has_calcification == False) & (df.patient_split == 'tune')]).image_path)[0:128]
train_positives = list((df.loc[(df.has_calcification == True) & (df.patient_split == 'train')]).image_path)[0:33]
train_negatives = list((df.loc[(df.has_calcification == False) & (df.patient_split == 'train')]).image_path)[0:33]

# Agatston Bucketing

In [None]:
volume_scores_path = os.path.join(project_base_dir, 'agatston', 'FINAL_test_set_agststons_volume.csv')
volume_scores = read_csv(volume_scores)

In [None]:
volume_scores.image_gt_agatston.mean()

197.78481012658227

In [None]:
fig = plt.figure()
fig.suptitle('Distribution of Volume-Level Agatston Scores')
ax = fig.add_subplot(1, 2, 1)
ax.set_title('Ground Truth')
ax.hist(list(volume_scores.image_gt_agatston))
ax = fig.add_subplot(1, 2, 2)
ax.set_title('Predicted')
ax.hist(list(volume_scores.predicted_agatston))

In [None]:
def bucket(score):
  if score == 0:
    return 0
  if 1 <= score <= 10:
    return 1
  if 11 <= score <= 99:
    return 2
  if 100 <= score <= 300:
    return 3
  if 400 <= score <= 999:
    return 4
  return 5

volume_scores['gt_bucket'] = volume_scores.image_gt_agatston.apply(bucket)
volume_scores['predicted_bucket'] = volume_scores.predicted_agatston.apply(bucket)
