<a href="https://colab.research.google.com/github/detsikas/Semantic-Segmentation/blob/master/human_segmentation_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Description

Based on https://arxiv.org/abs/1505.04597

In [0]:
import tensorflow as tf
import argparse
import os
import sys
import re
import datetime
import pandas as pd

base_folder = '/content/drive/My Drive/colab_data'
AUTOTUNE = tf.data.experimental.AUTOTUNE

# IO functions

In [0]:
# The last logged epoch is probably not complete. So training will resume from that
def get_last_complete_epoch(latest):
  _, filename = os.path.split(latest)
  epoch_str = re.split('\.|-', filename)[1]
  try:
    epoch_num = int(epoch_str)
    return epoch_num-1
  except ValueError:
    print('Bad checkpoing filename formnat: {}'.format(filename))
    sys.exit(0)

def create_metrics_log_file(output_path):
  loop = True
  index = 0
  while(loop):
    csv_logger_path = os.path.join(output_path, 'metrics_{}.log'.format(index))
    if not os.path.exists(csv_logger_path):
        loop = False
    else:
        index+=1
  return csv_logger_path


def write_arguments_to_file(args, output_path):
  file = os.path.join(output_path, 'args.txt')

  with open(file, 'w') as fp:
    for key in args:
      value = args[key]
      if value is not None:
        if not isinstance(value, (bool)):
          fp.write("--" + key + "\n")
          fp.write(str(value)+"\n")
        elif value is True:
          fp.write("--" + key + "\n")


