Skip to content

Commit

Permalink
Start using dataset_info and add NDVI
Browse files Browse the repository at this point in the history
This refactors things to pass around the dataset_info object instead of the individual components of it
to various functions which simplifies the code. We also added an NDVI channel which might improve
accuracy. This channel is baked into the arrays that are generated by the preprocess script. Unfortunately,
this increases the size of the dataset, but it made implementation simpler. The alternative is to have �"dynamic"
channels that are computed from the "base" layers stored on disk.
  • Loading branch information
lewfish committed Mar 9, 2017
1 parent abb4a5d commit df87233
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 135 deletions.
60 changes: 26 additions & 34 deletions src/model_training/data/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def get_channel_stats(path):
return means, stds


def save_channel_stats(path):
means, stds = get_channel_stats(join(path, TRAIN))
def save_channel_stats(path, means, stds):
channel_stats = map(lambda x: {'mean': x[0], 'std': x[1]},
zip(means, stds))
channel_stats_json = json.dumps({'stats': list(channel_stats)}, indent=4)
Expand Down Expand Up @@ -113,8 +112,8 @@ def make_batch_generator(path, tile_size, batch_size, shuffle):
yield samples


def transform_batch(batch, input_inds, output_inds, output_mask_inds,
augment=False, scale_params=None, eval_mode=False):
def transform_batch(batch, dataset_info, augment=False, scale_params=None,
eval_mode=False):
batch = batch.astype(np.float32)

if augment:
Expand All @@ -131,44 +130,40 @@ def transform_batch(batch, input_inds, output_inds, output_mask_inds,

if scale_params is not None:
means, stds = scale_params
batch[:, :, :, input_inds] -= \
means[np.newaxis, np.newaxis, np.newaxis, input_inds]
batch[:, :, :, input_inds] /= \
stds[np.newaxis, np.newaxis, np.newaxis, input_inds]
batch[:, :, :, dataset_info.input_inds] -= \
means[np.newaxis, np.newaxis, np.newaxis, dataset_info.input_inds]
batch[:, :, :, dataset_info.input_inds] /= \
stds[np.newaxis, np.newaxis, np.newaxis, dataset_info.input_inds]

inputs = batch[:, :, :, input_inds]
inputs = batch[:, :, :, dataset_info.input_inds]

if eval_mode:
outputs = batch[:, :, :, output_inds]
outputs_mask = batch[:, :, :, output_mask_inds]
outputs = batch[:, :, :, dataset_info.output_inds]
outputs_mask = batch[:, :, :, dataset_info.output_mask_inds]
return inputs, outputs, outputs_mask
else:
outputs = batch[:, :, :, output_inds]
outputs = batch[:, :, :, dataset_info.output_inds]
outputs = np.squeeze(outputs, axis=3)
outputs = label_to_one_hot_batch(outputs)
return inputs, outputs


def make_split_generator(dataset, split, tile_size=(256, 256),
batch_size=32, shuffle=False, augment=False,
scale=False, include_ir=False, include_depth=False,
eval_mode=False):
dataset_info = get_dataset_info(dataset)
def make_split_generator(dataset_info, split, batch_size=32, shuffle=False,
augment=False, scale=False, eval_mode=False):
path = dataset_info.dataset_path
split_path = join(path, split)

_, input_inds, output_inds, output_mask_inds = \
dataset_info.get_channel_inds(
include_ir=include_ir, include_depth=include_depth)
scale_params = load_channel_stats(path) \
if scale else None

gen = make_batch_generator(split_path, tile_size, batch_size, shuffle)
tile_size = dataset_info.input_shape[0:2]
gen = make_batch_generator(split_path, tile_size,
batch_size, shuffle)

def transform(batch):
return transform_batch(batch, input_inds, output_inds,
output_mask_inds, augment=augment,
scale_params=scale_params, eval_mode=eval_mode)
return transform_batch(
batch, dataset_info, augment=augment, scale_params=scale_params,
eval_mode=eval_mode)
gen = map(transform, gen)

return gen
Expand All @@ -188,9 +183,9 @@ def unscale_inputs(inputs, input_inds, scale_params):
return inputs


def plot_sample(file_path, inputs, outputs, rgb_input_inds, input_inds,
def plot_sample(file_path, inputs, outputs, dataset_info,
scale_params):
inputs = unscale_inputs(inputs, input_inds, scale_params)
inputs = unscale_inputs(inputs, dataset_info.input_inds, scale_params)

fig = plt.figure()
nb_input_inds = inputs.shape[2]
Expand All @@ -210,7 +205,7 @@ def plot_image(plot_row, plot_col, im, is_rgb=False):

plot_row = 0
plot_col = 0
im = inputs[:, :, rgb_input_inds]
im = inputs[:, :, dataset_info.rgb_input_inds]
plot_image(plot_row, plot_col, im, is_rgb=True)

for channel_ind in range(nb_input_inds):
Expand Down Expand Up @@ -239,18 +234,16 @@ def viz_generator(split):
batch_size = 4

dataset_info = get_dataset_info(dataset)
dataset_info.setup(include_ir=True, include_depth=True, include_ndvi=True)
path = dataset_info.dataset_path
viz_path = join(path, split, 'gen_samples')
_makedirs(viz_path)

scale_params = load_channel_stats(path)
rgb_input_inds, input_inds, _, _ = dataset_info.get_channel_inds(
include_ir=True, include_depth=True)

gen = make_split_generator(
POTSDAM, split, tile_size=(256, 256),
batch_size=batch_size, shuffle=True, augment=True, scale=True,
include_ir=True, include_depth=True)
dataset_info, split, batch_size=batch_size, shuffle=True, augment=True,
scale=True)

for batch_ind in range(nb_batches):
inputs, outputs = next(gen)
Expand All @@ -259,8 +252,7 @@ def viz_generator(split):
viz_path, '{}_{}.pdf'.format(batch_ind, sample_ind))
plot_sample(
file_path, inputs[sample_ind, :, :, :],
outputs[sample_ind, :, :, :], rgb_input_inds, input_inds,
scale_params)
outputs[sample_ind, :, :, :], dataset_info, scale_params)


