In [1]:
import tensorflow as tf
import apache_beam as beam
from typing import Dict, List, Tuple
import numpy as np

import distinctipy

In [32]:
# # To achieve good shuffling, each label should be in its own file
# # The function will take one sample from each label (i.e. file) and then shuffle them
# def load_dataset(filenames: List[str], deserialization_function, number_of_files):
#     parallel_reads = len(filenames)
    
#     files = tf.data.Dataset.from_tensor_slices(filenames)

#     ds = files.interleave(lambda x: tf.data.TFRecordDataset(x), 
#                           cycle_length=len(filenames), 
#                           block_length=1,
#                           num_parallel_calls=tf.data.experimental.AUTOTUNE,
#                           deterministic=False)

#     ds = ds.flat_map(lambda x : deserialization_function(x, number_of_files))  # parse the record
#     ds = ds.shuffle(parallel_reads, reshuffle_each_iteration=True)

#     return ds

In [33]:
# palette = [
# '000000', # 0 No Data
# 'a6cee3', # 1 Water
# '1f78b4', # 2 Opaque Clouds
# 'b2df8a', # 3 Trees and Shrubs
# '33a02c', # 4 Built surface
# 'fb9a99', # 5 Bridges and dams
# 'e31a1c', # 6 Grass
# 'fdbf6f', # 7 Plant/Ground Mix
# 'ff7f00', # 8 Crops (other than Palm Plantations)
# 'cab2d6', # 9 Palm Plantations
# '6a3d9a', # 10 Flooded Vegetation
# 'ffff99', # 11 Bare Ground and Sand
# 'b15928', # 12 Snow and Ice
# '000000'  # 13 Unknown
# ]

CLASSIFICATIONS = {
  "No data": '000000',
  "Water": 'a6cee3',
  "Opaque Clouds": '1f78b4',
  "Trees and Shrubs": 'b2df8a',
  "Built surface": '33a02c',
  "Bridges and dams": 'fb9a99',
  "Grass": 'e31a1c',
  "Plant/Ground Mix": 'fdbf6f',
  "Crops (other than Palm Plantations)": 'ff7f00',
  "Palm Plantations": 'cab2d6',
  "Flooded Vegetation": '6a3d9a',
  "Bare Ground and Sand": 'ffff99',
  "Snow and Ice": 'b15928',
  "Unknown": '000000'
}

NUM_CLASSES = len(CLASSIFICATIONS)

# Each tile is from 30cm WV3 satellite imagery, is 1024px x 1024px and is labelled twice.
SCALE = 0.3
PATCH_SIZE = 512

INPUT_BANDS = ['R', 'G', 'B']
LABELS_NAMES = ['label_1', 'label_2']
FEATURES = INPUT_BANDS + LABELS_NAMES

IMG_SIZE = [PATCH_SIZE, PATCH_SIZE, len(INPUT_BANDS)]

# Specify the size and shape of patches expected by the model.
KERNEL_SHAPE = [PATCH_SIZE, PATCH_SIZE]
COLUMNS = [
  tf.io.FixedLenFeature(shape=KERNEL_SHAPE, dtype=tf.float32) for k in FEATURES
]
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))

def parse_tfrecord(example_proto):
  """The parsing function.
  Read a serialized example into the structure defined by FEATURES_DICT.
  Args:
    example_proto: a serialized Example.
  Returns:
    A dictionary of tensors, keyed by feature name.
  """
  return tf.io.parse_single_example(example_proto, FEATURES_DICT)


def to_tuple(tensor):
  """Function to convert a tensor to a tuple of (inputs, outputs).
  Args:
    tensor: A stacked tensor, with label last.
  Returns:
    A tuple of (inputs, outputs).
  """
  # return tensor[:,:,:len(BANDS)], tensor[:,:,len(BANDS):]
  return {
    "R": tensor[:, :, 0],
    "G": tensor[:, :, 1],
    "B": tensor[:, :, 2]
  }


def flatten_patches(inputs):
  """Function to convert a dictionary of tensors to two stacked 
    tensors in HWC shape.
  Args:
    inputs: A dictionary of tensors, keyed by feature name.
  Returns:
    A tf.data.Dataset with two examaples in it.
  """
  # inputsList = [inputs.get(key) for key in BANDS]
  # label_1 = [inputs.get(LABELS_NAMES[0])]
  # label_2 = [inputs.get(LABELS_NAMES[1])]
  # stack1 = tf.stack(inputsList + label_1, axis=0)
  # stack2 = tf.stack(inputsList + label_2, axis=0)
  # # Convert from CHW to HWC
  # return tf.data.Dataset.from_tensor_slices([
  #   tf.transpose(stack1, [1, 2, 0]),
  #   tf.transpose(stack2, [1, 2, 0]),
  # ])
  
  bands = {key: inputs.get(key) for key in INPUT_BANDS}
  
  # return tf.data.Dataset.from_tensor_slices([{**bands, **{label_name: inputs.get(label_name)}} for label_name in LABELS_NAMES])
  # return tf.data.Dataset.from_tensor_slices([{"a": "A"}, {"b": "B"}]) #tf.data.Dataset.from_tensor_slices((bands, bands))
  # return [{**bands, **{label_name: inputs.get(label_name)}} for label_name in LABELS_NAMES]
  return {**bands, **{LABELS_NAMES[0]: inputs.get(LABELS_NAMES[0])}}