def read_work_paths(args):
  if not os.path.exists(args['dataset_path']):
    sys.exit('Dataset path does not exist')

  checkpoint_filename = 'cp-{epoch:04d}.ckpt'
  last_complete_epoch = 0
  latest_checkpoint = None
  if args['restore_from'] is not None:
    if not os.path.exists(args['restore_from']):
      sys.exit('Restore path {} does not exist'.format(args['restore_from']))

    checkpoint_path = os.path.join(args['restore_from'], checkpoint_filename)
    checkpoint_dir = os.path.dirname(checkpoint_path)
    if not os.path.exists(args['restore_from']):
      sys.exit('Cannot find checkpoints at {}'.format(checkpoint_path))

    output_path = args['restore_from']
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    last_complete_epoch = get_last_complete_epoch(latest_checkpoint)
  else:
    output_path = os.path.join(base_folder, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    checkpoint_path = os.path.join(output_path, checkpoint_filename)
    # Create output directory
    os.makedirs(output_path)
    write_arguments_to_file(args, output_path)

  return output_path, checkpoint_path, last_complete_epoch, latest_checkpoint


def write_configuration_to_file(args, output_path):
  config_file = os.path.join(output_path, 'config.txt')
  f = open(config_file, 'w')
  f.write('Configration\n')
  f.write('------------\n')
  f.write('Training data size: {}\n'.format(args['training_size']))
  f.write('Validations data size: {}\n'.format(args['validation_size']))
  f.write('Expansion method: {}\n'.format(args['expansion_method']))
  f.write('Batch size: {}\n'.format(args['batch_size']))
  f.write('Buffer size: {}\n'.format(args['buffer_size']))
  f.write('Epochs: {}\n'.format(args['epochs']))
  f.write('Image size: {}\n'.format(args['image_size']))
  f.write('Dataset path: {}\n'.format(args['dataset_path']))
  f.write('Restore from: {}\n'.format(args['restore_from']))
  f.write('Initializer: {}\n'.format(args['initializer']))
  f.write('Regularizer: {}\n'.format(args['regularizer']))
  f.write('Comments: {}\n'.format(args['comments']))
  f.write('TPU: {}\n'.format(args['tpu']))
  f.close()


def create_argument_parser():
  parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
  parser.add_argument('--dataset_path', help='Dataset path')
  parser.add_argument('--epochs', help='Training epochs (default 100)', type=int, default=100)
  parser.add_argument('--training_size', help='Training samples (default 2000)', type=int)
  parser.add_argument('--validation_size', help='Validation samples (default 200)', type=int)
  parser.add_argument('--buffer_size', help='Random shuffling buffer (default 1000)',type=int, default=1000)
  parser.add_argument('--batch_size', help='Batch size (default 32)', type=int, default=32)
  parser.add_argument('--image_size', help='Image size (default 256)', type=int, default=256)
  parser.add_argument('--initializer', help='Kernel initializer', action='store_true')
  parser.add_argument('--regularizer', help='Kernel regularizer', action='store_true')
  parser.add_argument('--comments', help='Comments')
  parser.add_argument('--tpu', help='Using TPU', action='store_true')
  parser.add_argument('--restore_from', help='Path to restore from checkpoints')
  parser.add_argument('--expansion_method', help='Method for the expansion path (default upsampling)',
                      choices=['upsampling', 'tconv'],
                      default='upsampling')
  return parser

def print_configuration(args):
  print('Configuration')
  print('-------------')
  print('Training data size: {}'.format(args['training_size'] if args['training_size'] is not None else "all"))
  print('Validations data size: {}'.format(args['validation_size'] if args['validation_size'] is not None else "all"))
  print('Expansion method: {}'.format(args['expansion_method']))
  print('Batch size: {}'.format(args['batch_size']))
  print('Buffer size: {}'.format(args['buffer_size']))
  print('Comments: {}'.format(args['comments']))
  print('Tpu: {}'.format(args['tpu']))
  print('Initializer: {}'.format(args['initializer']))
  print('Regularizer: {}'.format(args['regularizer']))
  print('Epochs: {}'.format(args['epochs']))
  print('Image size: {}'.format(args['image_size']))
  print('Dataset path: {}'.format(args['dataset_path']))
  print('Restore from: {}'.format(args['restore_from']))

# Input parameters

In [0]:
#@title Restore a previous execution - Other arguments ignored
RESTORE_FROM = '' #@param {type:"string"}
if RESTORE_FROM=='':
  RESTORE_FROM = None

In [0]:
#@title Configuration
BUFFER_SIZE = 1000 #@param {type:"number"}
TRAINING_SIZE = 4000 #@param {type:"number"}
VALIDATION_SIZE =  1000 #@param {type:"number"}
BATCH_SIZE = 32 #@param {type:"number"}
EPOCHS = 100 #@param {type:"number"}
IMAGE_SIZE = 256 #@param {type:"number"}
DATASET_PATH = 'train_10000_test_2000' #@param {type:"string"}
INITIALIZER = False #@param {type:"boolean"}
REGULARIZER = True #@param {type:"boolean"}
COMMENTS = 'dropout 0.2' #@param {type:"string"}
EXPANSION_METHOD = 'upsampling' #@param ['upsampling', 'tconv']
TPU = False #@param {type:"boolean"}

#DATASET_PATH = os.path.join(base_folder, DATASET_PATH)

if DATASET_PATH=='':
  DATASET_PATH = None

# Fetch dataset from google drive

In [5]:
dataset_gz_file = DATASET_PATH+'.tar.gz'
dataset_gz_full_gdrive_path = os.path.join(base_folder, dataset_gz_file)
print(dataset_gz_file)
print(dataset_gz_full_gdrive_path)
!cp '{dataset_gz_full_gdrive_path}' .
!tar xvf '{dataset_gz_file}'

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
train_10000_test_2000/train2017/5000/annotations/000000355869.png
train_10000_test_2000/train2017/5000/annotations/000000502800.png
train_10000_test_2000/train2017/5000/annotations/000000120356.png
train_10000_test_2000/train2017/5000/annotations/000000374910.png
train_10000_test_2000/train2017/5000/annotations/000000008436.png
train_10000_test_2000/train2017/5000/annotations/000000488822.png
train_10000_test_2000/train2017/5000/annotations/000000225331.png
train_10000_test_2000/train2017/5000/annotations/000000026784.png
train_10000_test_2000/train2017/5000/annotations/000000581569.png
train_10000_test_2000/train2017/5000/annotations/000000326207.png
train_10000_test_2000/train2017/5000/annotations/000000312074.png
train_10000_test_2000/train2017/5000/annotations/000000349038.png
train_10000_test_2000/train2017/5000/annotations/000000577169.png
train_10000_test_2000/train2017/5000/annotations/000000531474.png
train_10000

# Process arguments


In [6]:
args = {}
args['restore_from'] = RESTORE_FROM
args['dataset_path'] = DATASET_PATH

if args['dataset_path'] is None and args['restore_from'] is None:
  sys.exit('Dataset path or restore from must be specified')


if args['restore_from'] is not None:
  file = os.path.join(args['restore_from'], 'args.txt')
  parser = create_argument_parser()
  args_from_file = parser.parse_args(['@'+file])
  args['buffer_size'] = args_from_file.buffer_size
  args['training_size'] = args_from_file.training_size
  args['validation_size'] = args_from_file.validation_size
  args['batch_size'] = args_from_file.batch_size
  args['epochs'] = args_from_file.epochs  # Actually the authors use 100ths of thousands
  args['image_size'] = args_from_file.image_size
  args['initializer'] = args_from_file.initializer
  args['regularizer'] = args_from_file.regularizer
  args['comments'] = args_from_file.comments
  args['expansion_method'] = args_from_file.expansion_method
  args['tpu'] = args_from_file.tpu
else:
  args['buffer_size'] = BUFFER_SIZE
  args['training_size'] = TRAINING_SIZE
  args['validation_size'] = VALIDATION_SIZE
  args['batch_size'] = BATCH_SIZE
  args['epochs'] = EPOCHS  # Actually the authors use 100ths of thousands
  args['image_size'] = IMAGE_SIZE
  args['initializer'] = INITIALIZER
  args['regularizer'] = REGULARIZER
  args['comments'] = COMMENTS
  args['expansion_method'] = EXPANSION_METHOD
  args['tpu'] = TPU

print_configuration(args)
output_path, checkpoint_path, last_complete_epoch, latest_checkpoint = read_work_paths(args)
if (args['restore_from']) is None:
  write_configuration_to_file(args, output_path)

Configuration
-------------
Training data size: 4000
Validations data size: 1000
Expansion method: upsampling
Batch size: 32
Buffer size: 1000
Comments: dropout 0.2
Tpu: False
Initializer: False
Regularizer: True
Epochs: 100
Image size: 256
Dataset path: train_10000_test_2000
Restore from: None


# Read the dataset

In [0]:
# Read images
def load_jpg_image(filename):
    image = tf.image.decode_jpeg(tf.io.read_file(filename))
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image


def load_png_image(filename):
    image = tf.image.decode_png(tf.io.read_file(filename), channels=1)
    image = tf.cast(image, tf.float32)
    image /= 255.0
    image = tf.round(image)
    return image



def read_splits(dataset_path):
    training_path = tf.strings.join([dataset_path, 'train2017/*'], separator='/')
    validation_path = tf.strings.join([dataset_path, 'val2017/*'], separator='/')

    training_splits = tf.data.Dataset.list_files(training_path)
    validation_splits = tf.data.Dataset.list_files(validation_path)

    return training_splits, validation_splits



def get_dataset_split(split_path):
    # Read split path images
    images_pattern = tf.strings.join([split_path, 'images', '*.jpg'], separator='/')
    images = tf.data.Dataset.list_files(images_pattern, shuffle=False).map(load_jpg_image,
                                                                           num_parallel_calls=AUTOTUNE)

    # Read training annotations
    annotations_pattern = tf.strings.join([split_path, 'annotations', '*.png'], separator='/')
    annotations = tf.data.Dataset.list_files(annotations_pattern, shuffle=False).map(load_png_image,
                                                                                     num_parallel_calls=AUTOTUNE)

    # Merge images and annotations
    dataset_split = tf.data.Dataset.zip((images, annotations))
    return dataset_split


def get_datasets(dataset_path, train_size=None, val_size=None):
    training_splits, validation_splits = read_splits(dataset_path)

    training_dataset = training_splits.interleave(lambda x: get_dataset_split(x),
                                                  cycle_length=4,
                                                  deterministic=False,
                                                  num_parallel_calls=AUTOTUNE)
    validation_dataset = validation_splits.interleave(lambda x: get_dataset_split(x),
                                                      cycle_length=4,
                                                      deterministic=False,
                                                      num_parallel_calls=AUTOTUNE)
    '''
    for split in training_splits:
        split_path = os.path.join(dataset_path, 'train2017', split)
        if training_dataset is None:
            training_dataset = get_dataset_split(split_path)
        else:
            training_dataset = training_dataset.concatenate(get_dataset_split(split_path))

    validation_dataset = None
    for split in validation_splits:
        split_path = os.path.join(dataset_path, 'val2017', split)
        if validation_dataset is None:
            validation_dataset = get_dataset_split(split_path)
        else:
            validation_dataset = validation_dataset.concatenate(get_dataset_split(split_path))
    '''
    if train_size is not None:
        training_dataset = training_dataset.take(train_size)

    if val_size is not None:
        validation_dataset = validation_dataset.take(val_size)

    return training_dataset, validation_dataset


train_dataset, val_dataset = get_datasets(args['dataset_path'], args['training_size'], args['validation_size'])
train_dataset = train_dataset.batch(args['batch_size'])
val_dataset = val_dataset.batch(args['batch_size'])
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
val_dataset = val_dataset.prefetch(buffer_size=AUTOTUNE)

# Prepare the model

In [0]:
def initializer():
  return tf.random_normal_initializer(mean=0.0, stddev=0.01)


def regularizer():
  return tf.keras.regularizers.l2(0.0001)


def downsample(input):
  x = tf.keras.layers.MaxPool2D(2)(input)
  return tf.keras.layers.ReLU()(x)


def conv(input, filters):
  x = tf.keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same',
                              kernel_initializer=initializer() if args['initializer'] else 'glorot_uniform', 
                              kernel_regularizer=regularizer() if args['regularizer'] else None)(input)
  #x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Dropout(0.2)(x)
  return tf.keras.layers.ReLU()(x)


