# Experiment 05: Data augmentation

In this experiment we will study how different band configurations affect the training process and end results.

### Experiment variations:

- E03-4: The 4 MS WorldView-2 bands that are also available in GeoEye-1 (MS1 array RGB+NIR1)
 - In this variation we will also validate performance on the GeoEye-1 validation set
- E03-3: Only the 3 RGB bands from WorldView-2
 - Also validated on the GeoEye-1 validation set
 
### The notebook is divided into the following main sections:
1. Imports and configuration parameters
2. Tile generation (sampling of tiles from the satellite images)
3. Tile input pipelines (`tf.dataset` objects reading tiles from disk)
4. Building of models
5. Pretraining with L1 loss
6. Build the full ESRGAN model
7. GAN-training with L1 + Percep + GAN loss
8. Inspection of results

Training history is logged with TensorBoard.

## 1. Imports and configuration parameters

In [1]:
from modules.helpers import *
from modules.tile_generator import *
from modules.matlab_metrics import *
from modules.image_utils import *
from modules.tile_input_pipeline import *
from modules.models import *
from modules.evaluation import *

from modules.logging import *
from modules.train import *

import time

# Check GPUs and enable dynamic GPU memory use:",
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            # Prevent TensorFlow from allocating all memory of all GPUs:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [2]:
### MAIN SETTINGS ###############################################################################################
EXPERIMENT_NAMES = ['e05-8', 'e05-6', 'e05-4', 'e05-3']

# Select experiment variation to be run in THIS notebook:
EXPERIMENT = EXPERIMENT_NAMES[2]

# Turn on and off certain time consuming processes in the notebook:
GENERATE_TILES = False   # This should only be done once in experiment 01. All variations will read from the same
TILE_DENSITY_MAPS = False  # Loops through all tiles and compute density maps of where tiles have been sampled
CALCULATE_STATS = False  # Loops through all tiles and calculate mean and sd. Used for scaling
PRE_BUILD = True          # Step 1 of the training process
PRETRAIN = False          # Step 1 of the training process
GAN_BUILD = True          # Step 1 of the training process
GAN_TRAIN = False         # Step 2 of the training process
PRE_EVALUATE_LAST = False
GAN_EVALUATE_LAST = True
PRE_EVALUATE_HISTORY = False
GAN_EVALUATE_HISTORY = False

# Load metadata dataframe "meta" from repository root. 
# This dataframe keeops track of images and is used and updated throughout the notebook
meta = load_meta_pickle_csv('.', 'metadata_df', from_pickle=True)
#################################################################################################################

### PATHS #######################################################################################################
DATA_PATH = 'data/toulon-laspezia'
DATA_PATH_TILES = 'data/toulon-laspezia-tiles/e05'
DATA_PATH_TILES_P = {'train': DATA_PATH_TILES + '/train', 
                     'val': DATA_PATH_TILES + '/val', 
                     'test': DATA_PATH_TILES + '/test'}
LOGS_DIR = 'logs/' # Path to tensorboard logs and model checkpoint saves
LOGS_EXP_DIR = LOGS_DIR + EXPERIMENT
#################################################################################################################

### TILE GENERATION #############################################################################################
SENSORS_GENERATE = ['WV02', 'GE01']
AREAS_GENERATE = ['La_Spezia', 'Toulon']
meta = subset_by_areas_sensor(meta, areas=AREAS_GENERATE, sensors=SENSORS_GENERATE)
print('Sensors to generate tiles from:', SENSORS_GENERATE)
print('Areas to generate tiles from:', AREAS_GENERATE)

# Count images in partitions (train/val/test):
N_IMAGES_TOTAL = count_images(meta)
N_IMAGES = {'train': count_images_in_partition(meta, 'train'), 
            'val': count_images_in_partition(meta, 'val'), 
            'test': count_images_in_partition(meta, 'test')}
assert N_IMAGES_TOTAL == sum(N_IMAGES.values())  # Verify that different ways of counting adds up
print('Number of images in partitions', N_IMAGES)
print('Total number of images:', N_IMAGES_TOTAL)

TILES_PER_M2 = {'train': 2.0, 
                'val': 2.0, 
                'test': 2.0}

# Settings for whether to send every tile generated through a sea and cloud classifier
# This is useful if images consist of a lot of sea and clouds and you want to reduce the number of tiles
# with such monotone and less meaningful content. Classifier is trained on 2500 labeled tiles of various sizes
# where only tiles COMPLETELY covered by sea and/or clouds have been labelled "cloud/sea". 
# Validation accuracy around 0.95
CLOUD_SEA_REMOVAL = True
CLOUD_SEA_WEIGHTS_PATH = 'models/cloud-sea-classifier/cloudsea-effb0-augm-bicubic-pan-0.0005--200-0.127841.h5'
# Cutoff at inference time. Tiles with (quasi)-prob higher than cutoff will be classified as cloud and or sea:
CLOUD_SEA_PRED_CUTOFF = 0.95
# Setting to keep a certain proportion of cloud/sea tiles through the filter:
CLOUD_SEA_KEEP_RATE = 0.10

# GE01 images has some slight variations in resolution 0.5 +-0.05 m per pixel while WV02 is fixed at 0.5m
# Setting this to True will resize to as close as possible to 0.5m
# Not used in this notebook, but function is ready for use in module tile_generator.py
RESIZE_TO_PIXEL_SIZE = False
if RESIZE_TO_PIXEL_SIZE:
    RESIZE_RESAMPLING_METHOD = 'nearest'  # 'nearest', 'bicubic', 'bilinear'
    NEW_PIXEL_SIZE_PAN = 0.5
    RESIZE_DIR = DATA_PATH + '-resized'
    
# Data augmentation
AUGMENT_FLIP = True # both up/down and left/right flips
AUGMENT_ROTATE = True # 90 degree rotations
#################################################################################################################

### SENSORS AND AREA EXPERIMENT SELECTION #######################################################################
# Sensors used in which experiment variation
SENSORS_EXP = {'e05-8': {'train': 'WV02', 'val': ['WV02'], 'test': ['WV02']}, 
               'e05-6': {'train': 'WV02', 'val': ['WV02'], 'test': ['WV02']}, 
               'e05-4': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}, 
               'e05-3': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}}
SENSORS = SENSORS_EXP[EXPERIMENT]

# Areas used in which experiment variation
AREAS_EXP = {'e05-8': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e05-6': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e05-4': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e05-3': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}}
AREAS = AREAS_EXP[EXPERIMENT]
#################################################################################################################

### TILE DIMENSIONS #############################################################################################
# Note larger size of val and test. This is needed for sensible calculation of Ma, NIQE and PI calculation
SR_FACTOR = 4
MS_SIZE = {'train': 32, 'val': 128, 'test': 128}
PAN_SIZE = {'train': MS_SIZE['train'] * SR_FACTOR, 
            'val': MS_SIZE['val'] * SR_FACTOR, 
            'test': MS_SIZE['test'] * SR_FACTOR}
print('MS (LR) tile size:', MS_SIZE)
print('PAN (HR) tile size:', PAN_SIZE)
print('SR factor:', SR_FACTOR)
#################################################################################################################

### BAND (CHANNEL) CONFIGURATIONS ###############################################################################
# This is the essence of experiment 01
# Selection of bands is done in the tile input pipeline

# Selecting bands from the 8 bands of WV02:
WV02_FULL_BAND_CONFIG = get_sensor_bands('WV02', meta)
WV02_EXP_BAND_CONFIGS = {'e05-8': WV02_FULL_BAND_CONFIG,                          # 8 (all) bands
                         'e05-6': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items()  # 6 bands (BGYR+RE+NIR)
                                   if k not in ['Coastal', 'NIR2']}, 
                         'e05-4': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items()  # 4 bands (BGR+NIR)
                                   if k in ['Blue', 'Green', 'Red', 'NIR']},
                         'e05-3': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items()  # 3 bands (BGR)
                                   if k in ['Blue', 'Green', 'Red']}}