def preprocess(values: Dict[str, tf.Tensor]) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
  # Create a dictionary of band values.
  inputs = {name: values[name] for name in INPUT_BANDS}

  # Convert the labels into one-hot encoded vectors.
  outputs = tf.one_hot(tf.cast(values["label_1"], tf.uint8), len(CLASSIFICATIONS))
  return (inputs, outputs)

def get_dataset(glob):
  """"""
  dataset = tf.data.TFRecordDataset(glob, compression_type='GZIP')
  dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)
  # dataset = dataset.map(flatten_patches)
  # dataset = dataset.map(to_tuple, num_parallel_calls=5)
  dataset = dataset.map(preprocess)

  return dataset


def get_datasets(pattern):
    """"""
    glob = tf.io.gfile.glob(pattern)
    size = len(glob)
    print(f"size: {size}")
    train_size = int(0.8*size)
    shuffled = tf.random.shuffle(glob)
    train_files = shuffled[:train_size]
    test_files = shuffled[train_size:]
    training = get_dataset(train_files) #.take(24)
    training = training.batch(16)
    # training = training.shuffle(2048).repeat()
    testing = get_dataset(test_files) #.take(24)
    testing = testing.batch(16)
    return training, testing

In [34]:
pattern = "gs://ivanmkc-palm-data-2/high-res-patches/labels_*.tfrecord.gz"

training_dataset, testing_dataset = get_datasets(pattern)

size: 2014


In [35]:
# len(list(testing_dataset.as_numpy_iterator()))

In [36]:
# for item in training_dataset.take(1).as_numpy_iterator():
#     print(item)

# items = list(testing_dataset.take(2).as_numpy_iterator())
# item = items[0]

# image = np.stack([item['R'], item['G'], item['B']], 2)
# mask = np.stack([item['label_1']], 2)

# np.unique(mask)

In [37]:
# import matplotlib.pyplot as plt

# def display_image_mask(image, mask):
#   plt.figure(figsize=(15, 15))

#   title = ['Input Image', 'True Mask']

#   display_list = [image, mask]

#   for i in range(len(display_list)):
#     plt.subplot(1, len(display_list), i+1)
#     plt.title(title[i])
#     plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
#     plt.axis('off')
#   plt.show()

In [38]:
# display_image_mask(image, mask)

In [39]:
# for name, value in item.items():
#     print(f"{name}: {value.dtype.name} {value.shape}")

In [40]:
# np.unique(item['label_2'])

In [41]:
import tensorflow as tf

input_layers = {
    name: tf.keras.Input(shape=(None, None), name=name)
    for name in INPUT_BANDS
}
input_layers

{'R': <KerasTensor: shape=(None, None, None) dtype=float32 (created by layer 'R')>,
 'G': <KerasTensor: shape=(None, None, None) dtype=float32 (created by layer 'G')>,
 'B': <KerasTensor: shape=(None, None, None) dtype=float32 (created by layer 'B')>}

In [42]:
# Adapt the Normalization layer with the training dataset.
normalization = tf.keras.layers.Normalization(name="Normalize")
normalization.adapt(
    training_dataset.map(
        lambda inputs, _: tf.stack([inputs[name] for name in INPUT_BANDS], axis=-1)
    )
)

# Define the Fully Convolutional Network.
fcn_model = tf.keras.Sequential([
    tf.keras.Input(shape=(None, None, len(INPUT_BANDS)), name="Inputs"),
    normalization,
    tf.keras.layers.Conv2D(filters=32, kernel_size=5, activation="relu", name="Conv2D"),
    tf.keras.layers.Conv2DTranspose(filters=16, kernel_size=5, activation="relu", name="Deconv2D"),
    tf.keras.layers.Dense(len(CLASSIFICATIONS), activation="softmax", name="LandCover"),
], name="FullyConvolutionalNetwork")

fcn_model.summary()
tf.keras.utils.plot_model(fcn_model, show_shapes=True)

Model: "FullyConvolutionalNetwork"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Normalize (Normalization)    (None, None, None, 3)     7         
_________________________________________________________________
Conv2D (Conv2D)              (None, None, None, 32)    2432      
_________________________________________________________________
Deconv2D (Conv2DTranspose)   (None, None, None, 16)    12816     
_________________________________________________________________
LandCover (Dense)            (None, None, None, 14)    238       
Total params: 15,493
Trainable params: 15,486
Non-trainable params: 7
_________________________________________________________________
('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')


In [43]:
# Define the input dictionary layers.
input_layers = {
    name: tf.keras.Input(shape=(None, None, 1), name=name)
    for name in INPUT_BANDS
}

# Model wrapper that takes an input dictionary and feeds it to the FCN.
inputs = tf.keras.layers.concatenate(input_layers.values(), name="Stack")
model = tf.keras.Model(input_layers, fcn_model(inputs), name="land_cover_classifier")

tf.keras.utils.plot_model(model)

('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')


In [44]:
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)

In [45]:
# Train the model.
model.fit(
    training_dataset.shuffle(10),
    validation_data=testing_dataset,
    epochs=15,
)

# Save it as files.
model.save("model")

Epoch 1/15


2022-05-13 08:55:29.702979: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:175] Filling up shuffle buffer (this may take a while): 9 of 10
2022-05-13 08:55:30.006484: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:228] Shuffle buffer filled.


   1145/Unknown - 7698s 7s/step - loss: 1.9997 - accuracy: 0.3472

KeyboardInterrupt: 