In [1]:
import tensorflow as tf
tf.enable_eager_execution()
print(tf.__version__)

import pprint

1.14.0


In [2]:
FOLDER = 'data/landsat/'
PREFIX = 'data_patches_'

# Specify inputs (Landsat bands) to the model and the response variable.
opticalBands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7']
thermalBands = ['B10', 'B11']
BANDS = opticalBands + thermalBands
RESPONSE = 'avg_rad'
FEATURES = BANDS + [RESPONSE]

# Specify the size and shape of patches expected by the model.
KERNEL_SIZE = 333
KERNEL_SHAPE = [KERNEL_SIZE, KERNEL_SIZE]
COLUMNS = [
    tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES
]
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))

# Sizes of the training and evaluation datasets.
TRAIN_SIZE = 20000

# Specify model training parameters.
BATCH_SIZE = 24
EPOCHS = 2
BUFFER_SIZE = 1000
OPTIMIZER = 'Adam'
LOSS = 'MeanSquaredError'
METRICS = ['RootMeanSquaredError']

In [3]:
def parse_tfrecord(example_proto):
    """The parsing function.
    Read a serialized example into the structure defined by FEATURES_DICT.
    Args:
    example_proto: a serialized Example.
    Returns: 
    A dictionary of tensors, keyed by feature name.
    """
    return tf.io.parse_single_example(example_proto, FEATURES_DICT)


def to_tensor(inputs):
    """Function to convert a dictionary of tensors to a tuple of (inputs, outputs).
    Turn the tensors returned by parse_tfrecord into a stack in HWC shape.
    Args:
    inputs: A dictionary of tensors, keyed by feature name.
    Returns: 
    A dtuple of (inputs, outputs).
    """
    inputsList = [inputs.get(key) for key in FEATURES]
    stacked = tf.stack(inputsList, axis=0)
    # Convert from CHW to HWC
    stacked = tf.transpose(stacked, [1, 2, 0])
    return stacked
    
def to_tuple(input_tensor):
    landsat = input_tensor[:,:,:len(BANDS)]
    light = tf.log(tf.add(input_tensor[:,:,len(BANDS):], 2))
    light = tf.divide(tf.subtract(light, tf.reduce_min(light)),
                      tf.subtract(tf.reduce_max(light), tf.reduce_min(light)))
    return landsat, light

def aug_rotate(x):
    # Rotate 0, 90, 180, 270 degrees
    k = tf.random_uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    return tf.image.rot90(x, k)

def aug_flip(x):
    x = tf.image.random_flip_left_right(x)
    return x

def aug_crop(x):
    x = tf.random_crop(x, size=(256, 256, 10))
    return x


def get_dataset(pattern):
    """Function to read, parse and format to tuple a set of input tfrecord files.
    Get all the files matching the pattern, parse and convert to tuple.
    Args:
    pattern: A file pattern to match in a Cloud Storage bucket.
    Returns: 
    A tf.data.Dataset
    """
    glob = tf.gfile.Glob(pattern)
    dataset = tf.data.TFRecordDataset(glob, compression_type='GZIP')
    dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)
    dataset = dataset.map(to_tensor, num_parallel_calls=5)
    for f in [aug_crop, aug_flip, aug_rotate]:
        dataset = dataset.map(f)
    dataset = dataset.map(to_tuple, num_parallel_calls=5)
    return dataset

In [6]:
def get_training_dataset():
    """Get the preprocessed training dataset
    Returns: 
    A tf.data.Dataset of training data.
    """
    glob = FOLDER + PREFIX + '*'
    dataset = get_dataset(glob)
    dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

training = get_training_dataset()

print(iter(training.take(1)).next())

(<tf.Tensor: id=247, shape=(24, 256, 256, 9), dtype=float32, numpy=
array([[[[0.03565   , 0.0559    , 0.10605   , ..., 0.2218    ,
          0.296     , 0.2505    ],
         [0.0422    , 0.0633    , 0.1081    , ..., 0.2468    ,
          0.2895    , 0.2475    ],
         [0.048     , 0.0701    , 0.1163    , ..., 0.2836    ,
          0.2905    , 0.2485    ],
         ...,
         [0.0543    , 0.0715    , 0.1119    , ..., 0.0582    ,
          0.2635    , 0.2215    ],
         [0.054     , 0.0685    , 0.1144    , ..., 0.0618    ,
          0.2655    , 0.2225    ],
         [0.0514    , 0.0676    , 0.1125    , ..., 0.1284    ,
          0.2705    , 0.2265    ]],

        [[0.0352    , 0.05835   , 0.1091    , ..., 0.2352    ,
          0.2845    , 0.2395    ],
         [0.042     , 0.063     , 0.11375   , ..., 0.2526    ,
          0.2845    , 0.2405    ],
         [0.0487    , 0.07      , 0.1203    , ..., 0.2857    ,
          0.2895    , 0.2485    ],
         ...,
         [0.0535    

In [7]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


In [8]:
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import metrics
from tensorflow.python.keras import optimizers

def conv_block(input_tensor, num_filters):
	encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
	encoder = layers.BatchNormalization()(encoder)
	encoder = layers.Activation('relu')(encoder)
	encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
	encoder = layers.BatchNormalization()(encoder)
	encoder = layers.Activation('relu')(encoder)
	return encoder

def encoder_block(input_tensor, num_filters):
	encoder = conv_block(input_tensor, num_filters)
	encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
	return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters):
	decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
	decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
	decoder = layers.BatchNormalization()(decoder)
	decoder = layers.Activation('relu')(decoder)
	decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
	decoder = layers.BatchNormalization()(decoder)
	decoder = layers.Activation('relu')(decoder)
	decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
	decoder = layers.BatchNormalization()(decoder)
	decoder = layers.Activation('relu')(decoder)
	return decoder

def get_model():
	inputs = layers.Input(shape=[None, None, len(BANDS)]) # 256
	encoder0_pool, encoder0 = encoder_block(inputs, 32) # 128
	encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64) # 64
	encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128) # 32
	encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256) # 16
	encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512) # 8
	center = conv_block(encoder4_pool, 1024) # center
	decoder4 = decoder_block(center, encoder4, 512) # 16
	decoder3 = decoder_block(decoder4, encoder3, 256) # 32
	decoder2 = decoder_block(decoder3, encoder2, 128) # 64
	decoder1 = decoder_block(decoder2, encoder1, 64) # 128
	decoder0 = decoder_block(decoder1, encoder0, 32) # 256
	outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(decoder0)

	model = models.Model(inputs=[inputs], outputs=[outputs])

	model.compile(
		optimizer=optimizers.get(OPTIMIZER), 
		loss=losses.get(LOSS),
		metrics=[metrics.get(metric) for metric in METRICS]
    )
	return model

In [9]:
m = get_model()

m.fit(
    x=training, 
    epochs=EPOCHS, 
    steps_per_epoch=int(TRAIN_SIZE / BATCH_SIZE)
)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x298c18c9308>

In [10]:
m.save_weights('checkpoints/unet_trained')