if __name__ == '__main__':
Expand Down
34 changes: 24 additions & 10 deletions src/model_training/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
POTSDAM, TRAIN, VALIDATION, seed, get_dataset_info)
from .utils import (
_makedirs, load_tiff, load_image, rgb_to_label_batch, save_image,
rgb_to_mask)
from .generators import save_channel_stats
rgb_to_mask, compute_ndvi)
from .generators import save_channel_stats, get_channel_stats

np.random.seed(seed)

Expand All @@ -36,6 +36,13 @@ def process_data(file_indices, raw_rgbir_input_path, raw_depth_input_path,
join(raw_depth_input_path, depth_file_name))
depth_input_im = np.expand_dims(depth_input_im, axis=2)

red = rgbir_input_im[:, :, 0]
ir = rgbir_input_im[:, :, 3]
ndvi_im = compute_ndvi(red, ir)
# NDVI ranges from [-1.0, 1.0]. We need to make this value fit into a
# uint8 so we scale it.
ndvi_im = (ndvi_im + 1) * 127

output_im = load_tiff(join(raw_output_path, output_file_name))
output_im = np.expand_dims(output_im, axis=0)
output_im = rgb_to_label_batch(output_im)
Expand All @@ -48,8 +55,8 @@ def process_data(file_indices, raw_rgbir_input_path, raw_depth_input_path,
output_mask_im = np.expand_dims(output_mask_im, axis=2)

concat_im = np.concatenate(
[rgbir_input_im, depth_input_im, output_im, output_mask_im],
axis=2)
[rgbir_input_im, depth_input_im, ndvi_im, output_im,
output_mask_im], axis=2)

proc_file_name = '{}_{}'.format(index1, index2)
save_image(join(proc_data_path, proc_file_name), concat_im)
Expand Down Expand Up @@ -78,20 +85,27 @@ def get_file_names(index1, index2):
return (rgbir_file_name, depth_file_name, output_file_name,
output_mask_file_name)

train_file_indices, validation_file_inds = dataset_info.get_file_inds()

train_path = join(proc_data_path, TRAIN)
process_data(
train_file_indices, raw_rgbir_input_path, raw_depth_input_path,
dataset_info.train_inds, raw_rgbir_input_path, raw_depth_input_path,
raw_output_path, raw_output_mask_path, train_path,
get_file_names)

save_channel_stats(proc_data_path)
means, stds = get_channel_stats(train_path)
# The NDVI values are in [-1,1] by definition, but we store them as uint8s
# in [0, 255]. So, we use a hard coded scaling for this channel to make the
# values go back to [-1, 1], since they are more easily interpreted that way
# and fall into the range we want for the neural network.
ndvi_ind = dataset_info.ndvi_ind
means[ndvi_ind] = 1.0
stds[ndvi_ind] = 127.0
save_channel_stats(proc_data_path, means, stds)

validation_path = join(proc_data_path, VALIDATION)
process_data(
validation_file_inds, raw_rgbir_input_path, raw_depth_input_path,
raw_output_path, raw_output_mask_path, validation_path, get_file_names)
dataset_info.validation_inds, raw_rgbir_input_path,
raw_depth_input_path, raw_output_path, raw_output_mask_path,
validation_path, get_file_names)


