# Image segmentation using UNET


https://www.tensorflow.org/tutorials/images/segmentation

The Oxford-IIIT Pet Dataset

https://www.robots.ox.ac.uk/~vgg/data/pets/



In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, Conv2DTranspose
from tensorflow.keras.models import Model
import os

In [2]:
print(tf.version.VERSION)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

2.3.0


SystemError: GPU device not found

In [3]:
def unet(input_shape, num_classes):
    inputs = Input(input_shape)

    # Encoder
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
    drop5 = Dropout(0.5)(conv5)

    # Decoder
    up6 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(drop5)
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(merge6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(merge7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(merge8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    
    up9 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8)

    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(merge9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(conv9)

    model = Model(inputs=inputs, outputs=outputs)

    return model

In [4]:

# Define the file paths of the dataset
data_dir = '../data/oxford-iiit-pet/'
image_dir = os.path.join(data_dir, 'images')
annotations_file = os.path.join(data_dir, 'annotations', 'list.txt')

# Read the file paths and labels from the annotations file
with open(annotations_file, 'r') as f:
    lines = f.readlines()

filepaths = []
labels = []

for line in lines:
    line = line.strip()
    if line.startswith("#"):
        pass
    else:
        line = line.split()
        filepaths.append(os.path.join(image_dir, line[0] + '.jpg'))
        labels.append(os.path.join(data_dir, 'annotations', 'trimaps', line[0] + '.png'))
    
# Create a dataset from the file paths and labels
dataset = tf.data.Dataset.from_tensor_slices((filepaths, labels))

# Define a function to load and preprocess each image and label
def load_and_preprocess_image(filepath, label):
    # Load the image
    image = tf.io.read_file(filepath)
    # Decode the image
    image = tf.image.decode_jpeg(image, channels=3)
    # Resize the image to a fixed size
    image = tf.image.resize(image, [128, 128])
    # Normalize the pixel values to the range [0, 1]
    image = tf.cast(image, tf.float32) / 255.0
    # Load the label
    label = tf.io.read_file(label)
    # Decode the label
    label = tf.image.decode_png(label, channels=1)
    # Resize the label to a fixed size
    label = tf.image.resize(label, [128, 128], method='nearest')
    # Convert the label to a binary mask
    label = tf.where(label == 0, 0, 1)
    # Return the preprocessed image and label
    return image, label

# Apply the preprocessing function to each image and label in the dataset
dataset = dataset.map(load_and_preprocess_image)

# Shuffle and batch the dataset
batch_size = 32
dataset = dataset.shuffle(buffer_size=1000).batch(batch_size)

# Iterate over the dataset
#for batch in dataset:
#    images, labels = batch
    # Train your model with the batch of images and labels
    #print(batch)


In [5]:
input_shape = (128, 128, 3)
num_classes = 3

model = unet(input_shape, num_classes)

# Compile the model with appropriate loss and optimizer for your task
#model.compile(loss='binary_crossentropy', optimizer='adam')

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

#model.summary()

In [None]:
model.fit(dataset, epochs=10)