[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joconnor-ml/osm-ai-tools/blob/master/notebooks/train_resisc45_resnet.ipynb)

This notebook trains a Tensorflow 2 ResNet-50 model on the RESISC-45 aerial imagery classification dataset. The [TF-Hub model](https://tfhub.dev/google/remote_sensing/resisc45-resnet50/1) trained on this dataset is currently only in TF1 format, which does not allow for fine-tuning in TF2.

The RESISC dataset has been shown to transfer well to other remote sensing datasets, see
https://arxiv.org/abs/1911.06721. The linked paper forms the basis of the transfer-learning
methodology used here.

Manual download is required for the RESISC dataset, from [OneDrive](https://onedrive.live.com/?authkey=%21AHHNaHIlzp%5FIXjs&cid=5C5E061130630A68&id=5C5E061130630A68%21107&parId=5C5E061130630A68%21112&action=defaultclick)
(click starts download).


In [None]:
#@title Imports, Detect GPU
import math
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.keras as tfk

# detect GPU -- a GPU-enabled colab is recommended
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
# NB: RESISC-45 dataset requires manual download, see https://www.tensorflow.org/datasets/catalog/resisc45
!unzip -q PATH/TO/resisc45.zip

In [None]:
_LABELS = [
    'airplane', 'airport', 'baseball_diamond', 'basketball_court', 'beach',
    'bridge', 'chaparral', 'church', 'circular_farmland', 'cloud',
    'commercial_area', 'dense_residential', 'desert', 'forest', 'freeway',
    'golf_course', 'ground_track_field', 'harbor', 'industrial_area',
    'intersection', 'island', 'lake', 'meadow', 'medium_residential',
    'mobile_home_park', 'mountain', 'overpass', 'palace', 'parking_lot',
    'railway', 'railway_station', 'rectangular_farmland', 'river', 'roundabout',
    'runway', 'sea_ice', 'ship', 'snowberg', 'sparse_residential', 'stadium',
    'storage_tank', 'tennis_court', 'terrace', 'thermal_power_station',
    'wetland'
]

table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(_LABELS, range(len(_LABELS))),
    default_value=-1
)


In [None]:
def parse_image(filename):
    # Reads an image from a file, decodes it into a dense tensor, and resizes it
    # to a fixed shape.
    parts = tf.strings.split(filename, os.sep)
    label = parts[-2]

    image = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [256, 256]) * 255 - 127.5
    return image, table.lookup(label)

ds = tf.data.Dataset.list_files("NWPU-RESISC45/*/*").shuffle(31500)
test_ds = ds.take(6300).map(parse_image).batch(128).prefetch(tf.data.experimental.AUTOTUNE)
train_ds = ds.skip(6300).map(parse_image).batch(128).prefetch(tf.data.experimental.AUTOTUNE)

for image_batch, label_batch in train_ds.take(1):
    pass

In [None]:
plt.hist(image_batch.numpy().flatten())

In [None]:
def plot_one(img):
    plt.imshow((img+127.5).clip(0, 255).astype(np.uint8))

i = 6
print(F"Label = {label_batch.numpy()[i]}")
plot_one(image_batch[i].numpy())

In [None]:
with tf.device(device_name):
    data_augmentation = tfk.Sequential([
        tfk.layers.experimental.preprocessing.RandomFlip(),
        tfk.layers.experimental.preprocessing.RandomContrast(0.1),
        tfk.layers.experimental.preprocessing.Resizing(256,256),
        tfk.layers.experimental.preprocessing.RandomTranslation(0.1, 0.1),
        tfk.layers.experimental.preprocessing.RandomRotation(2*math.pi),
        tfk.layers.experimental.preprocessing.CenterCrop(224,224),
    ])

    inputs = tfk.Input(shape=image_batch[0].shape)
    augmented_model = data_augmentation(inputs)
    base_model = tfk.applications.ResNet50(include_top=False, input_shape=(None, None, 3), weights="imagenet")
    base_model.trainable = True
    base_model_out = base_model(augmented_model)
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()(base_model_out)
    prediction_layer = tf.keras.layers.Dense(45, activation='softmax')(global_average_layer)
    final_model = tf.keras.models.Model(inputs=inputs, outputs=prediction_layer)

opt = tfk.optimizers.SGD(
    lr=1e-3, momentum=0.9
)

def scheduler(epoch):
    if epoch < 5:
        return (0.001 / 5) * (epoch+1)
    elif epoch >= 5 and epoch < 20:
        return 0.001
    elif epoch >= 20 and epoch < 35:
        return 0.0001
    elif epoch >= 35 and epoch < 50:
        return 0.00001
    else:
        return 0.000001

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

final_model.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=['accuracy'])
history = final_model.fit(train_ds, validation_data=test_ds, epochs=35, callbacks=[callback])

In [None]:
# saves to local Colab workspace -- make sure to move it elsewhere!
base_model.save("resisc_224px_rgb_resnet50")