MS_BANDS_WV02_CONFIG = WV02_EXP_BAND_CONFIGS[EXPERIMENT]
if EXPERIMENT == 'e05-8':
    # We set this to 'all' in order to not pass e01-8 tiles through a band selection function (no reason to)
    MS_BANDS_WV02_IDXS = 'all' 
else:
    # For the other experiment variations we need lists of indices of the bands to be selected
    MS_BANDS_WV02_IDXS = list(MS_BANDS_WV02_CONFIG.values())

N_MS_BANDS = len(MS_BANDS_WV02_CONFIG.values()) # The number of MS bands in this experiment variation

# Selecting bands from the 4 bands of GE01:
GE01_FULL_BAND_CONFIG = get_sensor_bands('GE01', meta)                            
GE01_EXP_BAND_CONFIGS = {'e05-8': {None: None},                                   # not enough bands in GE01
                         'e05-6': {None: None},                                    # not enough bands in GE01
                         'e05-4': GE01_FULL_BAND_CONFIG,                           # 4 (all) bands (BGR+NIR)
                         'e05-3': {k:v for (k,v) in GE01_FULL_BAND_CONFIG.items()  # 3 bands (BGR)
                                   if k not in ['NIR']}}
MS_BANDS_GE01_CONFIG = GE01_EXP_BAND_CONFIGS[EXPERIMENT]
if EXPERIMENT == 'e05-4':
    MS_BANDS_GE01_IDXS = 'all'
else:
    MS_BANDS_GE01_IDXS = list(MS_BANDS_GE01_CONFIG.values())
print('MS (LR) Band Config WV02:', MS_BANDS_WV02_CONFIG)
print('MS (LR) Band Config GE01:', MS_BANDS_GE01_CONFIG)

N_PAN_BANDS = 1 # Obviously only 1 panchromatic band
#################################################################################################################

### MODEL PARAMETERS ############################################################################################
BATCH_SIZE = {'train': 16, 'val': 8, 'test': 8}
print('Batch sizes:', BATCH_SIZE)

# RRDB Generator Model parameters 
N_BLOCKS = 16 # Deeper means potential to capture more complex relationships, at the cost of training time
N_FILTERS = 64 # Baseline setting that is not tinkered with in this repository
#################################################################################################################

### PRETRAINING SETTINGS ########################################################################################
PRE_EPOCHS = 400
PRE_TRAIN_STEPS = 1000  # per epoch
PRE_VAL_STEPS = 0     # per epoch
print('Pretraining - Total steps:', PRE_EPOCHS * PRE_TRAIN_STEPS)

# Number of batches to save every epoch in TensorBoard
TRAIN_N_BATCHES_SAVE = 1
VAL_N_BATCHES_SAVE = 1

# Optimizer settings:
PRETRAIN_LOSS = 'l1'    # Official
PRETRAIN_LR = 5e-5      # Tuned and found stable for this particular experiment
#PRETRAIN_LR = 0.0002   # Official
PRETRAIN_BETA_1 = 0.9   # Official
PRETRAIN_BETA_2 = 0.999 # Official
# Note: Official implementation also uses stepwise learning rate scheduler. 
# This is avoided here as it is deemed not central to the experiment to "squeeze" out last performance and it 
# complicates comparisons between experiment variations
#################################################################################################################

### GAN TRAINING SETTINGS #######################################################################################
GAN_EPOCHS = 400
GAN_TRAIN_STEPS = 1000
GAN_VAL_STEPS = 0
# Proportion of val batches that will go through ma and niqe metric calculation
# MA_NIQE_PROPORTION = 0.04  # The calculation is very time consuming
MA_NIQE_PROPORTION = 1  # The calculation is very time consuming
print('GAN training - Total steps:', GAN_EPOCHS * GAN_TRAIN_STEPS)

# Weights for each loss in the composite loss function
G_LOSS_PIXEL_W = 0.01       # Official
G_LOSS_PERCEP_W = 1.0       # Official
G_LOSS_GENERATOR_W = 0.005  # Official

# Optimizer settings:
#GAN_G_LR = 1e-4 # Official
#GAN_D_LR = 1e-4 # Official
GAN_G_LR = 2e-5
GAN_D_LR = 2e-5
G_BETA_1, D_BETA_1 = 0.9, 0.9      # Official
G_BETA_2, D_BETA_2 = 0.999, 0.999  # Official
# Note: Official implementation also uses stepwise learning rate scheduler. 
# This is avoided here as it is deemed not central to the experiment to "squeeze" out last performance and it 
# complicates comparisons between experiment variations

# Path to the pretraining weights that is the starting point of GAN training:
PRETRAIN_WEIGHTS_DIRS = {'e05-4': LOGS_EXP_DIR + '/models/' + 'e05-4-pre_20210319-092729/'
                        }
PRETRAIN_WEIGHTS_DIR = PRETRAIN_WEIGHTS_DIRS[EXPERIMENT]
PRETRAIN_WEIGHTS_PATH = PRETRAIN_WEIGHTS_DIR + EXPERIMENT + '-pre-400.h5'

# Path to the gan-training weights that will be 
GAN_WEIGHTS_DIRS = {'e05-4': LOGS_EXP_DIR + '/models/' + 'e05-4-gan_20210321-101445/'
                   }
GAN_WEIGHTS_DIR = GAN_WEIGHTS_DIRS[EXPERIMENT]
GAN_WEIGHTS_PATH = GAN_WEIGHTS_DIR + EXPERIMENT + '-gan-G-399.h5'
#################################################################################################################

### MATLAB METRICS ##############################################################################################
# Calculate Ma, NIQE and Perceptual Index (PI) metrics on the validation set(s) during GAN training:
# PI was metric used in PIRM2018 competition https://github.com/roimehrez/PIRM2018
METRIC_MA = False
METRIC_NIQE = False
if METRIC_MA and METRIC_NIQE:
    METRIC_PI = True
else:
    METRIC_PI = False

# The number of pixels to be shaved off the border of the tile before calculating Ma/NIQE/PI (ignore border effects)
SHAVE_WIDTH = 4 # Official (as used in PIRM2018 evaluation)
# Ma/NIQE/PI calculation is done with official matlab repositories through MATLAB Engine API for Python
MATLAB_PATH = 'modules/matlab' # path to repositories
#################################################################################################################

### EVALUTAION ##################################################################################################
if PRE_EVALUATE_LAST or GAN_EVALUATE_LAST:
    METRIC_MA = True
    METRIC_NIQE = True
    if METRIC_MA and METRIC_NIQE:
        METRIC_PI = True
    else:
        METRIC_PI = False
        

EVAL_STEPS_PER_EPOCH = 'all'
EVAL_N_EPOCHS = 400
#EVAL_SENSOR = 'WV02'
EVAL_PER_IMAGE = True
    
if PRE_EVALUATE_HISTORY or GAN_EVALUATE_HISTORY:
    METRIC_MA = False
    METRIC_NIQE = True
    if METRIC_MA and METRIC_NIQE:
        METRIC_PI = True
    else:
        METRIC_PI = False
        
