In [3]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt

In [4]:
# Full path is needed
def get_filenames(filepath):
    return [os.path.join(filepath, file) for file in os.listdir(filepath)]

train_dir = os.path.join(os.getcwd(), 'data', 'train_tfrecords')
test_dir = os.path.join(os.getcwd(), 'data', 'test_tfrecords')

train_dataset = tf.data.TFRecordDataset(get_filenames(train_dir))
test_dataset = tf.data.TFRecordDataset(get_filenames(test_dir))

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "StudyInstanceUID": tf.io.FixedLenFeature([], tf.string),
        "ETT - Abnormal": tf.io.FixedLenFeature([], tf.int64),
        "ETT - Borderline": tf.io.FixedLenFeature([], tf.int64),
        "ETT - Normal": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Abnormal": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Borderline": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Incompletely Imaged": tf.io.FixedLenFeature([], tf.int64),
        "NGT - Normal": tf.io.FixedLenFeature([], tf.int64),
        "CVC - Abnormal": tf.io.FixedLenFeature([], tf.int64),
        "CVC - Borderline": tf.io.FixedLenFeature([], tf.int64),
        "CVC - Normal": tf.io.FixedLenFeature([], tf.int64),
        "Swan Ganz Catheter Present": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = tf.io.decode_image(example['image'])
    label = [
        tf.cast(example["ETT - Abnormal"], tf.int32),
        tf.cast(example["ETT - Borderline"], tf.int32),
        tf.cast(example["ETT - Normal"], tf.int32),
        tf.cast(example["NGT - Abnormal"], tf.int32),
        tf.cast(example["NGT - Borderline"], tf.int32),
        tf.cast(example["NGT - Incompletely Imaged"], tf.int32),
        tf.cast(example["NGT - Normal"], tf.int32),
        tf.cast(example["CVC - Abnormal"], tf.int32),
        tf.cast(example["CVC - Borderline"], tf.int32),
        tf.cast(example["CVC - Normal"], tf.int32),
        tf.cast(example["Swan Ganz Catheter Present"], tf.int32)
    ]
    return image, label


def read_unlabeled_tfrecord(example):
    FEATURES = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "StudyInstanceUID": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, FEATURES)
    image = tf.io.decode_image(example['image'])
    
    return image

In [5]:
# Train, valid spilit
parsed_train_dataset = train_dataset.map(read_labeled_tfrecord)
train_ds = parsed_train_dataset.skip(4000)
val_ds = parsed_train_dataset.take(4000)

plt.figure(figsize=(10,10))
for i, data in enumerate(train_ds.take(9)):
    img = tf.keras.preprocessing.image.array_to_img(data[0])
    plt.subplot(3,3,i+1)
    plt.imshow(img, cmap='gray')
plt.show()

NameError: name 'plt' is not defined

In [None]:
# Test
parsed_test_dataset = test_dataset.map(read_unlabeled_tfrecord)
plt.figure(figsize=(10,10))
for i, data in enumerate(parsed_test_dataset.take(9)):
    img = tf.keras.preprocessing.image.array_to_img(data)
    plt.subplot(3,3,i+1)
    plt.imshow(img, cmap='gray')
plt.show()

In [None]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tfa.image.rotate(image, random.uniform(-0.2, 0.2))
    
#     image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
#     image = tf.image.random_saturation(image, 0, 2, seed=SEED)
#     image = tf.image.adjust_saturation(image, 3)
    augment = tf.keras.Sequential([
        layers.experimental.preprocessing.Resizing(512, 512),
        layers.experimental.preprocessing.Rescaling(1. / 255),
        ])
    image = augment(image)
    

    return image, label   

test = val_ds.map(data_augment)


plt.figure(figsize=(10,10))
for i, data in enumerate(test.take(9)):
    img = tf.keras.preprocessing.image.array_to_img(data[0])
    plt.subplot(3,3,i+1)
    plt.imshow(img, cmap='gray')
plt.show()