# Image Segmentation

### 다음과 같은 일반적인 workflow로 진행
1. Visualize data/perform some exploratory data analysis
2. Set up data pipeline and preprocessing
3. Build model
4. Train model
5. Evaluate model
6. Repeat

## Project 설명

### Task
* GIANA dataset으로 위내시경 이미지에서 용종을 segmentation 해보자.
* 데이터 불러오기를 제외한 딥러닝 트레이닝 과정을 직접 구현해보는 것이 목표 입니다.
* This code is borrowed from [TensorFlow tutorials/Image Segmentation](https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb) which is made of `tf.keras.layers` and `tf.enable_eager_execution()`.
* You can see the detail description [tutorial link](https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb)  

### Dataset
* I use below dataset instead of [carvana-image-masking-challenge dataset](https://www.kaggle.com/c/carvana-image-masking-challenge/rules) in TensorFlow Tutorials which is a kaggle competition dataset.
  * carvana-image-masking-challenge dataset: Too large dataset (14GB)
* [Gastrointestinal Image ANAlys Challenges (GIANA)](https://giana.grand-challenge.org) Dataset (345MB)
  * Train data: 300 images with RGB channels (bmp format)
  * Train lables: 300 images with 1 channels (bmp format)
  * Image size: 574 x 500
* Training시 **image size는 256**으로 resize

### Baseline code
* Dataset: train, test로 split
* Input data shape: (`batch_size`, 256, 256, 3)
* Output data shape: (`batch_size`, 256, 256, 1)
* Architecture: 
  * 간단한 Encoder-Decoder 구조
  * U-Net 구조
  * [`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) 사용
* Training
  * `tf.data.Dataset` 사용
  * `model.fit()` 사용 for weight update
* Evaluation
  * MeanIOU: Image Segmentation에서 많이 쓰이는 evaluation measure
  * tf.version 1.13 API: [`tf.metrics.mean_iou`](https://www.tensorflow.org/api_docs/python/tf/metrics/mean_iou)
    * `tf.enable_eager_execution()`이 작동하지 않음
    * 따라서 예전 방식대로 `tf.Session()`을 이용하여 작성하거나 아래와 같이 2.0 version으로 작성하여야 함
  * tf.version 2.0 API: [`tf.keras.metrics.MeanIoU`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/metrics/MeanIoU)

### Try some techniques
* Change model architectures (Custom model)
  * Try another models (Unet 모델)
* Various regularization methods

## Import modules

### Import colab modules for Google Colab (if necessary)

In [None]:
# if necessary

# from google.colab import auth
# auth.authenticate_user()

# from google.colab import drive
# drive.mount('/content/gdrive')

In [None]:
use_colab = False
assert use_colab in [True, False]

### Import base modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import time
import shutil
import functools

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (12,12)

from sklearn.model_selection import train_test_split
import matplotlib.image as mpimg
import pandas as pd
from PIL import Image
from IPython.display import clear_output

import tensorflow as tf
tf.enable_eager_execution()

from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
is_train = True

model_name = 'ed_model'
assert model_name in ['ed_model', 'u-net']

## 데이터 수집 및 Visualize

### Download data

이 프로젝트는 [Giana Dataset](https://giana.grand-challenge.org/Dates/)을 이용하여 진행한다.

In [None]:
# Unfortunately you cannot downlaod GIANA dataset from website
# So I upload zip file on my dropbox
# if you want to download from my dropbox uncomment below  
if use_colab:
  DATASET_PATH='./gdrive/My Drive/datasets/sd_train'
else:
  DATASET_PATH='../../datasets/sd_train'

if not os.path.isdir(DATASET_PATH):
  os.makedirs(DATASET_PATH)
  
  import urllib.request
  u = urllib.request.urlopen(url='https://www.dropbox.com/s/1a11bw6zrm6bb77/sd_train.zip?dl=1')
  data = u.read()
  u.close()
 
  with open('sd_train.zip', "wb") as f :
    f.write(data)
  print('Data has been downloaded')
  
  shutil.move(os.path.join('sd_train.zip'), os.path.join(DATASET_PATH))
  file_path = os.path.join(DATASET_PATH, 'sd_train.zip')
  
  import zipfile
  zip_ref = zipfile.ZipFile(file_path, 'r')
  zip_ref.extractall(DATASET_PATH)
  zip_ref.close()
  print('Data has been extracted.')
  
else:
  print('Data has already been downloaded and extracted.')

### Split dataset into train data and test data

In [None]:
dataset_dir = os.path.join(DATASET_PATH, 'sd_train')

img_dir = os.path.join(dataset_dir, "train")
label_dir = os.path.join(dataset_dir, "train_labels")

In [None]:
x_train_filenames = [os.path.join(img_dir, filename) for filename in os.listdir(img_dir)]
x_train_filenames.sort()
y_train_filenames = [os.path.join(label_dir, filename) for filename in os.listdir(label_dir)]
y_train_filenames.sort()

In [None]:
x_train_filenames, x_test_filenames, y_train_filenames, y_test_filenames = \
                    train_test_split(x_train_filenames, y_train_filenames, test_size=0.2, random_state=219)

In [None]:
num_train_examples = len(x_train_filenames)
num_test_examples = len(x_test_filenames)

print("Number of training examples: {}".format(num_train_examples))
print("Number of test examples: {}".format(num_test_examples))

### Visualize

데이터 셋에서 5장 (`display_num`)의 이미지를 살펴보자.

In [None]:
display_num = 5

r_choices = np.random.choice(num_train_examples, display_num)

plt.figure(figsize=(10, 15))
for i in range(0, display_num * 2, 2):
  img_num = r_choices[i // 2]
  x_pathname = x_train_filenames[img_num]
  y_pathname = y_train_filenames[img_num]
  
  plt.subplot(display_num, 2, i + 1)
  plt.imshow(Image.open(x_pathname))
  plt.title("Original Image")
  
  example_labels = Image.open(y_pathname)
  label_vals = np.unique(example_labels)
  
  plt.subplot(display_num, 2, i + 2)
  plt.imshow(example_labels)
  plt.title("Masked Image")
  
plt.suptitle("Examples of Images and their Masks")
plt.show()

## Data pipeline and preprocessing 만들기

### Set up hyper-parameters

Hyper-parameter를 셋팅해보자. 이미지 사이즈, 배치 사이즈 등 training parameter들을 셋팅해보자.

In [None]:
# Set hyperparameters
image_size = 256
img_shape = (image_size, image_size, 3)
batch_size = 8
max_epochs = 10
print_steps = 10
save_epochs = 1

if use_colab:
  checkpoint_dir = train_dir ='./gdrive/My Drive/train_ckpt/segmentation/exp1'
  if not os.path.isdir(train_dir):
    os.makedirs(train_dir)
else:
  checkpoint_dir = train_dir = 'train/exp1'

### Build our input pipeline with `tf.data`

Input data pipeline을 만들기 가장 좋은 방법은 [**tf.data**](https://www.tensorflow.org/guide/datasets) (링크 참조) 를 사용하는 것이다. `tf.data` API 를 잘 읽어보자.


#### Our input pipeline will consist of the following steps:

TensorFlow segmentation tutorial input pipeline 참고 하였음.


>1. Read the bytes of the file in from the filename - for both the image and the label. Recall that our labels are actually images with each pixel annotated as car or background (1, 0). 
>2. Decode the bytes into an image format
>3. Apply image transformations: (optional, according to input parameters)
>  * `resize` - Resize our images to a standard size (as determined by eda or computation/memory restrictions)
>    * The reason why this is optional is that U-Net is a fully convolutional network (e.g. with no fully connected units) and is thus not dependent on the input size. However, if you choose to not resize the images, you must use a batch size of 1, since you cannot batch variable image size together
>    * Alternatively, you could also bucket your images together and resize them per mini-batch to avoid resizing images as much, as resizing may affect your performance through interpolation, etc.
>  * `hue_delta` - Adjusts the hue of an RGB image by a random factor. This is only applied to the actual image (not our label image). The `hue_delta` must be in the interval `[0, 0.5]` 
>  * `horizontal_flip` - flip the image horizontally along the central axis with a 0.5 probability. This transformation must be applied to both the label and the actual image. 
>  * `width_shift_range` and `height_shift_range` are ranges (as a fraction of total width or height) within which to randomly translate the image either horizontally or vertically. This transformation must be applied to both the label and the actual image. 
>  * `rescale` - rescale the image by a certain factor, e.g. 1/ 255.
>4. Shuffle the data, repeat the data (so we can iterate over it multiple times across epochs), batch the data, then prefetch a batch (for efficiency).

#### Why do we do these image transformations?

Data augmentation은 딥러닝을 이용한 이미지 처리분야 (classification, detection, segmentation 등) 에서 널리 쓰이는 테크닉이다. 자세한 내용은 아래 TensorFlow 공식 예제 링크로 대체한다.

> This is known as **data augmentation**. Data augmentation "increases" the amount of training data by augmenting them via a number of random transformations. During training time, our model would never see twice the exact same picture. This helps prevent [overfitting](https://developers.google.com/machine-learning/glossary/#overfitting) and helps the model generalize better to unseen data.

#### Processing each pathname

In [None]:
def _process_pathnames(fname, label_path):
  # We map this function onto each pathname pair
  img_str = tf.read_file(fname)
  img = tf.image.decode_bmp(img_str, channels=3)

  label_img_str = tf.read_file(label_path)
  label_img = tf.image.decode_bmp(label_img_str, channels=1)
  
  resize = [image_size, image_size]
  img = tf.image.resize_images(img, resize)
  label_img = tf.image.resize_images(label_img, resize)
  
  scale = 1 / 255.
  img = tf.cast(img, dtype=tf.float32) * scale
  label_img = tf.cast(label_img, dtype=tf.float32) * scale
  
  return img, label_img

#### Shifting the image

In [None]:
def shift_img(output_img, label_img, width_shift_range, height_shift_range):
  """This fn will perform the horizontal or vertical shift"""
  if width_shift_range or height_shift_range:
      if width_shift_range:
        width_shift_range = tf.random_uniform([], 
                                              -width_shift_range * img_shape[1],
                                              width_shift_range * img_shape[1])
      if height_shift_range:
        height_shift_range = tf.random_uniform([],
                                               -height_shift_range * img_shape[0],
                                               height_shift_range * img_shape[0])
      # Translate both 
      output_img = tfcontrib.image.translate(output_img,
                                             [width_shift_range, height_shift_range])
      label_img = tfcontrib.image.translate(label_img,
                                             [width_shift_range, height_shift_range])
  return output_img, label_img

#### Flipping the image randomly

In [None]:
def flip_img(horizontal_flip, tr_img, label_img):
  if horizontal_flip:
    flip_prob = tf.random_uniform([], 0.0, 1.0)
    tr_img, label_img = tf.cond(tf.less(flip_prob, 0.5),
                                lambda: (tf.image.flip_left_right(tr_img), tf.image.flip_left_right(label_img)),
                                lambda: (tr_img, label_img))
  return tr_img, label_img

#### Assembling our transformations into our augment function

In [None]:
def _augment(img,
             label_img,
             resize=None,  # Resize the image to some size e.g. [256, 256]
             scale=1,  # Scale image e.g. 1 / 255.
             hue_delta=0,  # Adjust the hue of an RGB image by random factor
             horizontal_flip=False,  # Random left right flip,
             width_shift_range=0,  # Randomly translate the image horizontally
             height_shift_range=0):  # Randomly translate the image vertically 
  if resize is not None:
    # Resize both images
    label_img = tf.image.resize_images(label_img, resize)
    img = tf.image.resize_images(img, resize)
  
  if hue_delta:
    img = tf.image.random_hue(img, hue_delta)
  
  img, label_img = flip_img(horizontal_flip, img, label_img)
  img, label_img = shift_img(img, label_img, width_shift_range, height_shift_range)
  label_img = tf.cast(label_img, dtype=tf.float32) * scale
  img = tf.cast(img, dtype=tf.float32) * scale
  return img, label_img

In [None]:
def get_baseline_dataset(filenames,
                         labels,
                         preproc_fn=functools.partial(_augment),
                         threads=5,
                         batch_size=batch_size,
                         is_train=True):
  num_x = len(filenames)
  # Create a dataset from the filenames and labels
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  # Map our preprocessing function to every element in our dataset, taking
  # advantage of multithreading
  dataset = dataset.map(_process_pathnames, num_parallel_calls=threads)
  
  if is_train:
    #if preproc_fn.keywords is not None and 'resize' not in preproc_fn.keywords:
    #  assert batch_size == 1, "Batching images must be of the same size"
    dataset = dataset.map(preproc_fn, num_parallel_calls=threads)
    dataset = dataset.shuffle(num_x * 10)
  
  dataset = dataset.batch(batch_size)
  return dataset

### Set up train and test datasets
Note that we apply image augmentation to our training dataset but not our validation dataset.

In [None]:
train_dataset = get_baseline_dataset(x_train_filenames,
                                     y_train_filenames)
test_dataset = get_baseline_dataset(x_test_filenames,
                                    y_test_filenames,
                                    is_train=False)

In [None]:
train_dataset

### Plot some train data

In [None]:
for images, labels in train_dataset.take(1):
  # Running next element in our graph will produce a batch of images
  plt.figure(figsize=(10, 10))
  img = images[0]

  plt.subplot(1, 2, 1)
  plt.imshow(img)

  plt.subplot(1, 2, 2)
  plt.imshow(labels[0, :, :, 0])
  plt.show()

## Build the model

해당 프로젝트는 두 개의 네트워크를 만들어보는 것이 목표이다.
* Encoder-Decoder 스타일의 네트워크
* [U-Net](https://arxiv.org/abs/1505.04597)

### Encoder-Decoder architecture

#### Encoder
* 다음과 같은 구조로 Encoder로 만들어보자.
* `input data`의 shape이 다음과 같이 되도록 네트워크를 구성해보자
  * inputs = [batch_size, 256, 256, 3]
  * conv1 = [batch_size, 128, 128, 32]
  * conv2 = [batch_size, 64, 64, 64]
  * conv3 = [batch_size, 32, 32, 128]
  * outputs = [batch_size, 16, 16, 256]
* Convolution - Normalization - Activation 등의 조합을 다양하게 생각해보자.
* Pooling을 쓸지 Convolution with stride=2 로 할지 잘 생각해보자.
* `tf.keras.Sequential()`을 이용하여 만들어보자.
  
#### Decoder
* Encoder의 mirror 형태로 만들어보자.
* `input data`의 shape이 다음과 같이 되도록 네트워크를 구성해보자
  * inputs = encoder의 outputs = [batch_size, 16, 16, 256]
  * conv_transpose1 = [batch_size, 32, 32, 128]
  * conv_transpose2 = [batch_size, 64, 64, 64]
  * conv_transpose3 = [batch_size, 128, 128, 32]
  * outputs = [batch_size, 256, 256, 1]
* `tf.keras.Sequential()`을 이용하여 만들어보자.

In [None]:
if model_name == 'ed_model':
  encoder = tf.keras.Sequential(name='encoder')

In [None]:
if model_name == 'ed_model':
  # inputs: [batch_size, 256, 256, 3]
  
  # conv1: [batch_size, 128, 128, 32]

  # conv2: [batch_size, 64, 64, 64]

  # conv3: [batch_size, 32, 32, 128]

  # outputs: [batch_size, 16, 16, 256]

In [None]:
# encoder 제대로 만들어졌는지 확인
if model_name == 'ed_model':
  bottleneck = encoder(tf.random.normal([3, 256, 256, 3]))
  print(bottleneck.shape)

In [None]:
if model_name == 'ed_model':
  decoder = tf.keras.Sequential(name='decoder')

In [None]:
if model_name == 'ed_model':
  # inputs: [batch_size, 16, 16, 256]

  # conv_transpose1: [batch_size, 32, 32, 128]

  # conv_transpose2: [batch_size, 64, 64, 64]

  # conv_transpose3: [batch_size, 128, 128, 32]

  # outputs: [batch_size, 256, 256, 1]

In [None]:
# decoder 제대로 만들어졌는지 확인
if model_name == 'ed_model':
  predictions = decoder(bottleneck)
  print(predictions.shape)

#### Create a encoder-decocer model

In [None]:
if model_name == 'ed_model':
  ed_model = tf.keras.Sequential()
  ed_model.add(encoder)
  ed_model.add(decoder)

### U-Net architecture

<img src='https://user-images.githubusercontent.com/11681225/58005153-fd934300-7b1f-11e9-9ad8-a0e9186e751c.png' width="800">

아래는 U-Net 만들 때 참고하면 좋은 TensorFlow tutorial 설명이다.

>We'll build the U-Net model. U-Net is especially good with segmentation tasks because it can localize well to provide high resolution segmentation masks. In addition, it works well with small datasets and is relatively robust against overfitting as the training data is in terms of the number of patches within an image, which is much larger than the number of training images itself. Unlike the original model, we will add batch normalization to each of our blocks. 

>The Unet is built with an encoder portion and a decoder portion. The encoder portion is composed of a linear stack of [`Conv`](https://developers.google.com/machine-learning/glossary/#convolution), `BatchNorm`, and [`Relu`](https://developers.google.com/machine-learning/glossary/#ReLU) operations followed by a [`MaxPool`](https://developers.google.com/machine-learning/glossary/#pooling). Each `MaxPool` will reduce the spatial resolution of our feature map by a factor of 2. We keep track of the outputs of each block as we feed these high resolution feature maps with the decoder portion. The Decoder portion is comprised of UpSampling2D, Conv, BatchNorm, and Relus. Note that we concatenate the feature map of the same size on the decoder side. Finally, we add a final Conv operation that performs a convolution along the channels for each individual pixel (kernel size of (1, 1)) that outputs our final segmentation mask in grayscale. 

#### The `tf.keras` Functional API

U-Net은 Encoder-Decoder 구조와는 달리 해당 레이어의 outputs이 바로 다음 레이어의 inputs이 되지 않는다. 이럴때는 `tf.keras.Sequential()`을 쓸 수가 없다. Sequential 구조가 아닌 네트워크를 만들 때 쓸 수 있는 API 가 바로 `tf.keras` functional API 이다. 자세한 설명은 다음 [문서](https://keras.io/getting-started/functional-api-guide/)를 참고 하면 좋다.

In [None]:
if model_name == 'u-net':
  class Conv(tf.keras.Model):
    def __init__(self, num_filters, kernel_size):
      super(Conv, self).__init__()
      self.conv = layers.Conv2D(num_filters, kernel_size, padding='same')
      self.bn = layers.BatchNormalization()

    def call(self, inputs, training=True):
      x = self.conv(inputs)
      x = self.bn(x, training=training)
      x = tf.nn.relu(x)

      return x

In [None]:
if model_name == 'u-net':
  class ConvBlock(tf.keras.Model):
    def __init__(self, num_filters):
      super(ConvBlock, self).__init__()
      # TODO

    def call(self, inputs, training=True):
      # TODO

      return 


  class EncoderBlock(tf.keras.Model):
    def __init__(self, num_filters):
      super(EncoderBlock, self).__init__()
      # TODO

    def call(self, inputs, training=True):
      # TODO

      return 


  class DecoderBlock(tf.keras.Model):
    def __init__(self, num_filters):
      super(DecoderBlock, self).__init__()
      # TODO

    def call(self, input_tensor, concat_tensor, training=True):
      # TODO

      return 

In [None]:
if model_name == 'u-net':
  class UNet(tf.keras.Model):
    def __init__(self):
      super(UNet, self).__init__()
      # TODO
      
    def call(self, inputs, training=True):
      # TODO

      return 

#### Create a U-Net model

In [None]:
if model_name == 'u-net':
  unet_model = UNet()

### Defining custom metrics and loss functions

우리가 사용할 loss function은 다음과 같다.
* binary cross entropy
* dice_loss

[논문](http://campar.in.tum.de/pub/milletari2016Vnet/milletari2016Vnet.pdf)에 나온 Dice coefficient 수식

$$D = \frac{2 \sum_{i}^{N} p_{i}g_{i}}{\sum_{i}^{N} p_{i}^{2} \sum_{i}^{N} g_{i}^{2}}$$

하지만 구현은 이렇게 할 것임.

$$D = \frac{2 \sum_{i}^{N} p_{i}g_{i} + \varepsilon}{\sum_{i}^{N} p_{i} \sum_{i}^{N} g_{i} + \varepsilon}$$

Dice loss의 자세한 설명은 아래를 참고하자.

>Defining loss and metric functions are simple with Keras. Simply define a function that takes both the True labels for a given example and the Predicted labels for the same given example.

>Dice loss is a metric that measures overlap. More info on optimizing for Dice coefficient (our dice loss) can be found in the [paper](http://campar.in.tum.de/pub/milletari2016Vnet/milletari2016Vnet.pdf), where it was introduced.

>We use dice loss here because it performs better at class imbalanced problems by design. In addition, maximizing the dice coefficient and IoU metrics are the actual objectives and goals of our segmentation task. Using cross entropy is more of a proxy which is easier to maximize. Instead, we maximize our objective directly.

In [None]:
def dice_coeff(y_true, y_pred):
  smooth = 1.
  # TODO
  # Flatten
  y_true_f = 
  y_pred_f = 
  intersection = 
  score = 
  return score

In [None]:
def dice_loss(y_true, y_pred):
  loss = 1 - dice_coeff(y_true, y_pred)
  return loss

Here, we'll use a specialized loss function that combines binary cross entropy and our dice loss. This is based on [individuals who competed within this competition obtaining better results empirically](https://www.kaggle.com/c/carvana-image-masking-challenge/discussion/40199). Try out your own custom losses to measure performance (e.g. bce + log(dice_loss), only bce, etc.)!

In [None]:
def bce_dice_loss(y_true, y_pred):
  loss = tf.reduce_mean(losses.binary_crossentropy(y_true, y_pred)) + dice_loss(y_true, y_pred)
  return loss

In [None]:
optimizer = tf.train.AdamOptimizer()

### Select a model

In [None]:
if model_name == 'ed_model':
  print('select the Encoder-Decoder model')
  model = ed_model

if model_name == 'u-net':
  print('select the U-Net model')
  model = unet_model

### Compile

In [None]:
model.compile(optimizer=optimizer, loss=bce_dice_loss, metrics=[dice_loss])
predictions = model(tf.random.normal([batch_size, 256, 256, 3]))
print(predictions.shape)

In [None]:
model.summary()

### Checkpoints (Object-based saving)

In [None]:
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
if is_train:
  checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                   model=model)
else:
  checkpoint = tf.train.Checkpoint(model=model)

## Train your model

In [None]:
## Define print function
def print_images():
  for test_images, test_labels in test_dataset.take(1):
    predictions = model(test_images, training=False)
        
    plt.figure(figsize=(10, 20))
    plt.subplot(1, 3, 1)
    plt.imshow(test_images[0,: , :, :])
    plt.title("Input image")

    plt.subplot(1, 3, 2)
    plt.imshow(test_labels[0, :, :, 0])
    plt.title("Actual Mask")

    plt.subplot(1, 3, 3)
    plt.imshow(predictions[0, :, :, 0])
    plt.title("Predicted Mask")
    plt.show()

### Training 첫번째 방법 `model.fit()` 함수 이용

In [None]:
# TODO
model.fit()

In [None]:
# print sample image after training
print_images()

In [None]:
# Save weight
checkpoint.save(file_prefix = checkpoint_prefix)

### Training 두 번째 방법 `tf.GradientTape()`을 이용하여 직접 구현하기

In [None]:
%%time
print('Start Training.')
train_dataset = train_dataset.repeat(1)
num_batches_per_epoch = num_train_examples // batch_size
global_step = 0
# save loss values for plot
loss_history = []

for epoch in range(max_epochs):
  
  for step, (images, labels) in enumerate(train_dataset):
    start_time = time.time()
    
    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = bce_dice_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    

    if global_step % print_steps == 0:
      clear_output(wait=True)
      epochs = epoch + step / float(num_batches_per_epoch)
      duration = time.time() - start_time
      examples_per_sec = batch_size  / float(duration)
      print("Epochs: {:.2f} global_step: {} loss: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step, loss, examples_per_sec, duration))

      loss_history.append([epochs, loss])

      # print sample image
      print_images()

  # saving (checkpoint) the model periodically
  if (epoch+1) % save_epochs == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

print('Training Done.')

### Plot the loss

In [None]:
loss_history = np.asarray(loss_history)
plt.plot(loss_history[:,0], loss_history[:,1])
plt.show()

## Restore the latest checkpoint

In [None]:
if not is_train:
  # restoring the latest checkpoint in checkpoint_dir
  status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Evaluate the test dataset

In [None]:
def mean_iou(y_true, y_pred):
  # Flatten
  y_true_f = tf.keras.layers.Flatten()(y_true)
  y_pred_f = tf.keras.layers.Flatten()(y_pred)
  
  y_true_f = tf.to_int32(tf.round(y_true_f))
  y_pred_f = tf.to_int32(tf.round(y_pred_f))
  
  # TO DO
  intersection = 
  union = 
  
  mean_iou = tf.reduce_mean(intersection/union)
  
  return mean_iou

In [None]:
mean = tf.keras.metrics.Mean("mean_iou")

for images, labels in test_dataset.take(3):
  predictions = model(images, training=False)
  m = mean_iou(labels, predictions)
  mean(m)

print("mean_iou: {}".format(mean.result().numpy()))