Skip to content

Commit

Permalink
Add ResNet-based model
Browse files Browse the repository at this point in the history
* Also, add learning rate as RunOption parameter
* Remove fcn.py, since fcn_resnet.py supersedes it.
  • Loading branch information
lewfish committed Feb 21, 2017
1 parent 0a7c714 commit abb7b94
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 101 deletions.
6 changes: 3 additions & 3 deletions src/model_training/eval_run.py
Expand Up @@ -149,10 +149,10 @@ def eval_run(options):
plot_predictions(model, data_path, run_path, options.nb_prediction_images,
options.include_depth)

print('Plotting graphs...')
plot_graphs(model, run_path)

print('Computing scores...')
compute_scores(
model, data_path, run_path, options.batch_size, options.nb_val_samples,
options.include_depth)

print('Plotting graphs...')
plot_graphs(model, run_path)
85 changes: 0 additions & 85 deletions src/model_training/models/fcn.py

This file was deleted.

53 changes: 53 additions & 0 deletions src/model_training/models/fcn_resnet.py
@@ -0,0 +1,53 @@
"""
ResNet based FCN.
"""
from keras.models import Model
from keras.layers import (Input,
Activation,
Convolution2D,
Reshape,
Lambda,
merge)

from .resnet import ResNet


def make_fcn_resnet(input_shape, nb_labels):
input_shape = tuple(input_shape)
nb_rows, nb_cols, _ = input_shape
nb_labels = nb_labels

input_tensor = Input(shape=input_shape)
model = ResNet(input_tensor=input_tensor)

x = model.output

x64 = model.get_layer('activation_10').output
x32 = model.get_layer('activation_22').output
x16 = model.get_layer('activation_37').output

def resize_bilinear(images):
# Workaround for
# https://github.com/fchollet/keras/issues/4609
import tensorflow as tf
nb_rows = 512
nb_cols = 512
return tf.image.resize_bilinear(images, [nb_rows, nb_cols])

c64 = Convolution2D(nb_labels, 1, 1)(x64)
c32 = Convolution2D(nb_labels, 1, 1)(x32)
c16 = Convolution2D(nb_labels, 1, 1)(x16)

b64 = Lambda(resize_bilinear)(c64)
b32 = Lambda(resize_bilinear)(c32)
b16 = Lambda(resize_bilinear)(c16)

x = merge([b64, b32, b16], mode='sum')

x = Reshape((nb_rows * nb_cols, nb_labels))(x)
x = Activation('softmax')(x)
x = Reshape((nb_rows, nb_cols, nb_labels))(x)

model = Model(input=input_tensor, output=x)

return model
2 changes: 1 addition & 1 deletion src/model_training/models/fcn_vgg_skip.py
Expand Up @@ -38,7 +38,7 @@ def make_fcn_vgg_skip(input_shape, nb_labels):
c32 = Convolution2D(128, 1, 1, border_mode='same', activation='relu')(x32)
l32 = Convolution2D(nb_labels, 1, 1, border_mode='same')(c32)

c16 = Convolution2D(512, 1, 1, border_mode='same', activation='relu')(x16)
c16 = Convolution2D(256, 1, 1, border_mode='same', activation='relu')(x16)
l16 = Convolution2D(nb_labels, 1, 1, border_mode='same')(c16)

def resize_bilinear(images):
Expand Down
131 changes: 131 additions & 0 deletions src/model_training/models/resnet.py
@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
# flake8: noqa
'''ResNet50 model for Keras.
# Reference:
- [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
Adapted from code from
https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py
'''
from __future__ import print_function

import numpy as np
import warnings

from keras.layers import merge, Input
from keras.layers import Dense, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, AveragePooling2D
from keras.layers import BatchNormalization
from keras.models import Model
from keras.preprocessing import image
import keras.backend as K
from keras.utils.layer_utils import convert_all_kernels_in_model
from keras.utils.data_utils import get_file


def identity_block(input_tensor, kernel_size, filters, stage, block):
'''The identity_block is the block that has no conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
'''
nb_filter1, nb_filter2, nb_filter3 = filters
if K.image_dim_ordering() == 'tf':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = Convolution2D(nb_filter1, 1, 1, name=conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)

x = Convolution2D(nb_filter2, kernel_size, kernel_size,
border_mode='same', name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)

