In [16]:
from __future__ import absolute_import, division, print_function, unicode_literals


def main_fun(args, ctx):
  from tensorflow_examples.models.pix2pix import pix2pix
  import tensorflow_datasets as tfds
  import tensorflow as tf

  print("TensorFlow version: ", tf.__version__)
  strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

  dataset, info = tfds.load('oxford_iiit_pet:3.2.0', with_info=True)

  def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32)/128.0 - 1
    input_mask -= 1
    return input_image, input_mask

  @tf.function
  def load_image_train(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

    if tf.random.uniform(()) > 0.5:
      input_image = tf.image.flip_left_right(input_image)
      input_mask = tf.image.flip_left_right(input_mask)

    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask

  def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

  TRAIN_LENGTH = info.splits['train'].num_examples
  BATCH_SIZE = args.batch_size
  BUFFER_SIZE = args.buffer_size
  STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

  train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  test = dataset['test'].map(load_image_test)

  train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
  train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  test_dataset = test.batch(BATCH_SIZE)

  OUTPUT_CHANNELS = 3

  with strategy.scope():
    base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

    # Use the activations of these layers
    layer_names = [
        'block_1_expand_relu',   # 64x64
        'block_3_expand_relu',   # 32x32
        'block_6_expand_relu',   # 16x16
        'block_13_expand_relu',  # 8x8
        'block_16_project',      # 4x4
    ]
    layers = [base_model.get_layer(name).output for name in layer_names]

    # Create the feature extraction model
    down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

    down_stack.trainable = False

    up_stack = [
        pix2pix.upsample(512, 3),  # 4x4 -> 8x8
        pix2pix.upsample(256, 3),  # 8x8 -> 16x16
        pix2pix.upsample(128, 3),  # 16x16 -> 32x32
        pix2pix.upsample(64, 3),   # 32x32 -> 64x64
    ]

    def unet_model(output_channels):

      # This is the last layer of the model
      last = tf.keras.layers.Conv2DTranspose(
          output_channels, 3, strides=2,
          padding='same', activation='softmax')  # 64x64 -> 128x128

      inputs = tf.keras.layers.Input(shape=[128, 128, 3])
      x = inputs

      # Downsampling through the model
      skips = down_stack(x)
      x = skips[-1]
      skips = reversed(skips[:-1])

      # Upsampling and establishing the skip connections
      for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

      x = last(x)

      return tf.keras.Model(inputs=inputs, outputs=x)

    model = unet_model(OUTPUT_CHANNELS)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

# Training only (since we're using command-line)
# def create_mask(pred_mask):
#   pred_mask = tf.argmax(pred_mask, axis=-1)
#   pred_mask = pred_mask[..., tf.newaxis]
#   return pred_mask[0]
#
#
# def show_predictions(dataset=None, num=1):
#   if dataset:
#     for image, mask in dataset.take(num):
#       pred_mask = model.predict(image)
#       display([image[0], mask[0], create_mask(pred_mask)])
#   else:
#     display([sample_image, sample_mask,
#              create_mask(model.predict(sample_image[tf.newaxis, ...]))])
#
#
# class DisplayCallback(tf.keras.callbacks.Callback):
#   def on_epoch_end(self, epoch, logs=None):
#     clear_output(wait=True)
#     show_predictions()
#     print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
#

  EPOCHS = args.epochs
  VAL_SUBSPLITS = 5
  VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

  tf.io.gfile.makedirs(args.model_dir)
  filepath = args.model_dir + "/weights-{epoch:04d}"
  ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)

  model_history = model.fit(train_dataset, epochs=EPOCHS,
                            steps_per_epoch=STEPS_PER_EPOCH,
                            callbacks=[ckpt_callback],
                            validation_steps=VALIDATION_STEPS,
                            validation_data=test_dataset)

  if tf.__version__ == '2.0.0':
    # Workaround for: https://github.com/tensorflow/tensorflow/issues/30251
    # Save model locally as h5py and reload it w/o distribution strategy
    if ctx.job_name == 'chief':
      model.save(args.model_dir + ".h5")
      new_model = tf.keras.models.load_model(args.model_dir + ".h5")
      tf.keras.experimental.export_saved_model(new_model, args.export_dir)
  else:
    model.save(args.export_dir, save_format='tf')

In [18]:
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from tensorflowonspark import TFCluster

sc = SparkContext('local' ,'segmentation')

cluster = TFCluster.run(sc, main_fun, num_ps=0, tf_args=None, num_executors=1, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
#cluster.shutdown(grace_secs=30)
sc.stop()

2020-07-08 13:50:13,890 INFO (MainThread-23140) Reserving TFSparkNodes 
2020-07-08 13:50:13,891 INFO (MainThread-23140) cluster_template: {'chief': [0]}
2020-07-08 13:50:13,892 INFO (MainThread-23140) Reservation server binding to port 0
2020-07-08 13:50:13,892 INFO (MainThread-23140) listening for reservations at ('192.168.0.248', 49322)
2020-07-08 13:50:13,893 INFO (MainThread-23140) Starting TensorFlow on executors
2020-07-08 13:50:13,899 INFO (MainThread-23140) Waiting for TFSparkNodes to start
2020-07-08 13:50:13,899 INFO (MainThread-23140) waiting for 1 reservations
2020-07-08 13:50:13,913 ERROR (Thread-18-23140) Exception in TF background thread


SystemExit: 1

In [15]:
sc.stop()