In [1]:
import numpy as np

%tensorflow_version 2.x
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

In [None]:
# Determines the maximum height and width of all images in the dataset.
def image_max_height_max_width(dataset):
  max_img_height = 0           
  max_img_width = 0
  for (input, target) in dataset:      
    img_shape = input.shape
    if img_shape[0] > max_img_height: 
      max_img_height = img_shape[0]
    if img_shape[1] > max_img_width: 
      max_img_width = img_shape[1]   
  return max_img_height, max_img_width

In [None]:
# Input pipeline
size_of_dataset = 27558
batch_size = 16

raw_dataset = tfds.load('Malaria', split='train', as_supervised = True)    # Name the split parameter to get a tf.dataset with tuples. Otherwise one gets a dictionary.
raw_dataset = raw_dataset.shuffle(size_of_dataset)                         # Shuffles the entire dataset once in the beginning. Later in the training loop shuffling with the batch_size as buffer is sufficient.


# Pads the images to the previous detected maximum height and width of the images in the dataset.
# Zero Padding is useful here, because the cells are presented in front of a dark background, so in this case it just increases the background surface.
img_max_height, img_max_width = image_max_height_max_width(raw_dataset)     
padded_images = raw_dataset.map(lambda inp, tar: tf.image.resize_with_pad(inp, img_max_height, img_max_width))
# Standardizes the images with a mean of zero and a standard deviation of 1.
padded_images = padded_images.map(lambda img: tf.image.per_image_standardization(img)) 

# One-hot encoded targets
one_hot_targets = raw_dataset.map(lambda inp, tar: tf.one_hot(tar, 2))
print(one_hot_targets)


# Seperates the inputs (images) and targets into training and test data.
splitting_limit_train_test_data = 22000
training_dataset_inputs = padded_images.take(splitting_limit_train_test_data)      # new datasets with all items up to the splitting limit
training_dataset_targets = one_hot_targets.take(splitting_limit_train_test_data)   

test_dataset_inputs = padded_images.skip(splitting_limit_train_test_data)          # new datasets with all items from the splitting limit
test_dataset_targets = one_hot_targets.skip(splitting_limit_train_test_data)       


# Zips together, batches and prefetches the training and test datasets.
training_dataset = tf.data.Dataset.zip((training_dataset_inputs, training_dataset_targets))
training_dataset = training_dataset.batch(batch_size).prefetch(1)                  # Prefetches 1 batch of size 64
#training_dataset = training_dataset.batch(batch_size)
#training_dataset = training_dataset.prefetch(128)

test_dataset = tf.data.Dataset.zip((test_dataset_inputs, test_dataset_targets))
test_dataset = test_dataset.batch(batch_size).prefetch(1)   

#for (img,_) in test_dataset:
#  print(img.shape)
#  break;