if __name__ == '__main__':
Expand Down
75 changes: 36 additions & 39 deletions src/model_training/data/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from os.path import join

RGBIR_INPUT = 'rgbir_input'
DEPTH_INPUT = 'depth_input'
TRAIN = 'train'
VALIDATION = 'validation'
POTSDAM = 'potsdam'
Expand All @@ -28,18 +26,10 @@
'Car',
'Clutter'
]
nb_labels = len(label_keys)

data_path = '/opt/data/'
datasets_path = join(data_path, 'datasets')
results_path = join(data_path, 'results')
raw_potsdam_path = join(datasets_path, POTSDAM)

tile_size = 256
target_size = (tile_size, tile_size)

big_tile_size = 2000
big_target_size = (big_tile_size, big_tile_size)

seed = 1

Expand All @@ -48,45 +38,52 @@ class PotsdamInfo():
def __init__(self):
self.dataset_path = join(datasets_path, 'processed_potsdam')
self.raw_dataset_path = join(datasets_path, POTSDAM)
self.small_tile_size = 256
self.big_tile_size = 2000
self.nb_labels = len(label_keys)

def get_channel_inds(self, include_ir=False, include_depth=False):
rgb_input_inds = [0, 1, 2]
input_inds = [0, 1, 2]
if include_ir:
input_inds.append(3)
if include_depth:
input_inds.append(4)

output_inds = [5]
output_mask_inds = [6]
return rgb_input_inds, input_inds, output_inds, output_mask_inds

def get_input_shape(self, include_ir=False, include_depth=False,
use_big_tiles=False):
nb_channels = 3
if include_ir:
nb_channels += 1
if include_depth:
nb_channels += 1

if use_big_tiles:
return (big_tile_size, big_tile_size, nb_channels)
else:
return (tile_size, tile_size, nb_channels)

def get_file_inds(self):
# Split used in https://arxiv.org/abs/1606.02585
training_inds = [
self.train_inds = [
(2, 10), (3, 10), (3, 11), (3, 12), (4, 11), (4, 12), (5, 10),
(5, 12), (6, 10), (6, 11), (6, 12), (6, 8), (6, 9), (7, 11),
(7, 12), (7, 7), (7, 9)
]
validation_inds = [
self.validation_inds = [
(2, 11), (2, 12), (4, 10), (5, 11), (6, 7), (7, 10), (7, 8)
]

return training_inds, validation_inds
self.setup(include_ir=False, include_depth=False,
include_ndvi=False, use_big_tiles=False)

def setup(self, include_ir=False, include_depth=False,
include_ndvi=False, use_big_tiles=False):
self.include_ir = include_ir
self.include_depth = include_depth
self.include_ndvi = include_ndvi

self.red_ind = 0
self.green_ind = 1
self.blue_ind = 2
self.ir_ind = 3
self.depth_ind = 4
self.ndvi_ind = 5

self.rgb_input_inds = [self.red_ind, self.green_ind, self.blue_ind]
self.input_inds = list(self.rgb_input_inds)
if include_ir:
self.input_inds.append(self.ir_ind)
if include_depth:
self.input_inds.append(self.depth_ind)
if include_ndvi:
self.input_inds.append(self.ndvi_ind)

self.output_inds = [6]
self.output_mask_inds = [7]

self.nb_channels = len(self.input_inds)
self.input_shape = (self.small_tile_size, self.small_tile_size, self.nb_channels)
if use_big_tiles:
self.input_shape = (self.big_tile_size, self.big_tile_size, self.nb_channels)

def get_dataset_info(dataset):
if dataset == POTSDAM:
Expand Down
22 changes: 21 additions & 1 deletion src/model_training/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import numpy as np
import rasterio

from .settings import label_keys, nb_labels
from .settings import label_keys

nb_labels = len(label_keys)


def _makedirs(path):
Expand Down Expand Up @@ -78,3 +80,21 @@ def one_hot_to_label_batch(one_hot_batch):

def one_hot_to_rgb_batch(one_hot_batch):
return label_to_rgb_batch(one_hot_to_label_batch(one_hot_batch))


def safe_divide(a, b):
"""
Avoid divide by zero
http://stackoverflow.com/questions/26248654/numpy-return-0-with-divide-by-zero
"""
with np.errstate(divide='ignore', invalid='ignore'):
c = np.true_divide(a,b)
c[c == np.inf] = 0
c = np.nan_to_num(c)
return c


def compute_ndvi(red, ir):
ndvi = safe_divide((ir - red), (ir + red))
ndvi = np.expand_dims(ndvi, axis=3)
return ndvi

0 comments on commit df87233

Please sign in to comment.