[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/voc2012.ipynb)

In [1]:
import numpy as np
import os
from shutil import unpack_archive
import cv2
from matplotlib import pyplot as plt
from IPython.display import Image
import PIL
from keras.optimizers import Adam
from time import time
from google.colab import files
import keras
from keras.models import load_model
from keras.callbacks import ModelCheckpoint
import pickle

Using TensorFlow backend.


### Download VOC 2012 dataset

In [2]:
%mkdir semantic_segmentation
%cd semantic_segmentation/
!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
unpack_archive('VOCtrainval_11-May-2012.tar', './')
%rm VOCtrainval_11-May-2012.tar
wdir = '/content/semantic_segmentation'

/content/semantic_segmentation
--2018-06-17 04:16:54--  http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
Resolving host.robots.ox.ac.uk (host.robots.ox.ac.uk)... 129.67.94.152
Connecting to host.robots.ox.ac.uk (host.robots.ox.ac.uk)|129.67.94.152|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1999639040 (1.9G) [application/x-tar]
Saving to: ‘VOCtrainval_11-May-2012.tar’

VOCtrainval_11-May-  32%[=====>              ] 614.85M  13.8MB/s    eta 1m 46s 


2018-06-17 04:19:17 (13.3 MB/s) - ‘VOCtrainval_11-May-2012.tar’ saved [1999639040/1999639040]



In [3]:
%cd {wdir}/VOCdevkit/VOC2012
!ls

/content/semantic_segmentation/VOCdevkit/VOC2012
Annotations  ImageSets	JPEGImages  SegmentationClass  SegmentationObject


In [0]:
imgs_folder = wdir + '/VOCdevkit/VOC2012/JPEGImages/'
classes_folder = wdir + '/VOCdevkit/VOC2012/SegmentationClass/'
train_list_path = wdir + '/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt'
val_list_path = wdir + '/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'

### Import helpers and Model files

In [5]:
%cd {wdir}
%rm -r {wdir}/semantic_segmentation/
!git clone https://github.com/dkatsios/semantic_segmentation.git
%cd {wdir}/semantic_segmentation
!ls

/content/semantic_segmentation
rm: cannot remove '/content/semantic_segmentation/semantic_segmentation/': No such file or directory
Cloning into 'semantic_segmentation'...
remote: Counting objects: 27, done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 27 (delta 11), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (27/27), done.
/content/semantic_segmentation/semantic_segmentation
AtrusUnet.ipynb  README.md  voc2012_helpers.ipynb  voc2012.ipynb


In [6]:
!pip install import_ipynb
import import_ipynb
from voc2012_helpers import *
from AtrusUnet import  *

Collecting import_ipynb
  Downloading https://files.pythonhosted.org/packages/63/35/495e0021bfdcc924c7cdec4e9fbb87c88dd03b9b9b22419444dc370c8a45/import-ipynb-0.1.3.tar.gz
Building wheels for collected packages: import-ipynb
  Running setup.py bdist_wheel for import-ipynb ... [?25l- done
[?25h  Stored in directory: /content/.cache/pip/wheels/b4/7b/e9/a3a6e496115dffdb4e3085d0ae39ffe8a814eacc44bbf494b5
Successfully built import-ipynb
Installing collected packages: import-ipynb
Successfully installed import-ipynb-0.1.3
importing Jupyter notebook from voc2012_helpers.ipynb
importing Jupyter notebook from AtrusUnet.ipynb


### Set parameters

In [0]:
img_shape = 512, 512, 3
filters = 128
segmentation_classes = 21
steps = 5
out_resized_levels = 2
kernel_sizes = [2, 3, 5]

batch_size = 2
epochs = 10
val_steps = 10

### Build generators

In [8]:
train_lists, val_lists = get_lists_from_folders(train_list_path, val_list_path, imgs_folder, classes_folder)
train_arrays = get_imgs_classes_arrays(*train_lists, img_shape)
val_arrays = get_imgs_classes_arrays(*val_lists, img_shape)

start constructing arrays
arrays constructed. time: 79 secs
start constructing arrays
arrays constructed. time: 78 secs


In [0]:
steps_per_epoch= len(train_lists[0]) // batch_size
train_gen = imgs_generator(*train_arrays, batch_size, out_resized_levels, segmentation_classes)
val_gen = imgs_generator(*val_arrays, batch_size, out_resized_levels, segmentation_classes)

### Build model

In [10]:
atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes,
                         steps, out_resized_levels, kernel_sizes, use_depthwise=True)

atrous_unet.build_model()
print('number of parameters:', atrous_unet.model.count_params())

number of parameters: 2091572


### Compile model

In [0]:
optimizer = Adam(0.0001)

losses = ['categorical_crossentropy'] * (out_resized_levels + 1)  # mse
atrous_unet.model.compile(optimizer=optimizer, loss=losses, metrics=['categorical_accuracy'])
!mkdir {wdir}/logs/

### Train model

In [0]:
class_weight = get_class_weight(segmentation_classes) 
checkpointer = ModelCheckpoint(filepath=wdir+'/logs/weights.hdf5', verbose=1, save_best_only=True)
download_weights = DownloadWeights(wdir)

In [0]:
history = atrous_unet.model.fit_generator(train_gen,
                                          steps_per_epoch=steps_per_epoch, epochs=epochs,
                                          verbose=1, validation_data=val_gen, validation_steps=val_steps,
                                          class_weight=class_weight, callbacks=[checkpointer, download_weights])

Epoch 1/10





### Download results (weights and history)

In [0]:
with open('./history.pkl', 'wb') as handle:
    pickle.dump(history.history, handle, protocol=pickle.HIGHEST_PROTOCOL)
files.download('./weights.hdf5')
files.download('./history.pkl')

### Load weights

In [0]:
# atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes,
#                          steps, out_resized_levels, kernel_sizes)
# atrous_unet.build_model()
# atrous_unet.model.load_weights('./weights.hdf5')

### Plot results

In [0]:
imgs, labels = val_gen.__next__()
predictions = atrous_unet.model.predict_on_batch(imgs)
pred_labels = predictions[-1]
pred_labels = get_images_from_predictions(pred_labels)

real_labels = labels[-1]
real_labels = get_images_from_predictions(real_labels)

In [0]:
for img, real_label, pred_label in zip(imgs, real_labels, pred_labels):
  f, (ax1, ax2, ax3) = plt.subplots(1, 3)
  ax1.imshow(img)
  ax1.axis('off')
  ax2.imshow(real_label)
  ax2.axis('off')
  ax3.imshow(pred_label)
  ax3.axis('off')
  plt.show()