In [None]:
!pip install ipython ipykernel --upgrade -q
!pip install layer-sdk --upgrade -q

[K     |████████████████████████████████| 793 kB 5.1 MB/s 
[K     |████████████████████████████████| 131 kB 65.4 MB/s 
[K     |████████████████████████████████| 381 kB 53.3 MB/s 
[K     |████████████████████████████████| 428 kB 49.9 MB/s 
[K     |████████████████████████████████| 130 kB 43.1 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-console 5.2.0 requires prompt-toolkit<2.0.0,>=1.0.0, but you have prompt-toolkit 3.0.29 which is incompatible.
google-colab 1.0.0 requires ipykernel~=4.10, but you have ipykernel 6.13.0 which is incompatible.
google-colab 1.0.0 requires ipython~=5.5.0, but you have ipython 7.32.0 which is incompatible.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.[0m
[K     |████████████████████████████████| 471 kB 5.1 MB/s 
[K     |██████████████

In [None]:
pip install git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

In [None]:
import layer
from layer.decorators import model

layer.login()
layer.init('image-segmentation')

In [None]:
@model('base-model')
def build_base_model():
  return tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

build_base_model()

In [None]:
# Create the feature extraction model
@model('down-stack-model')
def build_down_stack_model():
  base_model = layer.get_model('base-model').get_train()

  # Use the activations of these layers
  layer_names = [
      'block_1_expand_relu',   # 64x64
      'block_3_expand_relu',   # 32x32
      'block_6_expand_relu',   # 16x16
      'block_13_expand_relu',  # 8x8
      'block_16_project',      # 4x4
  ]
  base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
  down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)
  down_stack.trainable = False
  return down_stack

build_down_stack_model()

In [None]:
import io
from PIL import Image

def make_display_figure(display_list):
  figure = plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    ax = figure.add_subplot(1, len(display_list), i+1)
    ax.set_title(title[i])
    ax.imshow(tf.keras.utils.array_to_img(display_list[i]))
    ax.set_axis_off()
  return figure

def make_display_image(display_list):
    fig = make_display_figure(display_list)
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    plt.close()
    return img

def display(display_list):
  make_display_figure(display_list).show()

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def get_figures(dataset, num=1, model_train=None):
  figures = []
  for image, mask in dataset.take(num):
    display_list = [image[0], mask[0]]
    if model_train:
      pred_mask = model_train.predict(image)
      display_list.append(create_mask(pred_mask))
    figures.append(make_display_image(display_list))
  return figures

In [None]:
import PIL
import io



class LayerCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    layer.log(logs, step=epoch)


@model('model')
def build_model():
  dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

  OUTPUT_CLASSES = 3
  TRAIN_LENGTH = info.splits['train'].num_examples
  BATCH_SIZE = 64
  BUFFER_SIZE = 1000
  STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
  EPOCHS = 30
  VAL_SUBSPLITS = 5
  VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

  layer.log({
    'TRAIN_LENGTH': TRAIN_LENGTH,
    'BATCH_SIZE': BATCH_SIZE,
    'BUFFER_SIZE': BUFFER_SIZE,
    'OUTPUT_CLASSES': OUTPUT_CLASSES,
    'EPOCHS': EPOCHS,
    'STEPS_PER_EPOCH': STEPS_PER_EPOCH,
    'VAL_SUBSPLITS': VAL_SUBSPLITS,
    'VALIDATION_STEPS': VALIDATION_STEPS,
  })

  train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
  test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

  train_batches = (
      train_images
      .cache()
      .shuffle(BUFFER_SIZE)
      .batch(BATCH_SIZE)
      .repeat()
      .map(Augment())
      .prefetch(buffer_size=tf.data.AUTOTUNE))

  test_batches = test_images.batch(BATCH_SIZE)

  for ix, fig in enumerate(get_figures(train_batches, num=4)):
    layer.log({f'sample_{ix}': fig})

  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  down_stack = layer.get_model('down-stack-model').get_train()
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  up_stack = [
      pix2pix.upsample(512, 3),  # 4x4 -> 8x8
      pix2pix.upsample(256, 3),  # 8x8 -> 16x16
      pix2pix.upsample(128, 3),  # 16x16 -> 32x32
      pix2pix.upsample(64, 3),   # 32x32 -> 64x64
  ]
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=OUTPUT_CLASSES, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  segmentation_model = tf.keras.Model(inputs=inputs, outputs=x)

  segmentation_model.compile(
      optimizer='adam',
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=['accuracy'])

  #shape_plot = tf.keras.utils.plot_model(segmentation_model, show_shapes=True)
  #layer.log({
  #    'shape_plot': PIL.Image.open(io.BytesIO(shape_plot.data)) ,
  #})
  for ix, fig in enumerate(get_figures(train_batches, num=4, model_train=segmentation_model)):
    layer.log({f'prediction_initial_{ix}': fig})

  model_history = segmentation_model.fit(
      train_batches,
      epochs=EPOCHS,
      steps_per_epoch=STEPS_PER_EPOCH,
      validation_steps=VALIDATION_STEPS,
      validation_data=test_batches,
      callbacks=[LayerCallback()])
  
  for ix, fig in enumerate(get_figures(train_batches, num=4, model_train=segmentation_model)):
    layer.log({f'prediction_final_{ix}': fig})

  return segmentation_model

build_model()