<a href="https://colab.research.google.com/github/hmlewis-astro/street_network_deep_learning/blob/main/test_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Change to GPU runtime

#### Navigate to "Runtime > Change runtime type > GPU > Save"

# Download Kaggle API credentials
#### **Note**: This is a one-time step and you don’t need to generate the credentials every time you download the dataset.
- Navigate to your Kaggle profile
- Click the "Account" tab
- Scroll down to the "API" section
- Click "Create New API Token"; a file named `kaggle.json` will be download which contains your username and API key

# Upload Kaggle API credentials to Google Colab
#### **Note**: Uploaded files will get deleted when this runtime is recycled.
- Upload the `kaggle.json` file that you just downloaded from Kaggle
- Run the following cell

In [1]:
!pip install -q kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download balraj98/deepglobe-road-extraction-dataset

Downloading deepglobe-road-extraction-dataset.zip to /content
100% 3.79G/3.79G [01:08<00:00, 53.2MB/s]
100% 3.79G/3.79G [01:08<00:00, 59.3MB/s]


In [2]:
!unzip -q /content/deepglobe-road-extraction-dataset.zip -d /content/deep-globe

In [3]:
!rm -rf /content/deepglobe-road-extraction-dataset.zip

In [4]:
#!pip install git+https://github.com/tensorflow/examples.git

Collecting git+https://github.com/tensorflow/examples.git
  Cloning https://github.com/tensorflow/examples.git to /tmp/pip-req-build-549ywo5u
  Running command git clone -q https://github.com/tensorflow/examples.git /tmp/pip-req-build-549ywo5u
