[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/isbi.ipynb)

In [0]:
colab = True

In [0]:
import numpy as np
import os
from shutil import unpack_archive
import cv2
from matplotlib import pyplot as plt
from IPython.display import Image
import PIL
from keras.optimizers import Adam
from time import time
if colab:
  from google.colab import files
import keras
from keras.models import load_model
from keras.callbacks import ModelCheckpoint
import pickle

### Download ISBI dataset

In [0]:
if colab:
#   %rm -r /content/semantic_segmentation/unet
#   %mkdir /content/semantic_segmentation
#   %cd /content/semantic_segmentation/
#   !git clone https://github.com/zhixuhao/unet.git
#   %cd /content/semantic_segmentation/unet/data/membrane/
#   !ls
  wdir = '/content/semantic_segmentation/unet/data/membrane/'
else:
  wdir = 'E:/Data_Files/Workspaces/PyCharm/semantic_segmentation/'

In [0]:
%cd /content/semantic_segmentation/unet/data/membrane/train
!ls

### Import helpers and Model files

In [0]:
if colab:
    %cd {wdir}
    %rm -r {wdir}/semantic_segmentation/
    !git clone https://github.com/dkatsios/semantic_segmentation.git
    %cd {wdir}/semantic_segmentation
else:
    %cd {wdir}
%ls

In [0]:
!pip install import_ipynb
import import_ipynb
from voc2012_helpers import *
from AtrusUnet_3 import  *

### Set parameters

In [0]:
img_shape = 512, 512, 3
filters = 1024  # 128
segmentation_classes = 1
steps = 4
out_resized_levels = 0
kernel_sizes = [2, 3, 5, 7]
dilation_rates = [1, 2, 3]

use_max_pooling = True
use_depthwise = True
pre_resized = False
feed_resized = False

classify = False
only_labels = False
single_class = False

use_regularizers=False
kernel_init_he = True

batch_size = 1
test_batch_size = 4

epochs = 40
val_steps = 10

### Build generators

In [0]:
train_folder = wdir + 'train/'
imgs_folder, classes_folder = train_folder + 'image/', train_folder + 'label/'
imgs_list, classes_list = get_lists(imgs_folder, classes_folder)
train_arrays = get_isbi_imgs_classes_arrays(imgs_list[5:], classes_list[5:], img_shape)

val_folder = wdir + 'train/'
imgs_folder, classes_folder = val_folder + 'image/', val_folder + 'label/'
imgs_list, classes_list = get_lists(imgs_folder, classes_folder)
val_arrays = get_isbi_imgs_classes_arrays(imgs_list[:5], classes_list[:5], img_shape)

In [0]:
steps_per_epoch= train_arrays[0].shape[0] // batch_size
train_gen = isbi_imgs_generator(*train_arrays, batch_size, out_resized_levels, segmentation_classes,
                                pre_resized=pre_resized, classify=classify, steps=steps, single_class=single_class)
val_gen = isbi_imgs_generator(*val_arrays, batch_size, out_resized_levels, segmentation_classes,
                              pre_resized=pre_resized, classify=classify, steps=steps, single_class=single_class)
test_gen = isbi_imgs_generator(*val_arrays, test_batch_size, out_resized_levels, segmentation_classes,
                               pre_resized=pre_resized, classify=classify, steps=steps, single_class=single_class)

### Build model

In [0]:
atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes, steps, out_resized_levels,
                         kernel_sizes, use_max_pooling=use_max_pooling, use_depthwise=use_depthwise,
                         pre_resized=pre_resized, classify=classify, single_class=single_class,
                         only_labels=only_labels, use_regularizers=use_regularizers,
                         kernel_init_he=kernel_init_he, feed_resized=feed_resized)

atrous_unet.build_model()
print('number of parameters:', atrous_unet.model.count_params())

### Compile model

In [0]:
optimizer = Adam(1e-4)
loss, metrics, loss_weights = get_loss_metrics_weights(out_resized_levels, use_dice=False, classify=classify,
                                                       segmentation_classes=segmentation_classes,
                                                       only_labels=only_labels, single_class=single_class,
                                                       loss_factor=10.)
label_metrics = Metrics() if False else None
  
atrous_unet.model.compile(optimizer=optimizer, loss=loss,
                          metrics=metrics, loss_weights=loss_weights)
if colab:
    %mkdir {wdir}logs/

### Train model

In [0]:
weights_path = wdir + 'logs/weights.hdf5'
class_weight=None
checkpointer = ModelCheckpoint(filepath=weights_path, verbose=1, save_best_only=True)
callbacks = [checkpointer]
if label_metrics is not None: callbacks.append(label_metrics)

In [0]:
# atrous_unet.model.summary()

In [0]:
history = atrous_unet.model.fit_generator(train_gen,
                                          steps_per_epoch=steps_per_epoch, epochs=epochs,
                                          verbose=1, validation_data=val_gen, validation_steps=val_steps,
                                          class_weight=class_weight, callbacks=callbacks)

with open(wdir+'logs/history.pkl', 'wb') as handle:
    pickle.dump(history.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Download results (weights and history)

In [0]:
# files.download(wdir+'logs/weights.hdf5')
# files.download(wdir+'logs/history.pkl')

### Load weights

In [0]:
# atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes, steps, out_resized_levels,
#                          kernel_sizes,use_max_pooling=use_max_pooling, use_depthwise=use_depthwise)

# atrous_unet.build_model()
# atrous_unet.model.load_weights(wdir + 'logs/weights.hdf5')

### Plot results

In [0]:
# imgs, labels = val_gen.__next__()
# labels = labels['labels']
# predictions = atrous_unet.model.predict_on_batch(imgs)
# print(list(zip(*np.where(labels))))
# print(list(zip(*np.where(predictions > .3)))) 
# plt.imshow((imgs[0] + 1) * 127.5)

In [0]:
imgs, labels = test_gen.__next__()
predictions = atrous_unet.model.predict_on_batch(imgs)
imgs = (imgs + 1) * 127.5

In [0]:
predictions = predictions[-1] if isinstance(predictions, list) else predictions
pred_labels = (predictions > .5).astype(np.float32)
real_labels = labels['prediction']
pred_labels = pred_labels[:, :, :, 0]  # np.tile(pred_labels, reps=(1, 1, 1, 3))
real_labels = real_labels[:, :, :, 0]  # np.tile(real_labels, reps=(1, 1, 1, 3))

In [0]:
for img, real_label, pred_label in zip(imgs, real_labels, pred_labels):
  f, (ax1, ax2, ax3) = plt.subplots(1, 3)
  ax1.imshow(img)
  ax1.axis('off')
  ax2.imshow(real_label, cmap='gray')
  ax2.axis('off')
  ax3.imshow(pred_label, cmap='gray')
  ax3.axis('off')
  plt.show()