#if PRE_EVALUATE_HISTORY:
#    EVAL_WEIGHTS_DIR = PRETRAIN_WEIGHTS_DIR
#    EVAL_FIRST_STEP = 1
#    EVAL_PREFIX = EXPERIMENT + '-pre-'
#elif GAN_EVALUATE_HISTORY:
#    EVAL_WEIGHTS_DIR = GAN_WEIGHTS_DIR
#    EVAL_FIRST_STEP = 0
#    EVAL_PREFIX = EXPERIMENT + '-gan-'
    
print('MATLAB Metrics:')
print('Ma:', METRIC_MA)
print('NIQE:', METRIC_NIQE)
print('Perceptual Index (PI):', METRIC_PI)

Sensors to generate tiles from: ['WV02', 'GE01']
Areas to generate tiles from: ['La_Spezia', 'Toulon']
Number of images in partitions {'train': 22, 'val': 19, 'test': 21}
Total number of images: 62
MS (LR) tile size: {'train': 32, 'val': 128, 'test': 128}
PAN (HR) tile size: {'train': 128, 'val': 512, 'test': 512}
SR factor: 4
MS (LR) Band Config WV02: {'Blue': 1, 'Green': 2, 'Red': 4, 'NIR': 6}
MS (LR) Band Config GE01: {'Blue': 0, 'Green': 1, 'Red': 2, 'NIR': 3}
Batch sizes: {'train': 16, 'val': 8, 'test': 8}
Pretraining - Total steps: 400000
GAN training - Total steps: 400000
MATLAB Metrics:
Ma: True
NIQE: True
Perceptual Index (PI): True


## 2. Tile generation

### 2.1 Image resizing

Function `resize_sat_img_to_new_pixel_size` available in `modules.tile_generator`. Not used in this notebook

### 2.2 Tile allocation

We allocate `n_tiles` to each satellite image in proportion to the area covered by the satellite image. We adjust `n_tiles` by the argument `tiles_per_m2`. If `tiles_per_m2=1.0` then `n_tiles` is set deterministically to a value so that a square meter of satellite image is expected to be covered by `1.0` tile.

In [3]:
if GENERATE_TILES:
    meta = allocate_tiles_by_expected(meta, 
                                      override_pan_pixel_size=RESIZE_TO_PIXEL_SIZE,
                                      by_partition=True, 
                                      tiles_per_m2_train_val_test=(TILES_PER_M2['train'], 
                                                                   TILES_PER_M2['val'], 
                                                                   TILES_PER_M2['test']),
                                      pan_tile_size_train_val_test=(PAN_SIZE['train'], 
                                                                    PAN_SIZE['val'], 
                                                                    PAN_SIZE['test']),
                                      new_column_name='n_tiles')
else:
    # Load meta dataframe that was updated at tile generation time
    meta = load_meta_pickle_csv(DATA_PATH_TILES, 'metadata_tile_allocation', from_pickle=True)

n_tiles = {'train': count_tiles_in_partition(meta, 'train'),
           'val': count_tiles_in_partition(meta, 'val'), 
           'test':  count_tiles_in_partition(meta, 'test')}
n_tiles_total = count_tiles(meta)
assert n_tiles_total == sum(n_tiles.values())
print('Number of tiles per partition:')
print(n_tiles)
print('Total number of tiles:', n_tiles_total)

Number of tiles per partition:
{'train': 129221, 'val': 8113, 'test': 9293}
Total number of tiles: 146627


### 2.3 Tile generation to disk

In [4]:
if GENERATE_TILES:
    meta = generate_all_tiles(meta, 
                              save_dir=DATA_PATH_TILES, 
                              sr_factor=SR_FACTOR, 
                              by_partition=True,
                              ms_tile_size_train_val_test=(MS_SIZE['train'], MS_SIZE['val'], MS_SIZE['test']), 
                              cloud_sea_removal=CLOUD_SEA_REMOVAL, 
                              cloud_sea_weights_path=CLOUD_SEA_WEIGHTS_PATH, 
                              cloud_sea_pred_cutoff=CLOUD_SEA_PRED_CUTOFF,
                              cloud_sea_keep_rate=CLOUD_SEA_KEEP_RATE,
                              save_meta_to_disk=True)

In [5]:
if TILE_DENSITY_MAPS:
    for row in meta.iterrows():
        img_uid = row[0]
        density = tile_density_map(DATA_PATH_TILES, 
                                   row[1], 
                                   pan_or_ms='pan',
                                   density_dtype='uint8',
                                   write_to_disk=True,
                                   write_dir=DATA_PATH_TILES + '/density-maps', 
                                   write_filename=img_uid)
    # Plot last density
    plt.imshow(density)

In [6]:
if CALCULATE_STATS:
    train_tiles_mean, train_tiles_sd = mean_sd_of_train_tiles(DATA_PATH_TILES, 
                                                              sample_proportion=1.0, 
                                                              write_json=True)
else:
    train_tiles_mean, train_tiles_sd = read_mean_sd_json(DATA_PATH_TILES)

Loaded mean 341.3 and sd 128.4 from json file @ data/toulon-laspezia-tiles/e05/train_mean_sd.json


## 3. Data input pipeline from disk

### 3.1 Training set

In [7]:
SHUFFLE_BUFFER_SIZE = {'train': n_tiles['train'],  # 100
                       'val': n_tiles['val'],  # 100
                       'test': n_tiles['test']}  # 100

train_val_test = 'train'
sensor = SENSORS[train_val_test]
ds_train = {sensor: GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                   batch_size=BATCH_SIZE[train_val_test], 
                                   ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                   pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                   sensor=sensor,
                                   band_selection=MS_BANDS_WV02_IDXS, 
                                   mean_correction=train_tiles_mean,
                                   cache_memory=True,
                                   cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-' 
                                                  + train_val_test + '-' + sensor + '_cache'), 
                                   repeat=True, 
                                   shuffle=True, 
                                   shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test],
                                   augment_flip=True,
                                   augment_rotate=True
                                  )
           }
# Getting the scaled output range from the scaler. Needed to calculate PSNR and SSIM:
scaled_range = ds_train[sensor].get_scaler_output_range(print_ranges=True)

# Returning the actual tf.data.dataset object:
ds_train[sensor] = ds_train[sensor].get_dataset()
print(ds_train.keys())

Scaler ranges:
Input (uint) min, max: 0 2047
Input (uint) range: 2048
Output (float) range 1.2006480509994506
Output (float) min, max: -0.2000617970682984 1.0
dict_keys(['WV02'])


### 3.2 Validation set

In [8]:
# Validation set can have several sensors and is organized in a dictionary
# structure: ds_val = {sensor: dataset} ... ex: ds_val = {'WV02': dataset_with_only_WV02_images}
train_val_test = 'val'
ds_val = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_val[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                    batch_size=BATCH_SIZE[train_val_test], 
                                    ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                    pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                    sensor=sensor,
                                    band_selection=band_indices, 
                                    mean_correction=train_tiles_mean,
                                    cache_memory=True,
                                    cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                   + train_val_test + '-' + sensor + '_cache'), 
                                    repeat=True, 
                                    shuffle=True, 
                                    shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test])
    ds_val[sensor] = ds_val[sensor].get_dataset()
print(ds_val.keys())

dict_keys(['WV02', 'GE01'])


## 4. Build preliminary models

### 4.1 Bicubic baseline model

In [9]:
bicubic = build_deterministic_sr_model(upsample_factor=SR_FACTOR,
                                       resize_method='bicubic',
                                       loss='mean_absolute_error',
                                       metrics=('PSNR', 'SSIM'),
                                       scaled_range=scaled_range)