def conv_transpose(input, filters):
    x = tf.keras.layers.Conv2DTranspose(filters=filters, kernel_size=2, strides=2, padding='valid',
                                        kernel_initializer=initializer() if args['initializer'] else 'glorot_uniform',
                                        kernel_regularizer=regularizer() if args['regularizer'] else None)(input)
    #x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    return tf.keras.layers.ReLU()(x)


def create_model(args):
  input_shape = [args['image_size'], args['image_size'], 3]

  # Contracting path
  model_input = tf.keras.layers.Input(shape=input_shape)
  cx_256 = conv(model_input, 32)
  cx_256 = conv(cx_256, 32)
  cx_128 = downsample(cx_256)
  cx_128 = conv(cx_128, 64)
  cx_128 = conv(cx_128, 64)
  cx_64 = downsample(cx_128)
  cx_64 = conv(cx_64, 128)
  cx_64 = conv(cx_64, 128)
  cx_32 = downsample(cx_64)
  cx_32 = conv(cx_32, 256)
  cx_32 = conv(cx_32, 128)

  # Expanding path
  if args['expansion_method'] == 'upsampling':
    ex_64 = tf.keras.layers.UpSampling2D(2)(cx_32)
  else:
    ex_64 = conv_transpose(cx_32, 128)
  ex_64_concat = tf.keras.layers.Concatenate()([cx_64, ex_64])
  ex_64_concat = conv(ex_64_concat, 128)
  ex_64_concat = conv(ex_64_concat, 64)
  if args['expansion_method'] == 'upsampling':
    ex_128 = tf.keras.layers.UpSampling2D(2)(ex_64_concat)
  else:
    ex_128 = conv_transpose(ex_64_concat, 64)
  ex_128_concat = tf.keras.layers.Concatenate()([cx_128, ex_128])
  ex_128_concat = conv(ex_128_concat, 64)
  ex_128_concat = conv(ex_128_concat, 32)
  if args['expansion_method'] == 'upsampling':
    ex_256 = tf.keras.layers.UpSampling2D(2)(ex_128_concat)
  else:
    ex_256 = conv_transpose(ex_128_concat, 32)
  ex_256_concat = tf.keras.layers.Concatenate()([cx_256, ex_256])
  ex_256_concat = conv(ex_256_concat, 32)

  # Mapping
  m = tf.keras.layers.Conv2D(filters=1, kernel_size=3, activation='sigmoid', padding='same',
                            kernel_initializer=initializer() if args['initializer'] else 'glorot_uniform',
                                        kernel_regularizer=regularizer() if args['regularizer'] else None)(ex_256_concat)

  # Model
  model = tf.keras.Model(inputs=model_input, outputs=m)
  model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy'])

  if args['tpu']:
    from tensorflow.contrib.tpu.python.tpu import keras_support
    tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
    
    #connect the TPU cluster using the address 
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
    
    #run the model on different clusters 
    strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
    
    #convert the model to run on tpu 
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)  

  return model

