In [None]:
!git clone -b tpu https://github.com/lattice-ai/DeepLabV3-Plus
%cd DeepLabV3-Plus

In [None]:
!git pull

In [None]:
import os

import tensorflow as tf

from kaggle_datasets import KaggleDatasets
from deeplabv3plus.datasets import TFRecordDataset

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('human-segmentation-tfrecords')
print(GCS_PATH)

In [None]:
TRAIN_TFRECORDS = tf.io.gfile.glob(
    os.path.join(
        GCS_PATH,
        'human-segmentation-tfrecords/human-segmentation-train/*.tfrec'
    )
)

print('Number of TFRecord Files:', len(TRAIN_TFRECORDS))

In [None]:
tfrecord_dataset = TFRecordDataset(tfrecords=TRAIN_TFRECORDS,
                                   image_size=512,
                                   apply_flips=True,
                                   apply_jitter=False)

tfrecord_dataset.summary(visualize=True)

In [None]:
from deeplabv3plus.utils import get_strategy

TRAIN_STRATEGY = get_strategy()

In [None]:
tfrecord_dataset.dataset

In [None]:
def AtrousSpatialPyramidPooling(model_input):
  dims = tf.keras.backend.int_shape(model_input)

  layer = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3],
                                                      dims[-2]))(model_input)
  layer = tf.keras.layers.Conv2D(256, kernel_size=1, padding='same',
                                 kernel_initializer = 'he_normal')(layer)
  layer = tf.keras.layers.BatchNormalization()(layer)
  layer = tf.keras.layers.ReLU()(layer)
  out_pool = tf.keras.layers.UpSampling2D(size = (dims[-3] // layer.shape[1],
                                               dims[-2] // layer.shape[2]),
                                        interpolation = 'bilinear')(layer)
  
  layer = tf.keras.layers.Conv2D(256, kernel_size = 1,
                                   dilation_rate = 1, padding = 'same',
                                   kernel_initializer = 'he_normal',
                                   use_bias = False)(model_input)
  layer = tf.keras.layers.BatchNormalization()(layer)
  out_1 = tf.keras.layers.ReLU()(layer)

  layer = tf.keras.layers.Conv2D(256, kernel_size = 3,
                                   dilation_rate = 6, padding = 'same', 
                                   kernel_initializer = 'he_normal',
                                   use_bias = False)(model_input)
  layer = tf.keras.layers.BatchNormalization()(layer)
  out_6 = tf.keras.layers.ReLU()(layer)

  layer = tf.keras.layers.Conv2D(256, kernel_size = 3,
                                   dilation_rate = 12, padding = 'same',
                                   kernel_initializer = 'he_normal',
                                   use_bias = False)(model_input)
  layer = tf.keras.layers.BatchNormalization()(layer)
  out_12 = tf.keras.layers.ReLU()(layer)

  layer = tf.keras.layers.Conv2D(256, kernel_size = 3,
                                   dilation_rate = 18, padding = 'same',
                                   kernel_initializer = 'he_normal',
                                   use_bias = False)(model_input)
  layer = tf.keras.layers.BatchNormalization()(layer)
  out_18 = tf.keras.layers.ReLU()(layer)

  layer = tf.keras.layers.Concatenate(axis = -1)([out_pool, out_1,
                                                    out_6, out_12,
                                                    out_18])

  layer = tf.keras.layers.Conv2D(256, kernel_size = 1,
                                   dilation_rate = 1, padding = 'same',
                                   kernel_initializer = 'he_normal',
                                   use_bias = False)(layer)
  layer = tf.keras.layers.BatchNormalization()(layer)
  model_output = tf.keras.layers.ReLU()(layer)
  return model_output

In [None]:
def DeeplabV3Plus(nclasses = 20):
  model_input = tf.keras.Input(shape=(512,512,3))
  resnet50 = tf.keras.applications.ResNet50(weights = 'imagenet',
                                            include_top = False,
                                            input_tensor = model_input)
  layer = resnet50.get_layer('conv4_block6_2_relu').output
  layer = AtrousSpatialPyramidPooling(layer)
  input_a = tf.keras.layers.UpSampling2D(size = (512 // 4 // layer.shape[1],
                                                 512 // 4 // layer.shape[2]),
                                          interpolation = 'bilinear')(layer)

  input_b = resnet50.get_layer('conv2_block3_2_relu').output
  input_b = tf.keras.layers.Conv2D(48, kernel_size = (1,1), padding = 'same',
                                   kernel_initializer = tf.keras.initializers.he_normal(),
                                   use_bias = False)(input_b)
  input_b = tf.keras.layers.BatchNormalization()(input_b)
  input_b = tf.keras.layers.ReLU()(input_b)

  layer = tf.keras.layers.Concatenate(axis = -1)([input_a, input_b])

  layer = tf.keras.layers.Conv2D(256, kernel_size = 3,
                                   padding = 'same', activation = 'relu',
                                   kernel_initializer = tf.keras.initializers.he_normal(),
                                   use_bias = False)(layer)
  layer = tf.keras.layers.BatchNormalization()(layer)
  layer = tf.keras.layers.ReLU()(layer)
  layer = tf.keras.layers.Conv2D(256, kernel_size =3,
                                   padding = 'same', activation = 'relu',
                                   kernel_initializer = tf.keras.initializers.he_normal(),
                                   use_bias = False)(layer)
  layer = tf.keras.layers.BatchNormalization()(layer)
  layer = tf.keras.layers.ReLU()(layer)
  layer = tf.keras.layers.UpSampling2D(size = (512 // layer.shape[1],
                                               512 // layer.shape[2]),
                                       interpolation = 'bilinear')(layer)
  model_output = tf.keras.layers.Conv2D(nclasses, kernel_size = (1,1),
                                   padding = 'same')(layer)
  return tf.keras.Model(inputs = model_input, outputs = model_output)

In [None]:
with TRAIN_STRATEGY.scope():
    MODEL = DeeplabV3Plus()
    
    MODEL.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=0.0001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )

In [None]:
MODEL.summary()

In [None]:
x, y = next(iter(tfrecord_dataset.configured_dataset(batch_size=1)))

MODEL.fit(x, y, epochs=100)