# SRCNN - 2014

https://paperswithcode.com/paper/image-super-resolution-using-deep

This is a relatively simple model. Results on the satellite imagery is currently not very impressive and it is not certain that it beats bicubic convolution when compared with PSNR and SSIM as metrics. However, I have done minimal tuning and have not gotten rid of the border effect that SRCNN introduces out of the box.

Primary use of the model to date has been to validate the data generation pipeline.

# Imports and setup

In [None]:
import numpy as np
import random
import pandas as pd
import os
import matplotlib.pyplot as plt
import pathlib
import rasterio
import rasterio.plot
import geopandas
import pickle

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal

AUTOTUNE = tf.data.experimental.AUTOTUNE

# Check GPUs:",
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            # Prevent TensorFlow from allocating all memory of all GPUs:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

In [None]:
# Path to location where individual satellite images are located
DATA_PATH = 'data/toulon-laspezia' 

# Paths to the tiled imagery
DATA_PATH_TILES = 'data/toulon-laspezia-tiles'
DATA_PATH_TILES_TRAIN = str(DATA_PATH_TILES + '/train')
DATA_PATH_TILES_VAL = str(DATA_PATH_TILES + '/val')
DATA_PATH_TILES_TEST = str(DATA_PATH_TILES + '/test')

# Loading the metadata geopandas df produced when generating tiles to disk
with open(str(DATA_PATH_TILES + '/metadata_tile_allocation.pickle'), 'rb') as file:
    meta = pickle.load(file)
    
N_IMAGES = len(meta.index)
N_IMAGES_TRAIN = meta['train_val_test'].value_counts()['train']
N_IMAGES_VAL = meta['train_val_test'].value_counts()['val']
N_IMAGES_TEST = meta['train_val_test'].value_counts()['test']
print('Number of satellite images - train:', N_IMAGES_TRAIN, 
      ', val:', N_IMAGES_VAL, ', test:', N_IMAGES_TEST)

N_TILES_TRAIN = meta.loc[meta['train_val_test'] == 'train', 'n_tiles'].sum()
N_TILES_VAL = meta.loc[meta['train_val_test'] == 'val', 'n_tiles'].sum()
N_TILES_TEST = meta.loc[meta['train_val_test'] == 'test', 'n_tiles'].sum()
print('Number of satellite image tiles - train:', N_TILES_TRAIN, 
      ', val:', N_TILES_VAL, ', test:', N_TILES_TEST)

PAN_WIDTH, PAN_HEIGHT = (384, 384)

SR_FACTOR = 4
MS_WIDTH, MS_HEIGHT = (int(PAN_WIDTH/SR_FACTOR), int(PAN_HEIGHT/SR_FACTOR))

# Should be derived automatically, but added here as a quick fix
MS_BANDS = 8

BATCH_SIZE = 16
EPOCHS = 2

# Tensorflow tile generator from disk

Using `tf.data` API to construct a `Dataset` generator reading and preprocessing tiles from disk.

Best practices from https://www.tensorflow.org/guide/data, including multithreading, prefetching, shuffling, batching and caching.

`rasterio` is used to read geotiffs. the `decode_geotiff()` function is run inside a `tf.py_function()` wrapper ensuring that this function is also run in the computational graph.

In [None]:
def decode_geotiff(image_path):
    image_path = pathlib.Path(image_path.numpy().decode())
    with rasterio.open(image_path) as src:
        img = src.read()
    img = rasterio.plot.reshape_as_image(img) # from channels first to channels last
    return img

def preprocess_images(img, ms_or_pan):
    if ms_or_pan == 'ms':
        h, w = MS_HEIGHT, MS_WIDTH
    elif ms_or_pan == 'pan':
        h, w = PAN_HEIGHT, PAN_WIDTH
        
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.reshape(img, [h, w, -1]) # To avoid issue with extra dimension
    return img

def upsample_images(ms_img, pan_img):
    ms_img = tf.ensure_shape(ms_img, [MS_HEIGHT, MS_WIDTH, MS_BANDS])
    ms_img = tf.image.resize(ms_img, [PAN_HEIGHT, PAN_WIDTH])
    return ms_img, pan_img

def process_path(ms_tile_path):
    img_string_UID = tf.strings.split(ms_tile_path, os.sep)[-3]
    tile_UID = tf.strings.split(tf.strings.split(ms_tile_path, os.sep)[-1], '.')[0]
    
    ms_img = tf.py_function(decode_geotiff, [ms_tile_path], [tf.int16])
    pan_tile_path = tf.strings.regex_replace(ms_tile_path, '\\\\ms\\\\', '\\\\pan\\\\')
    pan_img = tf.py_function(decode_geotiff, [pan_tile_path], [tf.int16])
    
    ms_img = preprocess_images(ms_img, 'ms')
    pan_img = preprocess_images(pan_img, 'pan')
    
    return ms_img, pan_img

# https://www.tensorflow.org/tutorials/load_data/images
def prepare_for_training(ds, batch_size, cache=True, shuffle_buffer_size=100):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # Repeat forever
    ds = ds.repeat()

    ds = ds.batch(batch_size)
    
    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds

In [None]:
def dataset_from_tif_tiles(tiles_path, batch_size, upsampling = False, 
                           cache = True, shuffle_buffer_size = 1000):
    
    ds = tf.data.Dataset.list_files(str(pathlib.Path(tiles_path)/'*/ms*.tif'))
    ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
    
    # upsampling through bicubic convolution before SR is required for SRCNN
    if upsampling:
        ds = ds.map(upsample_images, num_parallel_calls=AUTOTUNE)

    ds = prepare_for_training(ds, batch_size, cache, shuffle_buffer_size)
    return ds

# SRCNN needs upsampling!
ds_train = dataset_from_tif_tiles(DATA_PATH_TILES_TRAIN, BATCH_SIZE, upsampling = True)
ds_val = dataset_from_tif_tiles(DATA_PATH_TILES_VAL, BATCH_SIZE, upsampling = True)
ds_test = dataset_from_tif_tiles(DATA_PATH_TILES_TEST, BATCH_SIZE, upsampling = True)

In [None]:
def show_batch(image_batch):
    ms = image_batch[0].numpy()
    pan = image_batch[1].numpy()
    print('ms batch shape', ms.shape)
    print('pan batch shape', pan.shape)

    plt.figure(figsize=(15,15))
    for i in range(8):
        i = i * 2
        ax_ms = plt.subplot(4,4,i+1, label = 'ms')
        
        ms_image = ms[i,:,:,2] # Just showing channel 2 as grayscale
        pan_image = pan[i,:,:,0]

        #plt.imshow(ms_image)
        ax_ms.imshow(ms_image, cmap = 'gray')
        
        ax_pan = plt.subplot(4,4,i+2, label = 'pan')

        ax_pan.imshow(pan_image, cmap = 'gray')
        
show_batch(next(iter(ds_train)))

# Build model

In [None]:
def build_srcnn(channels_in, channels_out):
    
    srcnn = Sequential()
    
    srcnn.add(Conv2D(filters=128, kernel_size = (9, 9), 
                     kernel_initializer=RandomNormal(mean=0.0, stddev=0.001, seed=None),
                     bias_initializer='zeros',
                     activation='relu', padding='same', use_bias=True, 
                     input_shape=(None, None, channels_in)))

    srcnn.add(Conv2D(filters=64, kernel_size = (1, 1), 
                     kernel_initializer=RandomNormal(mean=0.0, stddev=0.001, seed=None),
                     bias_initializer='zeros',
                     activation='relu', padding='same', use_bias=True))
    
    srcnn.add(Conv2D(filters=channels_out, kernel_size = (5, 5), 
                     kernel_initializer=RandomNormal(mean=0.0, stddev=0.001, seed=None), 
                     bias_initializer='zeros',
                     activation='linear', padding='same', use_bias=True))
    
    # define optimizer
    adam = Adam(lr=0.0003)
    
    # compile model
    srcnn.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return srcnn

srcnn = build_srcnn(channels_in = MS_BANDS, channels_out = 1)
srcnn.summary()

# Train model

In [None]:
history = srcnn.fit(ds_train,
                    epochs = EPOCHS,
                    validation_data = ds_val,
                    steps_per_epoch = 100, 
                    validation_steps = 50)

In [None]:
#srcnn.load_weights('models/model2.h5')

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

In [None]:
def psnr(img1, img2):
    img1 = np.expand_dims(img1, -1)
    img2 = np.expand_dims(img2, -1)
    return tf.image.psnr(img1, img2, max_val=1.0)

def ssim(img1, img2):
    img1 = tf.convert_to_tensor(np.expand_dims(img1, -1), dtype = tf.float32)
    img2 = tf.convert_to_tensor(np.expand_dims(img2, -1), dtype = tf.float32)
    return tf.image.ssim(img1, img2, max_val=1.0)

In [None]:
def plot_comparisons(model, ds, n_comparisons):
    batch = next(iter(ds))
    
    ms = batch[0]
    ms = tf.math.reduce_mean(ms, axis = -1).numpy()
        
    pan = batch[1].numpy()[:,:,:,0]
    
    sr = model.predict(batch)[:,:,:,0]
    
    print(ms.shape)
    print(pan.shape)
    print(sr.shape)
    
    cmap = 'gray'

    for i in range(n_comparisons):        
        fig, axs = plt.subplots(2, 2, constrained_layout=True, figsize = (20,20))
        fig.suptitle(str('Comparison ' + str(i)))
        
        axs[0,0].set_title('MS Bicubic Upsampling ' + 
                           'PSNR: ' + 
                           str(psnr(ms[i], pan[i]).numpy()) + 
                           ' / SSIM: ' +
                           str(ssim(ms[i], pan[i]).numpy()))
        axs[0,0].imshow(ms[i], cmap = cmap)

        axs[0,1].set_title('SRCNN ' + 
                           'PSNR: ' + 
                           str(psnr(sr[i], pan[i]).numpy()) + 
                           ' / SSIM: ' +
                           str(ssim(sr[i], pan[i]).numpy()))
        axs[0,1].imshow(sr[i], cmap = cmap)

        axs[1,0].set_title('PAN (Ground Truth)')
        axs[1,0].imshow(pan[i], cmap = cmap)   

In [None]:
n_comparisons = 16

plot_comparisons(srcnn, ds_val, n_comparisons = n_comparisons)