### 4.2 ESRGAN Generator model (pretrain version)

In [10]:
if PRE_BUILD:
    pretrain_model =  build_generator(pretrain_or_gan='pretrain', 
                                      pretrain_learning_rate=PRETRAIN_LR, 
                                      pretrain_loss_l1_l2=PRETRAIN_LOSS,
                                      pretrain_beta_1=PRETRAIN_BETA_1, 
                                      pretrain_beta_2=PRETRAIN_BETA_2, 
                                      pretrain_metrics=('PSNR', 'SSIM'),
                                      scaled_range=scaled_range, 
                                      n_channels_in=N_MS_BANDS, 
                                      n_channels_out=N_PAN_BANDS, 
                                      height_width_in=None,  # None will make network image size agnostic
                                      n_filters=N_FILTERS, 
                                      n_blocks=N_BLOCKS)
    # pretrain_model.summary()

## 5. Pretraining with L1 loss

In [11]:
if PRETRAIN:
    history = pretrain_esrgan(generator=pretrain_model,
                              ds_train_dict=ds_train,
                              epochs=PRE_EPOCHS,
                              steps_per_epoch=PRE_TRAIN_STEPS,
                              initial_epoch=0,
                              validate=True,
                              ds_val_dict=ds_val,
                              val_steps=PRE_VAL_STEPS,
                              model_name=EXPERIMENT + '-pre',
                              tag=EXPERIMENT,
                              log_tensorboard=True,
                              tensorboard_logs_dir=LOGS_EXP_DIR + '/tb',
                              save_models=True,
                              models_save_dir=LOGS_EXP_DIR + '/models',
                              save_weights_only=True,
                              log_train_images=True,
                              n_train_image_batches=TRAIN_N_BATCHES_SAVE,
                              log_val_images=True,
                              n_val_image_batches=VAL_N_BATCHES_SAVE)

## 6. Build the full ESRGAN Model

In [12]:
if GAN_BUILD:
    gan_model = build_esrgan_model(PRETRAIN_WEIGHTS_PATH,
                                   n_channels_in=N_MS_BANDS, 
                                   n_channels_out=N_PAN_BANDS, 
                                   n_filters=N_FILTERS, 
                                   n_blocks=N_BLOCKS, 
                                   pan_shape=(PAN_SIZE['train'], PAN_SIZE['train'], N_PAN_BANDS),
                                   G_lr=GAN_G_LR, 
                                   D_lr=GAN_D_LR, 
                                   G_beta_1=G_BETA_1, 
                                   G_beta_2=G_BETA_2, 
                                   D_beta_1=D_BETA_1, 
                                   D_beta_2=D_BETA_2,
                                   G_loss_pixel_w=G_LOSS_PIXEL_W, 
                                   G_loss_pixel_l1_l2='l1',
                                   G_loss_percep_w=G_LOSS_PERCEP_W, 
                                   G_loss_percep_l1_l2='l1', 
                                   G_loss_percep_layer=54,
                                   G_loss_percep_before_act=True,
                                   G_loss_generator_w=G_LOSS_GENERATOR_W,
                                   metric_reg=False, 
                                   metric_ma=METRIC_MA, 
                                   metric_niqe=METRIC_NIQE, 
                                   ma_niqe_proportion=MA_NIQE_PROPORTION,
                                   matlab_wd_path='modules/matlab',
                                   scale_mean=train_tiles_mean, 
                                   scaled_range=scaled_range, 
                                   shave_width=SHAVE_WIDTH)

Starting matlab.engine ...
matlab.engine started


## 7. GAN training

In [13]:
if GAN_TRAIN:
    history = gan_train_esrgan(esrgan_model=gan_model,
                               ds_train_dict=ds_train,
                               epochs=GAN_EPOCHS,
                               steps_per_epoch=GAN_TRAIN_STEPS,
                               initial_epoch=0,
                               validate=True,
                               ds_val_dict=ds_val,
                               val_steps=GAN_VAL_STEPS,
                               model_name=EXPERIMENT + '-gan',
                               tag=EXPERIMENT,
                               log_tensorboard=True,
                               tensorboard_logs_dir=LOGS_EXP_DIR + '/tb',
                               save_models=True,
                               models_save_dir=LOGS_EXP_DIR + '/models',
                               save_weights_only=True,
                               log_train_images=True,
                               n_train_image_batches=TRAIN_N_BATCHES_SAVE,
                               log_val_images=True,
                               n_val_image_batches=VAL_N_BATCHES_SAVE)

## 8. Evaluation

### 8.1 Data input pipelines for final evaluation

The pipeline is modified to include the file paths of the tiles/patches so that it is possible to log performance metrics for individual files and by extension for individual satellite images.

#### 8.1.1 Validation set

In [14]:
# Validation set can have several sensors and is organized in a dictionary
# structure: ds_val = {sensor: dataset} ... ex: ds_val = {'WV02': dataset_with_only_WV02_images}
train_val_test = 'val'
ds_val = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_val[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                    batch_size=BATCH_SIZE[train_val_test], 
                                    ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                    pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                    sensor=sensor,
                                    band_selection=band_indices, 
                                    mean_correction=train_tiles_mean,
                                    cache_memory=False,
                                    cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                   + train_val_test + '-' + sensor + '_filepath_cache'), 
                                    repeat=False, 
                                    shuffle=False, 
                                    shuffle_buffer_size=0, #SHUFFLE_BUFFER_SIZE[train_val_test], 
                                    include_file_paths=True)
    ds_val[sensor] = ds_val[sensor].get_dataset()
print(ds_val.keys())

dict_keys(['WV02', 'GE01'])


#### 8.1.2 Test set

In [15]:
train_val_test = 'test'
ds_test = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_test[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                     batch_size=BATCH_SIZE[train_val_test], 
                                     ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                     pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                     sensor=sensor,
                                     band_selection=band_indices, 
                                     mean_correction=train_tiles_mean,
                                     cache_memory=False,
                                     cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                    + train_val_test + '-' + sensor + '_filepath_cache'), 
                                     repeat=False, 
                                     shuffle=False, 
                                     shuffle_buffer_size=0)
    ds_test[sensor] = ds_test[sensor].get_dataset()
print(ds_test.keys())

dict_keys(['WV02', 'GE01'])


### 8.2 Evaluate last epoch

In [None]:
if PRE_EVALUATE_LAST or GAN_EVALUATE_LAST:
    val_or_test = 'val'
    # SENSORS = {'val': ['GE01']}

    # Computing Ma is 100x more time consuming than anything else. It is not interesting to measure this for pretraining
    #if METRIC_MA:
    #    PRE_GAN = ['gan']
    #else:
    #    PRE_GAN = ['pre', 'gan']
    
    if PRE_EVALUATE_LAST and not GAN_EVALUATE_LAST:
        PRE_GAN = ['pre']
    if not PRE_EVALUATE_LAST and GAN_EVALUATE_LAST:
        PRE_GAN = ['gan']

    print(PRE_GAN, SENSORS[val_or_test])
    for pre_gan in PRE_GAN:
        
        for sensor in SENSORS[val_or_test]:
            # if sensor == 'GE01':
            #     continue
                
            if sensor == 'GE01':
                band_indices = MS_BANDS_GE01_IDXS
            elif sensor == 'WV02':
                band_indices = MS_BANDS_WV02_IDXS
            if pre_gan == 'pre':
                gan_model.G.load_weights(PRETRAIN_WEIGHTS_PATH)
            else:
                gan_model.G.load_weights(GAN_WEIGHTS_PATH)

            print(pre_gan, sensor)
            start = time.time()
            results_df = esrgan_evaluate(model=gan_model, 
                                         dataset=ds_val[sensor], 
                                         steps='all', 
                                         per_image=True, 
                                         write_csv=True,
                                         csv_path=str(LOGS_EXP_DIR + '/csv/' + 'final_epoch-' 
                                                      + pre_gan + '-' + val_or_test + '-' + sensor + '.csv'), 
                                         verbose=1
                                        )
            end = time.time()
            print(str((end - start) / 60), 'minutes')