x = Convolution2D(nb_filter3, 1, 1, name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

x = merge([x, input_tensor], mode='sum')
x = Activation('relu')(x)
return x


def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
'''conv_block is the block that has a conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
And the shortcut should have subsample=(2,2) as well
'''
nb_filter1, nb_filter2, nb_filter3 = filters
if K.image_dim_ordering() == 'tf':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = Convolution2D(nb_filter1, 1, 1, subsample=strides,
name=conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)

x = Convolution2D(nb_filter2, kernel_size, kernel_size, border_mode='same',
name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)

x = Convolution2D(nb_filter3, 1, 1, name=conv_name_base + '2c')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

shortcut = Convolution2D(nb_filter3, 1, 1, subsample=strides,
name=conv_name_base + '1')(input_tensor)
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)

x = merge([x, shortcut], mode='sum')
x = Activation('relu')(x)
return x


def ResNet(input_tensor=None):
img_input = input_tensor
if K.image_dim_ordering() == 'tf':
bn_axis = 3
else:
bn_axis = 1

x = ZeroPadding2D((3, 3))(img_input)
x = Convolution2D(64, 7, 7, subsample=(2, 2), name='conv1')(x)
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)

x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')

model = Model(img_input, x)

return model
9 changes: 7 additions & 2 deletions src/model_training/run.py
Expand Up @@ -4,9 +4,11 @@
"""
import uuid
import json
from os.path import join
from os.path import join, isfile
import argparse

from keras.models import load_model

from .data.preprocess import _makedirs, results_path
from .train import make_model, train_model
from .eval_run import eval_run
Expand All @@ -21,7 +23,8 @@ def __init__(self, git_commit=None, model_type=None, input_shape=None,
nb_labels=None, run_name=None, batch_size=None,
samples_per_epoch=None, nb_epoch=None, nb_val_samples=None,
nb_prediction_images=None, patience=None, cooldown=None,
include_depth=False, kernel_size=None, dataset=None):
include_depth=False, kernel_size=None, dataset=None,
lr=0.001):
# Run `git rev-parse head` to get this.
self.git_commit = git_commit
self.model_type = model_type
Expand All @@ -41,6 +44,8 @@ def __init__(self, git_commit=None, model_type=None, input_shape=None,

self.dataset = dataset

self.lr = lr

self.include_depth = include_depth
if self.input_shape[2] == 4:
self.include_depth = True
Expand Down
9 changes: 5 additions & 4 deletions src/model_training/train.py
Expand Up @@ -2,10 +2,11 @@
Functions for training a model given a RunOptions object.
"""
import numpy as np
from os.path import join
from os.path import join, isfile

from keras.callbacks import (ModelCheckpoint, CSVLogger,
ReduceLROnPlateau)
from keras.optimizers import Adam

from .data.generators import make_input_output_generators
from .data.preprocess import get_dataset_path, results_path
Expand All @@ -24,12 +25,12 @@ def make_model(options):
elif model_type == 'fcn_vgg':
from .models.fcn_vgg import make_fcn_vgg
model = make_fcn_vgg(options.input_shape, options.nb_labels)
elif model_type == 'fcn':
from .models.fcn import make_fcn
model = make_fcn(options.input_shape, options.nb_labels)
elif model_type == 'fcn_vgg_skip':
from .models.fcn_vgg_skip import make_fcn_vgg_skip
model = make_fcn_vgg_skip(options.input_shape, options.nb_labels)
elif model_type == 'fcn_resnet':
from .models.fcn_resnet import make_fcn_resnet
model = make_fcn_resnet(options.input_shape, options.nb_labels)

return model

Expand Down
13 changes: 7 additions & 6 deletions src/options.json
@@ -1,5 +1,5 @@
{
"batch_size": 7,
"batch_size": 2,
"cooldown": null,
"dataset": "potsdam",
"git_commit": null,
Expand All @@ -13,12 +13,13 @@
10,
10
],
"lr": 0.0001,
"model_type": "conv_logistic",
"nb_epoch": 10,
"nb_epoch": 6,
"nb_labels": 6,
"nb_prediction_images": 8,
"nb_val_samples": 512,
"patience": 3,
"run_name": "conv_logistic/potstdam_test",
"samples_per_epoch": 512
"nb_val_samples": 8,
"patience": 5,
"run_name": "conv_logistic/restart_test",
"samples_per_epoch": 16
}

0 comments on commit abb7b94

Please sign in to comment.