Building wheels for collected packages: tensorflow-examples
  Building wheel for tensorflow-examples (setup.py) ... [?25l[?25hdone
  Created wheel for tensorflow-examples: filename=tensorflow_examples-079eae91b01d7666471c9e01dadd031e2c2a00f2_-py3-none-any.whl size=271371 sha256=b2b42a5a14e4404956e05b2f726a941c3ec5e168fb2403037cc38570c1e050f3
  Stored in directory: /tmp/pip-ephem-wheel-cache-bc6mrfi1/wheels/eb/19/50/2a4363c831fa12b400af86325a6f26ade5d2cdc5b406d552ca
Failed to build tensorflow-examples
Installing collected packages: tensorflow-examples
    Running setup.py install for tensorflow-examples ... [?25l[?25hdone
[33m  DEPRECATION: tensorflow-examples was installed using the legacy 'setup.py install' method, because a wheel could not

# Import packages and libraries

In [5]:
import os
import glob
import random
from tqdm import tqdm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import cv2

import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img, img_to_array

from tensorflow_examples.models.pix2pix import pix2pix

tqdm.pandas()


In [6]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


# Get class dictionary

In [7]:
class_dict_path = "/content/deep-globe/class_dict.csv"
class_dict = pd.read_csv(class_dict_path)
class_names = class_dict['name'].tolist()
class_rgb_values = class_dict[['r','g','b']].values.tolist()
class_dict

Unnamed: 0,name,r,g,b
0,road,255,255,255
1,background,0,0,0


# Get metadata

In [8]:
metadata_path = "/content/deep-globe/metadata.csv"
metadata = pd.read_csv(metadata_path)
metadata.head()

Unnamed: 0,image_id,split,sat_image_path,mask_path
0,100034,train,train/100034_sat.jpg,train/100034_mask.png
1,100081,train,train/100081_sat.jpg,train/100081_mask.png
2,100129,train,train/100129_sat.jpg,train/100129_mask.png
3,100703,train,train/100703_sat.jpg,train/100703_mask.png
4,100712,train,train/100712_sat.jpg,train/100712_mask.png


### Get training/validation data (i.e., images with available road masks)


In [9]:
metadata_train = metadata[metadata['split'] == 'train']
metadata_train = metadata_train.drop('split', axis=1)
metadata_train.head()

Unnamed: 0,image_id,sat_image_path,mask_path
0,100034,train/100034_sat.jpg,train/100034_mask.png
1,100081,train/100081_sat.jpg,train/100081_mask.png
2,100129,train/100129_sat.jpg,train/100129_mask.png
3,100703,train/100703_sat.jpg,train/100703_mask.png
4,100712,train/100712_sat.jpg,train/100712_mask.png


In [10]:
metadata_train = metadata_train.sample(frac=1).reset_index(drop=True)

In [11]:
metadata_train.shape

(6226, 3)

### Get test data (i.e., images without available road masks)

Combine the datasets defined (by Kaggle) as "validation" and "test", because the "validation" set does not have road masks, so cannot actually be used for validation

In [12]:
metadata_test = metadata[(metadata['split'] == 'valid') | 
                         (metadata['split'] == 'test')]
metadata_test = metadata_test.drop(['split', 'mask_path'], axis=1)
metadata_test.head()

Unnamed: 0,image_id,sat_image_path
6226,100794,valid/100794_sat.jpg
6227,100905,valid/100905_sat.jpg
6228,102867,valid/102867_sat.jpg
6229,10417,valid/10417_sat.jpg
6230,106553,valid/106553_sat.jpg


In [13]:
metadata_test = metadata_test.sample(frac=1).reset_index(drop=True)

In [14]:
metadata_test.shape

(2344, 2)

In [15]:
data_path = "/content/deep-globe/"


In [16]:
metadata_train["sat_image_path"] = metadata_train["sat_image_path"] \
                                    .apply(lambda x: os.path.join(data_path, x))
metadata_train["mask_path"] = metadata_train["mask_path"] \
                                    .apply(lambda x: os.path.join(data_path, x))


In [17]:
metadata_test["sat_image_path"] = metadata_test["sat_image_path"] \
                                    .apply(lambda x: os.path.join(data_path, x))


In [18]:
metadata_train.head()

Unnamed: 0,image_id,sat_image_path,mask_path
0,820820,/content/deep-globe/train/820820_sat.jpg,/content/deep-globe/train/820820_mask.png
1,279620,/content/deep-globe/train/279620_sat.jpg,/content/deep-globe/train/279620_mask.png
2,876221,/content/deep-globe/train/876221_sat.jpg,/content/deep-globe/train/876221_mask.png
3,880871,/content/deep-globe/train/880871_sat.jpg,/content/deep-globe/train/880871_mask.png
4,873793,/content/deep-globe/train/873793_sat.jpg,/content/deep-globe/train/873793_mask.png


In [19]:
class SatDatClass(tf.keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img_to_array(img).astype(int) / 255
        y = np.zeros((self.batch_size,) + self.img_size,# + (1,),
                     dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            target = load_img(path, target_size=self.img_size)#, 
                           #color_mode="grayscale")
            target = np.array(target).astype(int) / 255
            y[j] = target[:, :, 0]
            #y[j] = np.expand_dims(img, 2)
            #y[j] = np.array(img)
            #y[j] = np.expand_dims(img, 2).astype(int) // 255 #easier for network to interpret numbers in range [0,1]
            # Ground truth labels are 1, 2, 3. 
            # Subtract one to make them 0, 1, 2:
            #y[j] -= 1
        
        return x, y

In [30]:
class SatDatClass(tf.keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y[j] = np.expand_dims(img, 2)
            # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
            y[j] -= 1
        return x, y

In [31]:
train_dict = {'img' : [], 'mask' : []}

def load_data(load_dict=None, input_img_paths=None, target_img_paths=None, image_size=(128, 128)):
    image_names = os.listdir(input_img_paths)
    target_names = []

    for name in image_names:
        name = name.split('_')[0]
        if name not in target_names:
            target_names.append(name)
    
    image_dir = input_img_paths + '/'
    target_dir = target_img_paths + '/'
    
    for i in range (len(image_names)):
        try:
            img = plt.imread(image_dir + target_names[i] + '_sat.jpg') 
            target = plt.imread(target_dir + target_names[i] + '_mask.png')
            
        except:
            continue

        img = cv2.resize(img, image_size)
        target = cv2.resize(target, image_size)

        load_dict['img'].append(img)
        load_dict['mask'].append(target[:,:,0])
        
    return load_dict

In [21]:
def Conv2DBlock(inputs, previous_block_activation, num_filters, kernel_size=3, batch_norm=True):
    
    x = layers.Activation('relu')(inputs)    

    #x = layers.Conv2D(filters=num_filters, 
    x = layers.SeparableConv2D(filters=num_filters, 
                      kernel_size=(kernel_size, kernel_size),
                      #kernel_initializer='he_normal', 
                      padding='same')(x)
    
    if batch_norm:
        x = layers.BatchNormalization()(x)

    x = layers.Activation('relu')(x)    
    
    #x = tf.keras.layers.Conv2D(filters=num_filters, 
    x = layers.SeparableConv2D(filters=num_filters, 
                               kernel_size=(kernel_size, kernel_size),
                               #kernel_initializer='he_normal', 
                               padding='same') (x)
    if batch_norm:
        x = layers.BatchNormalization()(x)
    
    return x


In [22]:
def Conv2DTransposeBlock(inputs, previous_block_activation, num_filters, kernel_size=3, batch_norm=True):
    
    x = layers.Activation('relu')(inputs)    

    x = layers.Conv2DTranspose(filters=num_filters, 
                      kernel_size=(kernel_size, kernel_size),
                      #kernel_initializer='he_normal', 
                      padding='same')(x)
    
    if batch_norm:
        x = layers.BatchNormalization()(x)

    x = layers.Activation('relu')(x)    
    
    x = layers.Conv2DTranspose(filters=num_filters, 
                               kernel_size=(kernel_size, kernel_size),
                               #kernel_initializer='he_normal', 
                               padding='same') (x)
    if batch_norm:
        x = layers.BatchNormalization()(x)
    
    return x

In [36]:
def get_unet_model(img_size, num_classes):
    inputs = tf.keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = tf.keras.Model(inputs, outputs)
    
    return model


In [37]:
tf.keras.backend.clear_session()


In [40]:
img_size = (256,256)
num_classes = 2

#inputs = layers.Input(shape=img_size+(3,))
unet = get_unet_model(img_size, num_classes)

In [41]:
unet.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 128, 128, 32) 896         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 128, 128, 32) 128         conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 128, 128, 32) 0           batch_normalization_15[0][0]     
____________________________________________________________________________________________

In [42]:
batch_size = 32
#batch_size = 16

input_img_paths = sorted(metadata_train["sat_image_path"])
target_img_paths = sorted(metadata_train["mask_path"])

print('Number of training/validation samples:', len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print('Satellite image:', input_path, '|', 'Road mask:', target_path)

Number of training/validation samples: 6226
Satellite image: /content/deep-globe/train/100034_sat.jpg | Road mask: /content/deep-globe/train/100034_mask.png
Satellite image: /content/deep-globe/train/100081_sat.jpg | Road mask: /content/deep-globe/train/100081_mask.png
Satellite image: /content/deep-globe/train/100129_sat.jpg | Road mask: /content/deep-globe/train/100129_mask.png
Satellite image: /content/deep-globe/train/100703_sat.jpg | Road mask: /content/deep-globe/train/100703_mask.png
Satellite image: /content/deep-globe/train/100712_sat.jpg | Road mask: /content/deep-globe/train/100712_mask.png
Satellite image: /content/deep-globe/train/100773_sat.jpg | Road mask: /content/deep-globe/train/100773_mask.png
Satellite image: /content/deep-globe/train/100841_sat.jpg | Road mask: /content/deep-globe/train/100841_mask.png
Satellite image: /content/deep-globe/train/100867_sat.jpg | Road mask: /content/deep-globe/train/100867_mask.png
Satellite image: /content/deep-globe/train/100892_sa

In [44]:
val_samples = int(0.2 * len(input_img_paths))

random.Random(42).shuffle(input_img_paths)
random.Random(42).shuffle(target_img_paths)

train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = SatDatClass(batch_size, img_size, train_input_img_paths, train_target_img_paths)
val_gen = SatDatClass(batch_size, img_size, val_input_img_paths, val_target_img_paths)


In [None]:
unet.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=['accuracy'])

callbacks = [tf.keras.callbacks.EarlyStopping(patience=8, verbose=1, 
                                           restore_best_weights=True),
             tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, 
                                               verbose=1),
             tf.keras.callbacks.ModelCheckpoint("/content/satellite_segmentation.h5", 
                                             save_best_only=True)]

epochs = 10
unet.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks)


Epoch 1/10

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False


In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
OUTPUT_CLASSES = 2

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)


In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]), cmap='Greys_r')
    plt.axis('off')
  plt.show()

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask

def show_predictions(dataset=None, num=1):
  if dataset:
    #for image, mask in dataset.__getitem__(num):
    image, mask = dataset.__getitem__(num)
    pred_mask = model.predict(image[0].reshape(-1,128,128,3))
    pred_mask = create_mask(pred_mask)
    print(pred_mask.shape)
    display([image[0], np.repeat(mask[0].reshape(128,128,1), 3, axis=2), np.repeat(pred_mask[0], 3, axis=2)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
    

In [None]:
x,y = train_gen.__getitem__(0)
x.shape, x[0].shape, x[0].reshape(-1, 128, 128, 3).shape, y.shape, y[0].shape

In [None]:
BATCH_SIZE = 16
IMG_SIZE = (128, 128)
input_img_paths = sorted(metadata_train["sat_image_path"])
target_img_paths = sorted(metadata_train["mask_path"])

# Split into training and validation set
val_samples = int(0.2 * len(input_img_paths))
random.Random(42).shuffle(input_img_paths)
random.Random(42).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = SatDatClass(BATCH_SIZE, IMG_SIZE, 
                        train_input_img_paths, train_target_img_paths)
val_gen = SatDatClass(BATCH_SIZE, IMG_SIZE, 
                      val_input_img_paths, val_target_img_paths)

In [None]:
show_predictions(train_gen)


In [None]:
EPOCHS = 1
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

In [None]:
callbacks = [tf.keras.callbacks.EarlyStopping(patience=8, verbose=1, 
                                           restore_best_weights=True),
             tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, 
                                               verbose=1),
             tf.keras.callbacks.ModelCheckpoint("/content/satellite_segmentation.h5", 
                                             save_best_only=True)]

# Train the model, validate at the end of each epoch
EPOCHS = 5
model.fit(train_gen, epochs=EPOCHS, 
          validation_data=val_gen, callbacks=callbacks)

In [None]:
#val_preds = model.predict(val_gen)
show_predictions(val_gen)

In [None]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask = input_mask
  return input_image, input_mask

In [None]:
def load_image(datapoint, img_size=IMG_SIZE):
  img = plt.imread(datapoint['sat_image_path']) 
  input_image = tf.image.resize(img, img_size)

  mask = plt.imread(datapoint['mask_path']) 
  input_mask = tf.image.resize(mask, img_size)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

In [None]:
VAL_FRAC = 0.2
BATCH_SIZE = 32
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
IMG_SIZE = (256, 256)

# Split into training and validation set
val_samples = int(VAL_FRAC * len(metadata_train))

train_images = metadata_train.iloc[:-val_samples]
train_images = train_images.progress_apply(load_image, axis=1)

val_images = metadata_train.iloc[-val_samples:]
val_images = val_images.progress_apply(load_image, axis=1)


In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same randomn changes.
    self.augment_inputs = preprocessing.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = preprocessing.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels
    

In [None]:
train_images[0]

In [None]:
train_images_test = pd.DataFrame([train_images[:][0], train_images[:][1]],columns=['sat_image', 'mask'])
train_images_test.head(1)
#train_images_test = tf.convert_to_tensor(train_images_test)

In [None]:
train_batches = (
    train_images
    #.cache()
    #.shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

val_batches = val_images.batch(BATCH_SIZE)

# Sample satellite images and road masks

In [None]:
img_idx = np.random.choice(range(len(metadata_train)), size=3)

for i in img_idx:
  sat_img = cv2.cvtColor(cv2.imread(metadata_train['sat_image_path'][i]), 
                         cv2.COLOR_BGR2RGB)
  sat_mask = cv2.cvtColor(cv2.imread(metadata_train['mask_path'][i]), 
                          cv2.COLOR_BGR2RGB) / 255.

  fig, ax = plt.subplots(1, 2, figsize=(10,5))

  fig.suptitle('Image ID: {}'.format(metadata_train['image_id'][i]), 
               fontsize=14)
  ax[0].imshow(sat_img)
  im = ax[1].imshow(sat_mask, cmap='Greys', vmin=0.0, vmax=1.0)

  fig.subplots_adjust(right=0.8)
  cbar_ax = fig.add_axes([0.85, 0.20, 0.05, 0.60])
  fig.colorbar(im, cax=cbar_ax)
  

In [None]:
img_dim, img_dim, img_depth = sat_mask.shape 
img_dim, img_dim, img_depth


In [None]:
#img_size = (1024, 1024)
img_size = (256, 256)
#num_classes = 2
num_classes = 1
batch_size = 32
#batch_size = 16

input_img_paths = sorted(metadata_train["sat_image_path"])
target_img_paths = sorted(metadata_train["mask_path"])

print('Number of training/validation samples:', len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print('Satellite image:', input_path, '|', 'Road mask:', target_path)


# Load and vectorize batches of data

In [None]:
train_dict = {'img' : [], 'mask' : []}

def load_data(load_dict=None, input_img_paths=None, target_img_paths=None, image_size=(128, 128)):
    image_names = os.listdir(input_img_paths)
    target_names = []

    for name in image_names:
        name = name.split('_')[0]
        if name not in target_names:
            target_names.append(name)
    
    image_dir = input_img_paths + '/'
    target_dir = target_img_paths + '/'
    
    for i in range (len(image_names)):
        try:
            img = plt.imread(image_dir + target_names[i] + '_sat.jpg') 
            target = plt.imread(target_dir + target_names[i] + '_mask.png')
            
        except:
            continue

        img = cv2.resize(img, image_size)
        target = cv2.resize(target, image_size)

        load_dict['img'].append(img)
        load_dict['mask'].append(target[:,:,0])
        
    return load_dict


In [None]:
class SatDatClass(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            #img = load_img(path, target_size=self.img_size)
            img = plt.imread(path)
            img = cv2.resize(img, self.img_size)
            x[j] = img_to_array(img).astype(int) / 255
        y = np.zeros((self.batch_size,) + self.img_size ,#+ (1,),
                     dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            #img = load_img(path, target_size=self.img_size, 
                           #color_mode="grayscale")
            target = plt.imread(path)
            target = cv2.resize(target, self.img_size)
            y[j] = target[:, :, 0]
            #y[j] = np.expand_dims(img, 2)
            #y[j] = np.array(img)
            #y[j] = np.expand_dims(img, 2).astype(int) // 255 #easier for network to interpret numbers in range [0,1]
            # Ground truth labels are 1, 2, 3. 
            # Subtract one to make them 0, 1, 2:
            #y[j] -= 1
        
        return x, y


# U-Net Xception-style model

In [None]:
def Conv2DBlock(inputs, previous_block_activation, num_filters, kernel_size=3, batch_norm=True):
    
    x = layers.Activation('relu')(x)    

    #x = layers.Conv2D(filters=num_filters, 
    x = layers.SeparableConv2D(filters=num_filters, 
                      kernel_size=(kernel_size, kernel_size),
                      #kernel_initializer='he_normal', 
                      padding='same')(inputs)
    
    if doBatchNorm:
        x = layers.BatchNormalization()(x)

    x = layers.Activation('relu')(x)    
    #x = layers.Activation('relu')(x)
    
    #x = tf.keras.layers.Conv2D(filters=num_filters, 
    x = layers.SeparableConv2D(filters=num_filters, 
                               kernel_size=(kernel_size, kernel_size),
                               #kernel_initializer='he_normal', 
                               padding='same') (x)
    if doBatchNorm:
        x = layers.BatchNormalization()(x)

    #x = layers.MaxPooling2D(3, strides=2, padding="same")(x)    
    #x = layers.Activation('relu')(x)

    #residual = layers.Conv2D(num_filters, 1, strides=2, padding="same")(
            #previous_block_activation
        #)
    #x = layers.add([x, residual])  # Add back residual
    #previous_block_activation = x
    
    return x#, residual, previous_block_activation


In [None]:
def Conv2DTransposeBlock(inputs, previous_block_activation, num_filters, kernel_size=3, batch_norm=True):
    
    x = layers.Activation('relu')(x)    

    #x = layers.Conv2D(filters=num_filters, 
    x = layers.Conv2DTranspose(filters=num_filters, 
                      kernel_size=(kernel_size, kernel_size),
                      #kernel_initializer='he_normal', 
                      padding='same')(inputs)
    
    if doBatchNorm:
        x = layers.BatchNormalization()(x)

    x = layers.Activation('relu')(x)    
    #x = layers.Activation('relu')(x)
  
    #x = layers.Conv2D(filters=num_filters, 
    x = layers.Conv2DTranspose(filters=num_filters, 
                      kernel_size=(kernel_size, kernel_size),
                      #kernel_initializer='he_normal', 
                      padding='same')(inputs)
    
    if doBatchNorm:
        x = layers.BatchNormalization()(x)

    #x = layers.UpSampling2D(2)(x)
    
    #x = layers.MaxPooling2D(3, strides=2, padding="same")(x)    
    #x = layers.Activation('relu')(x)

    #residual = layers.Conv2D(num_filters, 1, strides=2, padding="same")(
            #previous_block_activation
        #)
    #x = layers.add([x, residual])  # Add back residual
    #previous_block_activation = x
    
    return x#, residual, previous_block_activation

In [None]:
def get_unet_model(inputs, previous_block_activation, num_filters=16, dropout=0.1, batch_norm=True):
    # encoder path
    c1 = Conv2DBlock(inputs, previous_block_activation, num_filters*1, kernel_size=3, batch_norm=batch_norm)
    p1 = layers.MaxPooling2D(3, strides=2, padding="same")(c1)
    # Project residual
    residual = layers.Conv2D(num_filters*1, 1, strides=2, padding="same")(previous_block_activation)
    p1 = layers.add([p1, residual])  # Add back residual
    previous_block_activation = p1  # Set aside next residual
    #p1 = tf.keras.layers.Dropout(droupouts)(p1)
    
    c2 = Conv2DBlock(p1, previous_block_activation, num_filters*2, kernel_size=3, batch_norm=batch_norm)
    p2 = layers.MaxPooling2D(3, strides=2, padding="same")(c2)
    # Project residual
    residual = layers.Conv2D(num_filters*2, 1, strides=2, padding="same")(previous_block_activation)
    p2 = layers.add([p2, residual])  # Add back residual
    previous_block_activation = p2  # Set aside next residual
    #p2 = tf.keras.layers.Dropout(droupouts)(p2)
    
    c3 = Conv2DBlock(p2, previous_block_activation, num_filters*4, kernel_size=3, batch_norm=batch_norm)
    p3 = layers.MaxPooling2D(3, strides=2, padding="same")(c3)
    # Project residual
    residual = layers.Conv2D(num_filters*4, 1, strides=2, padding="same")(previous_block_activation)
    p3 = layers.add([p3, residual])  # Add back residual
    previous_block_activation = p3  # Set aside next residual
    #p3 = tf.keras.layers.Dropout(droupouts)(p3)
    
    c4 = Conv2DBlock(p3, previous_block_activation, num_filters*8, kernel_size=3, batch_norm=batch_norm)
    p4 = layers.MaxPooling2D(3, strides=2, padding="same")(c4)
    # Project residual
    residual = layers.Conv2D(num_filters*2, 1, strides=2, padding="same")(previous_block_activation)
    p4 = layers.add([p4, residual])  # Add back residual
    previous_block_activation = p4  # Set aside next residual
    #p4 = tf.keras.layers.Dropout(droupouts)(p4)
    
    c5 = Conv2DBlock(p4, previous_block_activation, num_filters*16, kernel_size=3, batch_norm=batch_norm)
    '''
    # defining decoder path
    u6 = Conv2DTransposeBlock(num_filters*8, kernel_size=3, strides = (2, 2), padding = 'same')(u6)
    # Project residual
    residual = layers.UpSampling2D(2)(previous_block_activation)
    residual = layers.Conv2D(num_filters*8, 1, padding="same")(residual)
    c6 = layers.add([u6, residual])  # Add back residual
    previous_block_activation = x  # Set aside next residual
    #u6 = layers.concatenate([u6, c4])
    #u6 = layers.Dropout(droupouts)(u6)
    #c6 = Conv2dBlock(u6, numFilters * 8, kernelSize = 3, doBatchNorm = doBatchNorm)
    
    u7 = tf.keras.layers.Conv2DTranspose(numFilters*4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3])
    u7 = tf.keras.layers.Dropout(droupouts)(u7)
    c7 = Conv2dBlock(u7, numFilters * 4, kernelSize = 3, doBatchNorm = doBatchNorm)
    
    u8 = tf.keras.layers.Conv2DTranspose(numFilters*2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2])
    u8 = tf.keras.layers.Dropout(droupouts)(u8)
    c8 = Conv2dBlock(u8, numFilters * 2, kernelSize = 3, doBatchNorm = doBatchNorm)
    
    u9 = tf.keras.layers.Conv2DTranspose(numFilters*1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1])
    u9 = tf.keras.layers.Dropout(droupouts)(u9)
    c9 = Conv2dBlock(u9, numFilters * 1, kernelSize = 3, doBatchNorm = doBatchNorm)
    
    output = tf.keras.layers.Conv2D(1, (1, 1), activation = 'sigmoid')(c9)
    model = tf.keras.Model(inputs = [inputImage], outputs = [output])
    '''
    return model


In [None]:
inputs = layers.Input(shape=img_size+(3,))

# Entry block
x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)

previous_block_activation = x  # Set aside residual

unet = get_unet_model(inputs, previous_block_activation, num_filters=16)


In [None]:
def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))
    print(inputs.shape)

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", 
                            padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    
    return model


In [None]:
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

In [None]:
model = get_model(img_size, num_classes)
model.summary()


In [None]:
# Split into training and validation set
val_samples = int(0.2 * len(input_img_paths))
random.Random(42).shuffle(input_img_paths)
random.Random(42).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = SatDatClass(batch_size, img_size, 
                        train_input_img_paths, train_target_img_paths)
val_gen = SatDatClass(batch_size, img_size, 
                      val_input_img_paths, val_target_img_paths)

In [None]:
val_gen.__len__

In [None]:
x,y = train_gen.__getitem__(0)
print(x.min(), x.max())
fig, ax = plt.subplots(1, 2, figsize=(10,5))

fig.suptitle('Image ID: {}'.format(metadata_train['image_id'][0]), 
               fontsize=14)
ax[0].imshow(np.array(x[4]))
ax[1].imshow(np.array(y[4]), cmap='Greys');

In [None]:
#model.compile(optimizer="rmsprop", 
              #loss="sparse_categorical_crossentropy", 
              #metrics=['accuracy'],)#[keras.metrics.SparseCategoricalAccuracy()])#
                       #,keras.metrics.Precision(),
                       #keras.metrics.Recall(),
                       #keras.metrics.MeanIoU(num_classes=num_classes)],)
model.compile(#optimizer='rmsprop', 
              optimizer='adam', 
              #loss='sparse_categorical_crossentropy', 
              loss='binary_crossentropy', 
              metrics=['accuracy'],)

callbacks = [keras.callbacks.EarlyStopping(patience=8, verbose=1, 
                                           restore_best_weights=True),
             keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, 
                                               verbose=1),
             keras.callbacks.ModelCheckpoint("/content/satellite_segmentation.h5", 
                                             save_best_only=True)]

# Train the model, validate at the end of each epoch
epochs = 1
model.fit(train_gen, epochs=epochs, 
          validation_data=val_gen, callbacks=callbacks)

In [None]:
#val_gen = SatDatClass(batch_size, img_size, 
                      #val_input_img_paths, val_target_img_paths)
val_preds = model.predict(val_gen)


In [None]:
i = 18

val_preds[i].shape, val_preds[i].min(), val_preds[i].max()

In [None]:
i = 18

sat_img = cv2.cvtColor(cv2.imread(val_input_img_paths[i]), 
                         cv2.COLOR_BGR2RGB)
sat_mask = cv2.cvtColor(cv2.imread(val_target_img_paths[i]), 
                          cv2.COLOR_BGR2RGB)[:,:,0]

fig, ax = plt.subplots(1, 3, figsize=(15,5))

fig.suptitle('Image ID: {}'.format(metadata_train['image_id'][i]), 
               fontsize=14)
ax[0].imshow(sat_img)
ax[1].imshow(sat_mask, cmap='Greys')

mask = val_preds[i][:, :, 0]# < 0.9 #np.argmax(val_preds[i], axis=-1)
mask = np.expand_dims(mask, axis=-1)
ax[2].imshow(keras.preprocessing.image.array_to_img(mask), cmap='Greys');


In [None]:
x = val_preds[:][:, :, 0]

x.min(), x.max()

In [None]:
img = load_img(metadata_train['sat_image_path'][0], target_size=img_size)
msk = load_img(metadata_train['mask_path'][0], target_size=img_size,
                           color_mode="grayscale")
y = np.expand_dims(msk, 2) / 255.
np.array(img).shape, np.array(msk).shape, np.array(msk).min(), np.array(msk).max(), y.shape, y.min(), y.max()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,5))

fig.suptitle('Image ID: {}'.format(metadata_train['image_id'][0]), 
               fontsize=14)
ax[0].imshow(img)
ax[1].imshow(np.array(msk), cmap='Greys');