['gan'] ['WV02', 'GE01']
gan WV02
Computed 8 images in  140.15203738212585 seconds
Last image: {'G_pixel_loss': 9.921810124069452e-05, 'G_perceptual_loss': 1.4626184701919556, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 34.548805236816406, 'SSIM': 0.9117075204849243, 'Ma': 4.679090976715088, 'NIQE': 3.850234270095825, 'PI': 4.585571765899658}
Computed 8 images in  131.30784010887146 seconds
Last image: {'G_pixel_loss': 4.8402453103335574e-05, 'G_perceptual_loss': 0.5312403440475464, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.11383056640625, 'SSIM': 0.977506697177887, 'Ma': 4.290518760681152, 'NIQE': 6.61899995803833, 'PI': 6.164240837097168}
Computed 8 images in  131.20135831832886 seconds
Last image: {'G_pixel_loss': 0.0002174215333070606, 'G_perceptual_loss': 3.167149305343628, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 27.46923065185547, 'SSIM': 0.8469409346580505, 'Ma': 4.4416003227233

Computed 8 images in  129.22005367279053 seconds
Last image: {'G_pixel_loss': 3.205088432878256e-05, 'G_perceptual_loss': 0.1835993230342865, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 49.07991027832031, 'SSIM': 0.9946539998054504, 'Ma': 3.0635979175567627, 'NIQE': 7.241084575653076, 'PI': 7.088743209838867}
Computed 8 images in  128.41788601875305 seconds
Last image: {'G_pixel_loss': 0.00010816437134053558, 'G_perceptual_loss': 1.120068907737732, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.71501541137695, 'SSIM': 0.8864721655845642, 'Ma': 4.522098064422607, 'NIQE': 4.123499393463135, 'PI': 4.800700664520264}
Computed 8 images in  128.85620999336243 seconds
Last image: {'G_pixel_loss': 8.614096441306174e-05, 'G_perceptual_loss': 0.8129532337188721, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.393821716308594, 'SSIM': 0.9188840985298157, 'Ma': 4.644932270050049, 'NIQE': 3.915651559829712, 

Computed 8 images in  123.51638579368591 seconds
Last image: {'G_pixel_loss': 0.00016746297478675842, 'G_perceptual_loss': 2.432255268096924, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.858972549438477, 'SSIM': 0.82475346326828, 'Ma': 4.662391185760498, 'NIQE': 3.992033004760742, 'PI': 4.664820671081543}
Computed 8 images in  123.2991452217102 seconds
Last image: {'G_pixel_loss': 0.0002284558431711048, 'G_perceptual_loss': 2.888357162475586, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 28.970216751098633, 'SSIM': 0.7890487909317017, 'Ma': 4.683557033538818, 'NIQE': 4.37243127822876, 'PI': 4.844437122344971}
Computed 8 images in  123.5530595779419 seconds
Last image: {'G_pixel_loss': 3.758089951588772e-05, 'G_perceptual_loss': 0.1443498432636261, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 48.92828369140625, 'SSIM': 0.9975225329399109, 'Ma': 2.9010279178619385, 'NIQE': 11.381860733032227, 'PI'

Computed 8 images in  123.69129943847656 seconds
Last image: {'G_pixel_loss': 0.00047308809007517993, 'G_perceptual_loss': 0.7748098373413086, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 28.00346565246582, 'SSIM': 0.8047928810119629, 'Ma': 3.53035306930542, 'NIQE': 6.120121955871582, 'PI': 6.29488468170166}
Computed 8 images in  123.5008053779602 seconds
Last image: {'G_pixel_loss': 0.0004873315920121968, 'G_perceptual_loss': 0.9271759390830994, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 27.545988082885742, 'SSIM': 0.7623817324638367, 'Ma': 4.46664571762085, 'NIQE': 5.439011573791504, 'PI': 5.486183166503906}
Computed 8 images in  123.26654815673828 seconds
Last image: {'G_pixel_loss': 7.744669710518792e-05, 'G_perceptual_loss': 0.9668660163879395, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 41.325103759765625, 'SSIM': 0.9252534508705139, 'Ma': 4.199149131774902, 'NIQE': 4.544976234436035, 'PI

Computed 8 images in  122.98172664642334 seconds
Last image: {'G_pixel_loss': 0.0001179037062684074, 'G_perceptual_loss': 1.5342493057250977, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.45677185058594, 'SSIM': 0.8452537655830383, 'Ma': 4.303503036499023, 'NIQE': 4.148168087005615, 'PI': 4.922332763671875}
Computed 8 images in  122.88932800292969 seconds
Last image: {'G_pixel_loss': 0.0001263733283849433, 'G_perceptual_loss': 1.7575498819351196, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.59084701538086, 'SSIM': 0.8237109780311584, 'Ma': 4.446253299713135, 'NIQE': 4.338284015655518, 'PI': 4.946015357971191}
Computed 8 images in  123.18253064155579 seconds
Last image: {'G_pixel_loss': 7.233025098685175e-05, 'G_perceptual_loss': 1.6158552169799805, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.21706771850586, 'SSIM': 0.9405562877655029, 'Ma': 4.790010452270508, 'NIQE': 4.1135969161987305, '

Computed 8 images in  123.307932138443 seconds
Last image: {'G_pixel_loss': 0.00010120978549821302, 'G_perceptual_loss': 1.1989761590957642, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.43614196777344, 'SSIM': 0.8682326078414917, 'Ma': 4.355616569519043, 'NIQE': 4.603621482849121, 'PI': 5.124002456665039}
Computed 8 images in  123.54318761825562 seconds
Last image: {'G_pixel_loss': 0.00018780704704113305, 'G_perceptual_loss': 2.369891881942749, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.878555297851562, 'SSIM': 0.7545826435089111, 'Ma': 4.3164286613464355, 'NIQE': 3.988067150115967, 'PI': 4.835819244384766}
Computed 8 images in  123.30418586730957 seconds
Last image: {'G_pixel_loss': 0.0001942677772603929, 'G_perceptual_loss': 3.0002822875976562, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.741214752197266, 'SSIM': 0.7877787947654724, 'Ma': 4.302739143371582, 'NIQE': 4.25567626953125, '

Computed 8 images in  120.3559033870697 seconds
Last image: {'G_pixel_loss': 5.564704042626545e-05, 'G_perceptual_loss': 0.602520227432251, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.883663177490234, 'SSIM': 0.9671584963798523, 'Ma': 4.601249694824219, 'NIQE': 5.052131652832031, 'PI': 5.225440979003906}
Computed 8 images in  119.78926014900208 seconds
Last image: {'G_pixel_loss': 0.00015493624960072339, 'G_perceptual_loss': 1.824330449104309, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.97593688964844, 'SSIM': 0.8114814162254333, 'Ma': 4.490235805511475, 'NIQE': 4.5312819480896, 'PI': 5.0205230712890625}
Computed 8 images in  120.14009523391724 seconds
Last image: {'G_pixel_loss': 0.00022874110436532646, 'G_perceptual_loss': 2.829629421234131, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.075960159301758, 'SSIM': 0.7437537908554077, 'Ma': 4.790683746337891, 'NIQE': 3.782520055770874, 'PI

Computed 8 images in  119.56316804885864 seconds
Last image: {'G_pixel_loss': 0.00013393412518780679, 'G_perceptual_loss': 1.90084707736969, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.49263000488281, 'SSIM': 0.8479577898979187, 'Ma': 4.695612907409668, 'NIQE': 3.9538841247558594, 'PI': 4.629135608673096}
Computed 8 images in  119.61244225502014 seconds
Last image: {'G_pixel_loss': 4.4195196096552536e-05, 'G_perceptual_loss': 0.4258836805820465, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 45.12042999267578, 'SSIM': 0.9748183488845825, 'Ma': 5.056529521942139, 'NIQE': 6.489216327667236, 'PI': 5.716343402862549}
Computed 8 images in  119.99645733833313 seconds
Last image: {'G_pixel_loss': 4.029331466881558e-05, 'G_perceptual_loss': 0.44507938623428345, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 44.913394927978516, 'SSIM': 0.9772647619247437, 'Ma': 5.037129878997803, 'NIQE': 6.873790264129639,

Computed 8 images in  119.94581460952759 seconds
Last image: {'G_pixel_loss': 0.00012126332148909569, 'G_perceptual_loss': 1.9105284214019775, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.32685852050781, 'SSIM': 0.8585326075553894, 'Ma': 4.463947772979736, 'NIQE': 5.006934642791748, 'PI': 5.271493434906006}
Computed 8 images in  119.99514245986938 seconds
Last image: {'G_pixel_loss': 2.9325792638701387e-05, 'G_perceptual_loss': 0.40948763489723206, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 47.03752136230469, 'SSIM': 0.9930523037910461, 'Ma': 3.5314106941223145, 'NIQE': 6.9149489402771, 'PI': 6.691769123077393}
Computed 8 images in  120.08771324157715 seconds
Last image: {'G_pixel_loss': 0.00010212225606665015, 'G_perceptual_loss': 0.7520250678062439, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.0391845703125, 'SSIM': 0.8892695903778076, 'Ma': 3.912792682647705, 'NIQE': 6.382262706756592, 

Computed 8 images in  120.27121758460999 seconds
Last image: {'G_pixel_loss': 0.00014320970512926579, 'G_perceptual_loss': 1.922573208808899, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 34.4127082824707, 'SSIM': 0.7985349297523499, 'Ma': 4.67747163772583, 'NIQE': 5.082469463348389, 'PI': 5.202498912811279}
Computed 8 images in  120.84319591522217 seconds
Last image: {'G_pixel_loss': 2.7894404411199503e-05, 'G_perceptual_loss': 0.4231269657611847, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 49.2363395690918, 'SSIM': 0.9904235601425171, 'Ma': 3.6739425659179688, 'NIQE': 6.741617679595947, 'PI': 6.53383731842041}
Computed 8 images in  120.97972846031189 seconds
Last image: {'G_pixel_loss': 0.0001231971400557086, 'G_perceptual_loss': 1.4464502334594727, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.8243522644043, 'SSIM': 0.8232472538948059, 'Ma': 4.397040367126465, 'NIQE': 5.142710208892822, 'PI':

Computed 8 images in  120.8314917087555 seconds
Last image: {'G_pixel_loss': 0.0001584184356033802, 'G_perceptual_loss': 1.8253241777420044, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 32.94615173339844, 'SSIM': 0.7967504858970642, 'Ma': 4.504445552825928, 'NIQE': 5.001407146453857, 'PI': 5.248480796813965}
Computed 8 images in  119.94977593421936 seconds
Last image: {'G_pixel_loss': 0.00012742204125970602, 'G_perceptual_loss': 1.6935265064239502, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.804508209228516, 'SSIM': 0.8178699612617493, 'Ma': 4.599576473236084, 'NIQE': 4.561769485473633, 'PI': 4.981096267700195}
Computed 8 images in  119.95992732048035 seconds
Last image: {'G_pixel_loss': 6.641718209721148e-05, 'G_perceptual_loss': 1.01634681224823, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.51555633544922, 'SSIM': 0.9377175569534302, 'Ma': 4.433074951171875, 'NIQE': 3.7516136169433594, 'P

Computed 8 images in  120.84027647972107 seconds
Last image: {'G_pixel_loss': 0.0001255602401215583, 'G_perceptual_loss': 1.2871732711791992, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.27839279174805, 'SSIM': 0.7953267693519592, 'Ma': 4.394082069396973, 'NIQE': 5.539848804473877, 'PI': 5.572883605957031}
Computed 8 images in  120.70609927177429 seconds
Last image: {'G_pixel_loss': 9.117678564507514e-05, 'G_perceptual_loss': 1.0233442783355713, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.015228271484375, 'SSIM': 0.8973127007484436, 'Ma': 3.839158296585083, 'NIQE': 5.170252323150635, 'PI': 5.6655473709106445}
Computed 8 images in  120.06927037239075 seconds
Last image: {'G_pixel_loss': 1.9417346265981905e-05, 'G_perceptual_loss': 0.436769962310791, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 52.35008239746094, 'SSIM': 0.9949457049369812, 'Ma': 2.944749355316162, 'NIQE': 8.588314056396484, 

Computed 8 images in  120.65051698684692 seconds
Last image: {'G_pixel_loss': 0.00022194162011146545, 'G_perceptual_loss': 2.892197608947754, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.782413482666016, 'SSIM': 0.7511066198348999, 'Ma': 4.567202568054199, 'NIQE': 3.91652774810791, 'PI': 4.6746625900268555}
Computed 8 images in  120.6726496219635 seconds
Last image: {'G_pixel_loss': 9.024287282954901e-05, 'G_perceptual_loss': 1.05653977394104, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.04295349121094, 'SSIM': 0.8907086849212646, 'Ma': 4.198223114013672, 'NIQE': 5.192537307739258, 'PI': 5.497157096862793}
Computed 8 images in  120.42351508140564 seconds
Last image: {'G_pixel_loss': 2.6983483621734194e-05, 'G_perceptual_loss': 0.5037476420402527, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 45.79457092285156, 'SSIM': 0.9912317395210266, 'Ma': 3.902679681777954, 'NIQE': 7.621434211730957, 'PI

Computed 8 images in  121.25870180130005 seconds
Last image: {'G_pixel_loss': 0.00010276260582031682, 'G_perceptual_loss': 1.5164811611175537, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.15976333618164, 'SSIM': 0.9104282855987549, 'Ma': 4.384437084197998, 'NIQE': 4.537520408630371, 'PI': 5.076541900634766}
Computed 8 images in  120.73612952232361 seconds
Last image: {'G_pixel_loss': 5.4656269639963284e-05, 'G_perceptual_loss': 0.6829650402069092, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.22709274291992, 'SSIM': 0.954283595085144, 'Ma': 4.816863059997559, 'NIQE': 4.949092388153076, 'PI': 5.06611442565918}
Computed 8 images in  120.9517297744751 seconds
Last image: {'G_pixel_loss': 3.893834582413547e-05, 'G_perceptual_loss': 0.3943135142326355, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 44.57278823852539, 'SSIM': 0.9776623845100403, 'Ma': 5.100257396697998, 'NIQE': 6.042533874511719, 'PI

Computed 8 images in  121.538489818573 seconds
Last image: {'G_pixel_loss': 5.421461901278235e-05, 'G_perceptual_loss': 1.0108872652053833, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.01115036010742, 'SSIM': 0.9683601260185242, 'Ma': 4.701779365539551, 'NIQE': 5.895245552062988, 'PI': 5.596733093261719}
Computed 8 images in  122.20273852348328 seconds
Last image: {'G_pixel_loss': 5.647423677146435e-05, 'G_perceptual_loss': 0.6858334541320801, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 34.82450485229492, 'SSIM': 0.9819942116737366, 'Ma': 4.243229866027832, 'NIQE': 7.5121684074401855, 'PI': 6.634469032287598}
Computed 8 images in  128.12172675132751 seconds
Last image: {'G_pixel_loss': 2.8776283215847798e-05, 'G_perceptual_loss': 0.21880176663398743, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 51.02225112915039, 'SSIM': 0.9972633123397827, 'Ma': 2.923737049102783, 'NIQE': 12.239630699157715, 

Computed 8 images in  125.1634464263916 seconds
Last image: {'G_pixel_loss': 7.711795478826389e-05, 'G_perceptual_loss': 1.0886433124542236, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.46367645263672, 'SSIM': 0.932102620601654, 'Ma': 4.219958782196045, 'NIQE': 5.03418493270874, 'PI': 5.407113075256348}
Computed 8 images in  133.2879877090454 seconds
Last image: {'G_pixel_loss': 0.0004802401235792786, 'G_perceptual_loss': 0.8891507983207703, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 27.525753021240234, 'SSIM': 0.7752285003662109, 'Ma': 4.553079605102539, 'NIQE': 5.506618022918701, 'PI': 5.47676944732666}
Computed 8 images in  130.67885541915894 seconds
Last image: {'G_pixel_loss': 0.00025126669788733125, 'G_perceptual_loss': 2.0657236576080322, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.79701042175293, 'SSIM': 0.8244332671165466, 'Ma': 4.899949073791504, 'NIQE': 4.646131992340088, 'PI':

Computed 8 images in  127.9337375164032 seconds
Last image: {'G_pixel_loss': 0.00020497506193351, 'G_perceptual_loss': 3.2103664875030518, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.091747283935547, 'SSIM': 0.8420071005821228, 'Ma': 4.6519670486450195, 'NIQE': 3.832648754119873, 'PI': 4.590340614318848}
Computed 8 images in  128.2501037120819 seconds
Last image: {'G_pixel_loss': 0.00014223391190171242, 'G_perceptual_loss': 2.2123587131500244, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.948951721191406, 'SSIM': 0.8710927963256836, 'Ma': 4.760298728942871, 'NIQE': 4.332188606262207, 'PI': 4.785944938659668}
Computed 8 images in  129.52086281776428 seconds
Last image: {'G_pixel_loss': 0.000490322767291218, 'G_perceptual_loss': 0.8258810639381409, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 27.49958038330078, 'SSIM': 0.8029786944389343, 'Ma': 3.9326770305633545, 'NIQE': 8.0248441696167, 'PI'

Computed 8 images in  121.34143424034119 seconds
Last image: {'G_pixel_loss': 0.00017154899251181632, 'G_perceptual_loss': 2.9938437938690186, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.098421096801758, 'SSIM': 0.8939118385314941, 'Ma': 4.594542026519775, 'NIQE': 3.726038932800293, 'PI': 4.56574821472168}
Computed 8 images in  121.60379981994629 seconds
Last image: {'G_pixel_loss': 0.00019358577264938504, 'G_perceptual_loss': 2.831820487976074, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.380264282226562, 'SSIM': 0.8241513967514038, 'Ma': 4.582991123199463, 'NIQE': 3.5563435554504395, 'PI': 4.486676216125488}
Computed 8 images in  121.45103049278259 seconds
Last image: {'G_pixel_loss': 0.00022916629677638412, 'G_perceptual_loss': 3.7652664184570312, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 28.453359603881836, 'SSIM': 0.8235191702842712, 'Ma': 4.6647820472717285, 'NIQE': 3.3978469371795

Computed 8 images in  120.00952172279358 seconds
Last image: {'G_pixel_loss': 1.9936000171583146e-05, 'G_perceptual_loss': 0.24413560330867767, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 45.60480499267578, 'SSIM': 0.9969000816345215, 'Ma': 3.0671708583831787, 'NIQE': 10.675905227661133, 'PI': 8.804367065429688}
Computed 8 images in  120.96543264389038 seconds
Last image: {'G_pixel_loss': 1.365947082376806e-05, 'G_perceptual_loss': 0.2072916179895401, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 56.41897964477539, 'SSIM': 0.9981446862220764, 'Ma': 2.9129722118377686, 'NIQE': 11.202659606933594, 'PI': 9.144844055175781}
Computed 8 images in  122.0546932220459 seconds
Last image: {'G_pixel_loss': 0.0002253275306429714, 'G_perceptual_loss': 3.3195347785949707, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 27.31139373779297, 'SSIM': 0.8726480007171631, 'Ma': 4.73441743850708, 'NIQE': 3.802421331405639

Computed 8 images in  128.4503824710846 seconds
Last image: {'G_pixel_loss': 0.000355055759428069, 'G_perceptual_loss': 0.4514385163784027, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.50189208984375, 'SSIM': 0.9455184936523438, 'Ma': 2.933380126953125, 'NIQE': 11.070639610290527, 'PI': 9.06863021850586}
Computed 8 images in  131.31293964385986 seconds
Last image: {'G_pixel_loss': 0.0005910327890887856, 'G_perceptual_loss': 2.476684331893921, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 23.299072265625, 'SSIM': 0.7812517285346985, 'Ma': 4.847098350524902, 'NIQE': 4.395219326019287, 'PI': 4.774060249328613}
Computed 8 images in  131.6911280155182 seconds
Last image: {'G_pixel_loss': 0.0010244438890367746, 'G_perceptual_loss': 4.574019908905029, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 20.290264129638672, 'SSIM': 0.3136768043041229, 'Ma': 4.5041279792785645, 'NIQE': 3.6405136585235596, 'PI': 

Computed 8 images in  124.95788788795471 seconds
Last image: {'G_pixel_loss': 0.0009703509276732802, 'G_perceptual_loss': 4.772510051727295, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 19.773836135864258, 'SSIM': 0.44108787178993225, 'Ma': 4.433946132659912, 'NIQE': 3.274780035018921, 'PI': 4.420416831970215}
Computed 8 images in  122.44818043708801 seconds
Last image: {'G_pixel_loss': 0.0005325590027496219, 'G_perceptual_loss': 1.6027761697769165, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 24.907939910888672, 'SSIM': 0.808749794960022, 'Ma': 4.48066520690918, 'NIQE': 4.869050979614258, 'PI': 5.194192886352539}
Computed 8 images in  121.93722105026245 seconds
Last image: {'G_pixel_loss': 0.00037364367744885385, 'G_perceptual_loss': 0.8052957653999329, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 29.476640701293945, 'SSIM': 0.932326078414917, 'Ma': 3.3842356204986572, 'NIQE': 9.651570320129395, 

Computed 8 images in  123.33107209205627 seconds
Last image: {'G_pixel_loss': 0.00039938942063599825, 'G_perceptual_loss': 0.8194326162338257, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 29.52374839782715, 'SSIM': 0.8256050944328308, 'Ma': 2.8943984508514404, 'NIQE': 10.129556655883789, 'PI': 8.617578506469727}
Computed 8 images in  121.74995708465576 seconds
Last image: {'G_pixel_loss': 0.0009261431987397373, 'G_perceptual_loss': 4.015405654907227, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 21.471420288085938, 'SSIM': 0.4204094409942627, 'Ma': 4.711967468261719, 'NIQE': 3.767117738723755, 'PI': 4.5275750160217285}
Computed 8 images in  119.98291349411011 seconds
Last image: {'G_pixel_loss': 0.0008341963985003531, 'G_perceptual_loss': 3.808436155319214, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 22.268348693847656, 'SSIM': 0.3424900770187378, 'Ma': 4.449158191680908, 'NIQE': 3.945773363113403

Computed 8 images in  120.79142904281616 seconds
Last image: {'G_pixel_loss': 0.0007272413349710405, 'G_perceptual_loss': 2.5389163494110107, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 23.681310653686523, 'SSIM': 0.10639496892690659, 'Ma': 4.429962158203125, 'NIQE': 4.270447254180908, 'PI': 4.9202423095703125}
Computed 8 images in  121.07306551933289 seconds
Last image: {'G_pixel_loss': 0.0008383747772313654, 'G_perceptual_loss': 3.9619216918945312, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 22.087217330932617, 'SSIM': 0.19014783203601837, 'Ma': 4.517958164215088, 'NIQE': 4.335137844085693, 'PI': 4.908589839935303}
Computed 8 images in  121.7188196182251 seconds
Last image: {'G_pixel_loss': 0.0006786753074266016, 'G_perceptual_loss': 2.9930241107940674, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 23.60735321044922, 'SSIM': 0.5595707297325134, 'Ma': 4.81378173828125, 'NIQE': 3.7160770893096924

Computed 8 images in  122.30056881904602 seconds
Last image: {'G_pixel_loss': 0.00023257355496753007, 'G_perceptual_loss': 1.8354339599609375, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.70792007446289, 'SSIM': 0.7443394660949707, 'Ma': 4.300338268280029, 'NIQE': 4.767414569854736, 'PI': 5.2335381507873535}
Computed 8 images in  122.63335871696472 seconds
Last image: {'G_pixel_loss': 0.000125675811432302, 'G_perceptual_loss': 0.27611681818962097, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.517059326171875, 'SSIM': 0.9927166700363159, 'Ma': 2.9415643215179443, 'NIQE': 9.339662551879883, 'PI': 8.19904899597168}
Computed 8 images in  122.80134010314941 seconds
Last image: {'G_pixel_loss': 0.0002932145434897393, 'G_perceptual_loss': 1.0753358602523804, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.127662658691406, 'SSIM': 0.6276817917823792, 'Ma': 4.368685722351074, 'NIQE': 4.632490158081055

Computed 8 images in  127.82687497138977 seconds
Last image: {'G_pixel_loss': 0.0002097432443406433, 'G_perceptual_loss': 1.615799069404602, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 32.912166595458984, 'SSIM': 0.8247900009155273, 'Ma': 4.29816198348999, 'NIQE': 4.688397407531738, 'PI': 5.195117950439453}
Computed 8 images in  131.59180688858032 seconds
Last image: {'G_pixel_loss': 0.0002485198201611638, 'G_perceptual_loss': 0.8257766962051392, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 32.43252182006836, 'SSIM': 0.8447335958480835, 'Ma': 4.104452133178711, 'NIQE': 5.246312618255615, 'PI': 5.570930480957031}
Computed 8 images in  131.11864233016968 seconds
Last image: {'G_pixel_loss': 0.00021233964071143419, 'G_perceptual_loss': 0.9881817698478699, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.777000427246094, 'SSIM': 0.877316951751709, 'Ma': 4.305965900421143, 'NIQE': 4.466259479522705, 'P

Computed 8 images in  121.44638895988464 seconds
Last image: {'G_pixel_loss': 0.00034893531119450927, 'G_perceptual_loss': 3.140362024307251, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 27.841163635253906, 'SSIM': 0.5681357383728027, 'Ma': 4.604144096374512, 'NIQE': 4.583624839782715, 'PI': 4.989740371704102}
Computed 8 images in  122.2690052986145 seconds
Last image: {'G_pixel_loss': 0.00012785398575942963, 'G_perceptual_loss': 1.1027657985687256, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.443603515625, 'SSIM': 0.9534441828727722, 'Ma': 4.187110424041748, 'NIQE': 5.193690299987793, 'PI': 5.503290176391602}
Computed 8 images in  122.57385444641113 seconds
Last image: {'G_pixel_loss': 0.0003507344808895141, 'G_perceptual_loss': 0.9760646820068359, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 29.95828628540039, 'SSIM': 0.5509474277496338, 'Ma': 3.804765462875366, 'NIQE': 5.423925399780273, 'PI

Computed 8 images in  128.15016555786133 seconds
Last image: {'G_pixel_loss': 0.00014789022679906338, 'G_perceptual_loss': 1.2668899297714233, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.2098388671875, 'SSIM': 0.7591266632080078, 'Ma': 3.777698516845703, 'NIQE': 6.528995990753174, 'PI': 6.375648498535156}
Computed 8 images in  127.74032425880432 seconds
Last image: {'G_pixel_loss': 0.00015681167133152485, 'G_perceptual_loss': 1.9571490287780762, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 34.107418060302734, 'SSIM': 0.9225459098815918, 'Ma': 4.237111568450928, 'NIQE': 6.691256999969482, 'PI': 6.227072715759277}
Computed 8 images in  130.92233395576477 seconds
Last image: {'G_pixel_loss': 0.00015428618644364178, 'G_perceptual_loss': 1.2842119932174683, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.30174255371094, 'SSIM': 0.7663935422897339, 'Ma': 4.213319301605225, 'NIQE': 6.458624362945557,

### 8.3 Evaluate every kth epoch

In [None]:
if PRE_EVALUATE_HISTORY or GAN_EVALUATE_HISTORY:
    val_or_test = 'val'
    
    # Computing Ma is 100x more time consuming than anything else. It is not interesting to measure this for pretraining
    if METRIC_MA:
        PRE_GAN = ['gan']
    else:
        PRE_GAN = ['pre', 'gan']

    for pre_gan in PRE_GAN:
        for sensor in SENSORS[val_or_test]:
            if sensor == 'GE01':
                band_indices = MS_BANDS_GE01_IDXS
            elif sensor == 'WV02':
                band_indices = MS_BANDS_WV02_IDXS
            if pre_gan == 'pre':
                model_weights_dir = PRETRAIN_WEIGHTS_DIR
                eval_first_step = 1
                eval_prefix = EXPERIMENT + '-pre-'
            else:
                model_weights_dir = GAN_WEIGHTS_DIR
                eval_first_step = 0
                eval_prefix = EXPERIMENT + '-gan-'

            esrgan_epoch_evaluator(gan_model,
                                   model_weights_dir=model_weights_dir,
                                   model_weight_prefix=eval_prefix,
                                   dataset=ds_val[sensor],
                                   n_epochs=EVAL_N_EPOCHS,
                                   first_epoch=eval_first_step,
                                   steps_per_epoch=EVAL_STEPS_PER_EPOCH,
                                   k_epoch=25,
                                   csv_dir=LOGS_EXP_DIR + '/csv/' + val_or_test + '-' + sensor, 
                                   per_image=EVAL_PER_IMAGE, 
                                   verbose=0)

### 8.4 Comparison plots

In [None]:
idx = 10
import plotly.express as px
batch = next(iter(ds_train['WV02']))
ms = batch[0][1][idx,:,:,:3]
pan = batch[1][1][idx,:,:,:]
px.imshow(stretch_img(ms))
px.imshow(stretch_img(pan)[:,:,0], color_continuous_scale='gray')