https://www.tensorflow.org/tutorials/images/segmentation

In [1]:
import tensorflow as tf
import pix2pix
import tensorflow_datasets as tfds
from pathlib import Path
from IPython.display import clear_output
import matplotlib.pyplot as plt
import os

In [2]:
def get_label(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    file = tf.strings.split(parts[-1],'.')[-2]
    fpath = "C:/Users/harri/tensorflow_datasets/oxford_iiit_pet/annotations/trimaps/"+"._"+file+".png"
    image = tf.io.read_file(fpath)
    image = tf.image.decode_png(image, channels=1)
    #image = tf.image.convert_image_dtype(image, tf.float32)
    return image

def decode_image(img):
    img = tf.image.decode_jpeg(img, channels=3)
    #img = tf.image.convert_image_dtype(img, tf.float32)
    return img

def process_path(path):
    """
    Input: file_path of a sample image
    Output: image in 3x64x64 float32 Tensor and one hot tensor
    """
    out = {}
    image = tf.io.read_file(path)
    out['image'] = decode_image(image)
    out['segmentation_mask'] = get_label(path)
    return out


dir_path = Path("C:/Users/harri/tensorflow_datasets/oxford_iiit_pet/images/")
list_files = tf.data.Dataset.list_files(str(dir_path/'*'))
dataset = list_files.map(process_path)

In [3]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask

In [4]:
@tf.function
def load_image_train(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask

def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

In [5]:
def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

In [6]:
TRAIN_LENGTH = 5000
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

In [7]:
train = dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [8]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [10]:
OUTPUT_CHANNELS = 3

In [11]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False

Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5


In [12]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

In [13]:
def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    x = inputs

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [14]:
model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [15]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [16]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [22]:
EPOCHS = 20
VAL_SUBSPLITS = 5

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH)

Train for 78 steps
Epoch 1/20
 1/78 [..............................] - ETA: 7s

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Expected image (JPEG, PNG, or GIF), got unknown format starting with '\000\005\026\007\000\002\000\000Mac OS X'
	 [[{{node DecodePng}}]]
	 [[IteratorGetNext]]
	 [[IteratorGetNext/_2]]
  (1) Invalid argument:  Expected image (JPEG, PNG, or GIF), got unknown format starting with '\000\005\026\007\000\002\000\000Mac OS X'
	 [[{{node DecodePng}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_distributed_function_14746]

Function call stack:
distributed_function -> distributed_function


In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
show_predictions(test_dataset, 3)