## ChestX-Ray 14 Dataset

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers

from src.cxr14 import CXR14

(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'cx_r14',
    split=['train', 'val', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

2021-11-21 02:47:54.313348: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:923] could not open file to read NUMA node: /sys/bus/pci/devices/0000:2d:00.0/numa_node
Your kernel may have been built without NUMA support.
2021-11-21 02:47:54.339483: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:923] could not open file to read NUMA node: /sys/bus/pci/devices/0000:2d:00.0/numa_node
Your kernel may have been built without NUMA support.
2021-11-21 02:47:54.339774: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:923] could not open file to read NUMA node: /sys/bus/pci/devices/0000:2d:00.0/numa_node
Your kernel may have been built without NUMA support.
2021-11-21 02:47:54.340346: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow wi

In [3]:
print(ds_info)
print(ds_info.metadata)

tfds.core.DatasetInfo(
    name='cx_r14',
    full_name='cx_r14/1.1.0',
    description="""
    "ChestX-ray dataset comprises 112,120 frontal-view X-ray images of 30,805 unique patients with 
    the text-mined fourteen disease image labels (where each image can have multi-labels), mined 
    from the associated radiological reports using natural language processing. Fourteen common 
    thoracic pathologies include Atelectasis, Consolidation, Infiltration, Pneumothorax, Edema, 
    Emphysema, Fibrosis, Effusion, Pneumonia, Pleural_thickening, Cardiomegaly, Nodule, Mass and 
    Hernia, which is an extension of the 8 common disease patterns listed in our CVPR2017 paper. 
    Note that original radiology reports (associated with these chest x-ray studies) are not 
    meant to be publicly shared for many reasons. The text-mined disease labels are expected to 
    have accuracy >90%."
    """,
    homepage='https://nihcc.app.box.com/v/ChestXray-NIHCC',
    data_path='/home/tmarkmann/tens

### Simple Build Pipeline

In [None]:
def preproc_img(image, label):
  image = tf.image.resize(image, [224, 224])
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    preproc_img, num_parallel_calls=tf.data.AUTOTUNE)
#ds_train = ds_train.shuffle(buffer_size=1000)
ds_train = ds_train.batch(8)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [None]:
ds_test = ds_test.map(
    preproc_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(8)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

### Benchmark

In [None]:
tfds.benchmark(ds_train, batch_size=8)

### Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

#tfds.show_examples(ds_train, ds_info)
def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(np.array2string(label.numpy(), separator=','))
  plt.axis('off')
   
for image, label in ds_train.take(1).unbatch():
  show(image, label)

## Train

In [None]:
model = tf.keras.models.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(224, 224, 3)),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(14, activation='sigmoid')
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.AUC(curve='ROC',multi_label=True, num_labels=14, from_logits=False)],
)

model.summary()

In [None]:
model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)