model = create_model(args)

# Callbacks

In [9]:
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1, period=10)
#csv_logger_path = create_metrics_log_file(output_path)
csv_logger_path = os.path.join(output_path, 'metrics.log')
csv_logger_callback = tf.keras.callbacks.CSVLogger(csv_logger_path, append=True)
tensorboard_log_dir = os.path.join(output_path, 'tensorboard')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1)




In [10]:
device_name = tf.test.gpu_device_name()
print(device_name)

/device:GPU:0


# Train

In [11]:
# Check if we should resume from a previous training
if args['restore_from'] is not None:
  print('Loading model from checkpoints. Resuming from epoch: {}'.format(last_complete_epoch+1))
  model.load_weights(latest_checkpoint)

# Train
history = model.fit(train_dataset, validation_data=val_dataset, epochs=args['epochs'],
                    initial_epoch=last_complete_epoch, 
                    callbacks=[cp_callback, csv_logger_callback, tensorboard_callback])

Epoch 1/100
     21/Unknown - 11s 512ms/step - loss: 0.6101 - binary_accuracy: 0.8393

KeyboardInterrupt: ignored

# Save model and results

In [0]:
print('Training complete')
print('Saving metrics')
output_json_file = os.path.join(output_path, 'history.json')
pd.DataFrame.from_dict(history.history).to_json(output_json_file)
print('Saving model')
model_folder = os.path.join(output_path, 'saved_model')
os.makedirs(model_folder)
model_filename = os.path.join(model_folder, 'model')
model.save(model_filename)
print('Done')

# Pro# fd# Flush drive

In [0]:
from google.colab import drive
drive.flush_and_unmount()
del model