# Experiment 01: Band configurations

In this experiment we will study how different band configurations affect the training process and end results.

### Experiment variations:

- E01-8: All 8 MS WorldView-2 bands is used.
- E01-6: The 6 MS WorldView-2 bands that overlap the PAN band (RGB+Y+RedEdge+NIR1)
- E01-4: The 4 MS WorldView-2 bands that are also available in GeoEye-1 (MS1 array RGB+NIR1)
 - In this variation we will also validate performance on the GeoEye-1 validation set
- E01-3: Only the 3 RGB bands from WorldView-2
 - Also validated on the GeoEye-1 validation set
 
### The notebook is divided into the following main sections:
1. Imports and configuration parameters
2. Tile generation (sampling of tiles from the satellite images)
3. Tile input pipelines (`tf.dataset` objects reading tiles from disk)
4. Building of models
5. Pretraining with L1 loss
6. Build the full ESRGAN model
7. GAN-training with L1 + Percep + GAN loss
8. Inspection of results

Training history is logged with TensorBoard.

## 1. Imports and configuration parameters

In [1]:
from modules.helpers import *
from modules.tile_generator import *
from modules.matlab_metrics import *
from modules.image_utils import *
from modules.tile_input_pipeline import *
from modules.models import *
from modules.evaluation import *

from modules.logging import *
from modules.train import *

import time

# Check GPUs and enable dynamic GPU memory use:",
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            # Prevent TensorFlow from allocating all memory of all GPUs:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [2]:
### MAIN SETTINGS ###############################################################################################
EXPERIMENT_NAMES = ['e01-8', 'e01-6', 'e01-4', 'e01-3']

# Select experiment variation to be run in THIS notebook:
EXPERIMENT = EXPERIMENT_NAMES[0]

# Turn on and off certain time consuming processes in the notebook:
GENERATE_TILES = False   # This should only be done once in experiment 01. All variations will read from the same
TILE_DENSITY_MAPS = False  # Loops through all tiles and compute density maps of where tiles have been sampled
CALCULATE_STATS = False  # Loops through all tiles and calculate mean and sd. Used for scaling
PRE_BUILD = True          # Step 1 of the training process
PRETRAIN = False          # Step 1 of the training process
GAN_BUILD = True          # Step 1 of the training process
GAN_TRAIN = False         # Step 2 of the training process
PRE_EVALUATE_LAST = True
GAN_EVALUATE_LAST = False
PRE_EVALUATE_HISTORY = False
GAN_EVALUATE_HISTORY = False

# Load metadata dataframe "meta" from repository root. 
# This dataframe keeops track of images and is used and updated throughout the notebook
meta = load_meta_pickle_csv('.', 'metadata_df', from_pickle=True)
#################################################################################################################

### PATHS #######################################################################################################
DATA_PATH = 'data/toulon-laspezia'
DATA_PATH_TILES = 'data/toulon-laspezia-tiles/e01'
DATA_PATH_TILES_P = {'train': DATA_PATH_TILES + '/train', 
                     'val': DATA_PATH_TILES + '/val', 
                     'test': DATA_PATH_TILES + '/test'}
LOGS_DIR = 'logs/' # Path to tensorboard logs and model checkpoint saves
LOGS_EXP_DIR = LOGS_DIR + EXPERIMENT
#################################################################################################################

### TILE GENERATION #############################################################################################
SENSORS_GENERATE = ['WV02', 'GE01']
AREAS_GENERATE = ['La_Spezia', 'Toulon']
meta = subset_by_areas_sensor(meta, areas=AREAS_GENERATE, sensors=SENSORS_GENERATE)
print('Sensors to generate tiles from:', SENSORS_GENERATE)
print('Areas to generate tiles from:', AREAS_GENERATE)

# Count images in partitions (train/val/test):
N_IMAGES_TOTAL = count_images(meta)
N_IMAGES = {'train': count_images_in_partition(meta, 'train'), 
            'val': count_images_in_partition(meta, 'val'), 
            'test': count_images_in_partition(meta, 'test')}
assert N_IMAGES_TOTAL == sum(N_IMAGES.values())  # Verify that different ways of counting adds up
print('Number of images in partitions', N_IMAGES)
print('Total number of images:', N_IMAGES_TOTAL)

TILES_PER_M2 = {'train': 2.0, 
                'val': 2.0, 
                'test': 2.0}

# Settings for whether to send every tile generated through a sea and cloud classifier
# This is useful if images consist of a lot of sea and clouds and you want to reduce the number of tiles
# with such monotone and less meaningful content. Classifier is trained on 2500 labeled tiles of various sizes
# where only tiles COMPLETELY covered by sea and/or clouds have been labelled "cloud/sea". 
# Validation accuracy around 0.95
CLOUD_SEA_REMOVAL = True
CLOUD_SEA_WEIGHTS_PATH = 'models/cloud-sea-classifier/cloudsea-effb0-augm-bicubic-pan-0.0005--200-0.129130.h5'
# Cutoff at inference time. Tiles with (quasi)-prob higher than cutoff will be classified as cloud and or sea:
CLOUD_SEA_PRED_CUTOFF = 0.95
# Setting to keep a certain proportion of cloud/sea tiles through the filter:
CLOUD_SEA_KEEP_RATE = 0.10

# GE01 images has some slight variations in resolution 0.5 +-0.05 m per pixel while WV02 is fixed at 0.5m
# Setting this to True will resize to as close as possible to 0.5m
# Not used in this notebook, but function is ready for use in module tile_generator.py
RESIZE_TO_PIXEL_SIZE = False
if RESIZE_TO_PIXEL_SIZE:
    RESIZE_RESAMPLING_METHOD = 'nearest'  # 'nearest', 'bicubic', 'bilinear'
    NEW_PIXEL_SIZE_PAN = 0.5
    RESIZE_DIR = DATA_PATH + '-resized'
#################################################################################################################

### SENSORS AND AREA EXPERIMENT SELECTION #######################################################################
# Sensors used in which experiment variation
SENSORS_EXP = {'e01-8': {'train': 'WV02', 'val': ['WV02'], 'test': ['WV02']}, 
               'e01-6': {'train': 'WV02', 'val': ['WV02'], 'test': ['WV02']}, 
               'e01-4': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}, 
               'e01-3': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}}
SENSORS = SENSORS_EXP[EXPERIMENT]

# Areas used in which experiment variation
AREAS_EXP = {'e01-8': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-6': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-4': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-3': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}}
AREAS = AREAS_EXP[EXPERIMENT]
#################################################################################################################

### TILE DIMENSIONS #############################################################################################
# Note larger size of val and test. This is needed for sensible calculation of Ma, NIQE and PI calculation
SR_FACTOR = 4
MS_SIZE = {'train': 32, 'val': 128, 'test': 128}
PAN_SIZE = {'train': MS_SIZE['train'] * SR_FACTOR, 
            'val': MS_SIZE['val'] * SR_FACTOR, 
            'test': MS_SIZE['test'] * SR_FACTOR}
print('MS (LR) tile size:', MS_SIZE)
print('PAN (HR) tile size:', PAN_SIZE)
print('SR factor:', SR_FACTOR)
#################################################################################################################

### BAND (CHANNEL) CONFIGURATIONS ###############################################################################
# This is the essence of experiment 01
# Selection of bands is done in the tile input pipeline

# Selecting bands from the 8 bands of WV02:
WV02_FULL_BAND_CONFIG = get_sensor_bands('WV02', meta)
WV02_EXP_BAND_CONFIGS = {'e01-8': WV02_FULL_BAND_CONFIG,                          # 8 (all) bands
                        'e01-6': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items()  # 6 bands (BGYR+RE+NIR)
                                  if k not in ['Coastal', 'NIR2']}, 
                        'e01-4': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items()  # 4 bands (BGR+NIR)
                                  if k in ['Blue', 'Green', 'Red', 'NIR']},
                        'e01-3': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items()  # 3 bands (BGR)
                                  if k in ['Blue', 'Green', 'Red']}}
MS_BANDS_WV02_CONFIG = WV02_EXP_BAND_CONFIGS[EXPERIMENT]
if EXPERIMENT == 'e01-8':
    # We set this to 'all' in order to not pass e01-8 tiles through a band selection function (no reason to)
    MS_BANDS_WV02_IDXS = 'all' 
else:
    # For the other experiment variations we need lists of indices of the bands to be selected
    MS_BANDS_WV02_IDXS = list(MS_BANDS_WV02_CONFIG.values())

N_MS_BANDS = len(MS_BANDS_WV02_CONFIG.values()) # The number of MS bands in this experiment variation

# Selecting bands from the 4 bands of GE01:
GE01_FULL_BAND_CONFIG = get_sensor_bands('GE01', meta)                            
GE01_EXP_BAND_CONFIGS = {'e01-8': {None: None},                                   # not enough bands in GE01
                        'e01-6': {None: None},                                    # not enough bands in GE01
                        'e01-4': GE01_FULL_BAND_CONFIG,                           # 4 (all) bands (BGR+NIR)
                        'e01-3': {k:v for (k,v) in GE01_FULL_BAND_CONFIG.items()  # 3 bands (BGR)
                                  if k not in ['NIR']}}
MS_BANDS_GE01_CONFIG = GE01_EXP_BAND_CONFIGS[EXPERIMENT]
if EXPERIMENT == 'e01-4':
    MS_BANDS_GE01_IDXS = 'all'
else:
    MS_BANDS_GE01_IDXS = list(MS_BANDS_GE01_CONFIG.values())
print('MS (LR) Band Config WV02:', MS_BANDS_WV02_CONFIG)
print('MS (LR) Band Config GE01:', MS_BANDS_GE01_CONFIG)

N_PAN_BANDS = 1 # Obviously only 1 panchromatic band
#################################################################################################################

### MODEL PARAMETERS ############################################################################################
BATCH_SIZE = {'train': 16, 'val': 8, 'test': 8}
print('Batch sizes:', BATCH_SIZE)

# RRDB Generator Model parameters 
N_BLOCKS = 16 # 23 - Deeper means potential to capture more complex relationships, at the cost of training time
N_FILTERS = 64 # Baseline setting that is not tinkered with in this repository
#################################################################################################################

### PRETRAINING SETTINGS ########################################################################################
PRE_EPOCHS = 400
PRE_TRAIN_STEPS = 1000  # per epoch
PRE_VAL_STEPS = 0     # per epoch
print('Pretraining - Total steps:', PRE_EPOCHS * PRE_TRAIN_STEPS)

# Number of batches to save every epoch in TensorBoard
TRAIN_N_BATCHES_SAVE = 2
VAL_N_BATCHES_SAVE = 2

# Optimizer settings:
PRETRAIN_LOSS = 'l1'    # Official
PRETRAIN_LR = 5e-5      # Tuned and found stable for this particular experiment
#PRETRAIN_LR = 0.0002   # Official
PRETRAIN_BETA_1 = 0.9   # Official
PRETRAIN_BETA_2 = 0.999 # Official
# Note: Official implementation also uses stepwise learning rate scheduler. 
# This is avoided here as it is deemed not central to the experiment to "squeeze" out last performance and it 
# complicates comparisons between experiment variations
#################################################################################################################

### GAN TRAINING SETTINGS #######################################################################################
GAN_EPOCHS = 400
GAN_TRAIN_STEPS = 1000
GAN_VAL_STEPS = 250
# Proportion of val batches that will go through ma and niqe metric calculation
# MA_NIQE_PROPORTION = 0.04  # The calculation is very time consuming
MA_NIQE_PROPORTION = 1  # The calculation is very time consuming
print('GAN training - Total steps:', GAN_EPOCHS * GAN_TRAIN_STEPS)

# Weights for each loss in the composite loss function
G_LOSS_PIXEL_W = 0.01       # Official
G_LOSS_PERCEP_W = 1.0       # Official
G_LOSS_GENERATOR_W = 0.005  # Official

# Optimizer settings:
#GAN_G_LR = 1e-4 # Official
#GAN_D_LR = 1e-4 # Official
GAN_G_LR = 2e-5
GAN_D_LR = 2e-5
G_BETA_1, D_BETA_1 = 0.9, 0.9      # Official
G_BETA_2, D_BETA_2 = 0.999, 0.999  # Official
# Note: Official implementation also uses stepwise learning rate scheduler. 
# This is avoided here as it is deemed not central to the experiment to "squeeze" out last performance and it 
# complicates comparisons between experiment variations

# Path to the pretraining weights that is the starting point of GAN training:
PRETRAIN_WEIGHTS_DIRS = {'e01-8': LOGS_EXP_DIR + '/models/' + 'e01-8-pre_20210116-194500/', 
                         'e01-6': LOGS_EXP_DIR + '/models/' + 'e01-6-pre_20210122-091421/', 
                         'e01-4': LOGS_EXP_DIR + '/models/' + 'e01-4-pre_20210119-101939/', 
                         'e01-3': LOGS_EXP_DIR + '/models/' + 'e01-3-pre_20210124-153415/'
                        }
PRETRAIN_WEIGHTS_DIR = PRETRAIN_WEIGHTS_DIRS[EXPERIMENT]
PRETRAIN_WEIGHTS_PATH = PRETRAIN_WEIGHTS_DIR + EXPERIMENT + '-pre-400.h5'

# Path to the gan-training weights that will be 
GAN_WEIGHTS_DIRS = {'e01-8': LOGS_EXP_DIR + '/models/' + 'e01-8-gan_20210222-124759/', 
                    'e01-6': LOGS_EXP_DIR + '/models/' + 'e01-6-gan_20210216-152032/', 
                    'e01-4': LOGS_EXP_DIR + '/models/' + 'e01-4-gan_20210212-144735/', 
                    'e01-3': LOGS_EXP_DIR + '/models/' + 'e01-3-gan_20210219-135014/'
                   }
GAN_WEIGHTS_DIR = GAN_WEIGHTS_DIRS[EXPERIMENT]
GAN_WEIGHTS_PATH = GAN_WEIGHTS_DIR + EXPERIMENT + '-gan-G-399.h5'
#################################################################################################################

### MATLAB METRICS ##############################################################################################
# Calculate Ma, NIQE and Perceptual Index (PI) metrics on the validation set(s) during GAN training:
# PI was metric used in PIRM2018 competition https://github.com/roimehrez/PIRM2018
METRIC_MA = True
METRIC_NIQE = True
if METRIC_MA and METRIC_NIQE:
    METRIC_PI = True
else:
    METRIC_PI = False

# The number of pixels to be shaved off the border of the tile before calculating Ma/NIQE/PI (ignore border effects)
SHAVE_WIDTH = 4 # Official (as used in PIRM2018 evaluation)
# Ma/NIQE/PI calculation is done with official matlab repositories through MATLAB Engine API for Python
MATLAB_PATH = 'modules/matlab' # path to repositories
#################################################################################################################

### EVALUTAION ##################################################################################################
if PRE_EVALUATE_LAST or GAN_EVALUATE_LAST:
    METRIC_MA = True
    METRIC_NIQE = True
    if METRIC_MA and METRIC_NIQE:
        METRIC_PI = True
    else:
        METRIC_PI = False
        

EVAL_STEPS_PER_EPOCH = 'all'
EVAL_N_EPOCHS = 400
#EVAL_SENSOR = 'WV02'
EVAL_PER_IMAGE = True
    
if PRE_EVALUATE_HISTORY or GAN_EVALUATE_HISTORY:
    METRIC_MA = False
    METRIC_NIQE = True
    if METRIC_MA and METRIC_NIQE:
        METRIC_PI = True
    else:
        METRIC_PI = False
        
#if PRE_EVALUATE_HISTORY:
#    EVAL_WEIGHTS_DIR = PRETRAIN_WEIGHTS_DIR
#    EVAL_FIRST_STEP = 1
#    EVAL_PREFIX = EXPERIMENT + '-pre-'
#elif GAN_EVALUATE_HISTORY:
#    EVAL_WEIGHTS_DIR = GAN_WEIGHTS_DIR
#    EVAL_FIRST_STEP = 0
#    EVAL_PREFIX = EXPERIMENT + '-gan-'
    
print('MATLAB Metrics:')
print('Ma:', METRIC_MA)
print('NIQE:', METRIC_NIQE)
print('Perceptual Index (PI):', METRIC_PI)

Sensors to generate tiles from: ['WV02', 'GE01']
Areas to generate tiles from: ['La_Spezia', 'Toulon']
Number of images in partitions {'train': 22, 'val': 19, 'test': 21}
Total number of images: 62
MS (LR) tile size: {'train': 32, 'val': 128, 'test': 128}
PAN (HR) tile size: {'train': 128, 'val': 512, 'test': 512}
SR factor: 4
MS (LR) Band Config WV02: {'Coastal': 0, 'Blue': 1, 'Green': 2, 'Yellow': 3, 'Red': 4, 'RedEdge': 5, 'NIR': 6, 'NIR2': 7}
MS (LR) Band Config GE01: {None: None}
Batch sizes: {'train': 16, 'val': 8, 'test': 8}
Pretraining - Total steps: 400000
GAN training - Total steps: 400000
MATLAB Metrics:
Ma: True
NIQE: True
Perceptual Index (PI): True


## 2. Tile generation

### 2.1 Image resizing

Function `resize_sat_img_to_new_pixel_size` available in `modules.tile_generator`. Not used in this notebook

### 2.2 Tile allocation

We allocate `n_tiles` to each satellite image in proportion to the area covered by the satellite image. We adjust `n_tiles` by the argument `tiles_per_m2`. If `tiles_per_m2=1.0` then `n_tiles` is set deterministically to a value so that a square meter of satellite image is expected to be covered by `1.0` tile.

In [3]:
if GENERATE_TILES:
    meta = allocate_tiles_by_expected(meta, 
                                      override_pan_pixel_size=RESIZE_TO_PIXEL_SIZE,
                                      by_partition=True, 
                                      tiles_per_m2_train_val_test=(TILES_PER_M2['train'], 
                                                                   TILES_PER_M2['val'], 
                                                                   TILES_PER_M2['test']),
                                      pan_tile_size_train_val_test=(PAN_SIZE['train'], 
                                                                    PAN_SIZE['val'], 
                                                                    PAN_SIZE['test']),
                                      new_column_name='n_tiles')
else:
    # Load meta dataframe that was updated at tile generation time
    meta = load_meta_pickle_csv(DATA_PATH_TILES, 'metadata_tile_allocation', from_pickle=True)

n_tiles = {'train': count_tiles_in_partition(meta, 'train'),
           'val': count_tiles_in_partition(meta, 'val'), 
           'test':  count_tiles_in_partition(meta, 'test')}
n_tiles_total = count_tiles(meta)
assert n_tiles_total == sum(n_tiles.values())
print('Number of tiles per partition:')
print(n_tiles)
print('Total number of tiles:', n_tiles_total)

Number of tiles per partition:
{'train': 129221, 'val': 8113, 'test': 9293}
Total number of tiles: 146627


### 2.3 Tile generation to disk

In [4]:
if GENERATE_TILES:
    meta = generate_all_tiles(meta, 
                              save_dir=DATA_PATH_TILES, 
                              sr_factor=SR_FACTOR, 
                              by_partition=True,
                              ms_tile_size_train_val_test=(MS_SIZE['train'], MS_SIZE['val'], MS_SIZE['test']), 
                              cloud_sea_removal=CLOUD_SEA_REMOVAL, 
                              cloud_sea_weights_path=CLOUD_SEA_WEIGHTS_PATH, 
                              cloud_sea_pred_cutoff=CLOUD_SEA_PRED_CUTOFF,
                              cloud_sea_keep_rate=CLOUD_SEA_KEEP_RATE,
                              save_meta_to_disk=True)

In [5]:
if TILE_DENSITY_MAPS:
    for row in meta.iterrows():
        img_uid = row[0]
        density = tile_density_map(DATA_PATH_TILES, 
                                   row[1], 
                                   pan_or_ms='pan',
                                   density_dtype='uint8',
                                   write_to_disk=True,
                                   write_dir=DATA_PATH_TILES + '/density-maps', 
                                   write_filename=img_uid)
    # Plot last density
    plt.imshow(density)

In [6]:
if CALCULATE_STATS:
    train_tiles_mean, train_tiles_sd = mean_sd_of_train_tiles(DATA_PATH_TILES, 
                                                              sample_proportion=1.0, 
                                                              write_json=True)
else:
    train_tiles_mean, train_tiles_sd = read_mean_sd_json(DATA_PATH_TILES)

Loaded mean 341.3 and sd 128.4 from json file @ data/toulon-laspezia-tiles/e01/train_mean_sd.json


## 3. Data input pipeline from disk

### 3.1 Training set

In [7]:
SHUFFLE_BUFFER_SIZE = {'train': n_tiles['train'],  # 100
                       'val': n_tiles['val'],  # 100
                       'test': n_tiles['test']}  # 100

train_val_test = 'train'
sensor = SENSORS[train_val_test]
ds_train = {sensor: GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                   batch_size=BATCH_SIZE[train_val_test], 
                                   ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                   pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                   sensor=sensor,
                                   band_selection=MS_BANDS_WV02_IDXS, 
                                   mean_correction=train_tiles_mean,
                                   cache_memory=True,
                                   cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                    + train_val_test + '-' + sensor + '_cache'), 
                                   repeat=True, 
                                   shuffle=True, 
                                   shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test])
           }
# Getting the scaled output range from the scaler. Needed to calculate PSNR and SSIM:
scaled_range = ds_train[sensor].get_scaler_output_range(print_ranges=True)

# Returning the actual tf.data.dataset object:
ds_train[sensor] = ds_train[sensor].get_dataset()
print(ds_train.keys())

Scaler ranges:
Input (uint) min, max: 0 2047
Input (uint) range: 2048
Output (float) range 1.2006480509994506
Output (float) min, max: -0.2000617970682984 1.0
dict_keys(['WV02'])


### 3.2 Validation set

In [8]:
# Validation set can have several sensors and is organized in a dictionary
# structure: ds_val = {sensor: dataset} ... ex: ds_val = {'WV02': dataset_with_only_WV02_images}
train_val_test = 'val'
ds_val = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_val[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                    batch_size=BATCH_SIZE[train_val_test], 
                                    ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                    pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                    sensor=sensor,
                                    band_selection=band_indices, 
                                    mean_correction=train_tiles_mean,
                                    cache_memory=True,
                                    cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                   + train_val_test + '-' + sensor + '_cache'), 
                                    repeat=True, 
                                    shuffle=True, 
                                    shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test])
    ds_val[sensor] = ds_val[sensor].get_dataset()
print(ds_val.keys())

dict_keys(['WV02'])


## 4. Build preliminary models

### 4.1 Bicubic baseline model

In [9]:
bicubic = build_deterministic_sr_model(upsample_factor=SR_FACTOR,
                                       resize_method='bicubic',
                                       loss='mean_absolute_error',
                                       metrics=('PSNR', 'SSIM'),
                                       scaled_range=scaled_range)

### 4.2 ESRGAN Generator model (pretrain version)

In [10]:
if PRE_BUILD:
    pretrain_model =  build_generator(pretrain_or_gan='pretrain', 
                                      pretrain_learning_rate=PRETRAIN_LR, 
                                      pretrain_loss_l1_l2=PRETRAIN_LOSS,
                                      pretrain_beta_1=PRETRAIN_BETA_1, 
                                      pretrain_beta_2=PRETRAIN_BETA_2, 
                                      pretrain_metrics=('PSNR', 'SSIM'),
                                      scaled_range=scaled_range, 
                                      n_channels_in=N_MS_BANDS, 
                                      n_channels_out=N_PAN_BANDS, 
                                      height_width_in=None,  # None will make network image size agnostic
                                      n_filters=N_FILTERS, 
                                      n_blocks=N_BLOCKS)
    # pretrain_model.summary()

## 5. Pretraining with L1 loss

In [11]:
if PRETRAIN:
    history = pretrain_esrgan(generator=pretrain_model,
                              ds_train_dict=ds_train,
                              epochs=PRE_EPOCHS,
                              steps_per_epoch=PRE_TRAIN_STEPS,
                              initial_epoch=0,
                              validate=True,
                              ds_val_dict=ds_val,
                              val_steps=PRE_VAL_STEPS,
                              model_name=EXPERIMENT + '-pre',
                              tag=EXPERIMENT,
                              log_tensorboard=True,
                              tensorboard_logs_dir=LOGS_EXP_DIR + '/tb',
                              save_models=True,
                              models_save_dir=LOGS_EXP_DIR + '/models',
                              save_weights_only=True,
                              log_train_images=True,
                              n_train_image_batches=TRAIN_N_BATCHES_SAVE,
                              log_val_images=True,
                              n_val_image_batches=VAL_N_BATCHES_SAVE)

## 6. Build the full ESRGAN Model

In [12]:
if GAN_BUILD:
    gan_model = build_esrgan_model(PRETRAIN_WEIGHTS_PATH,
                                   n_channels_in=N_MS_BANDS, 
                                   n_channels_out=N_PAN_BANDS, 
                                   n_filters=N_FILTERS, 
                                   n_blocks=N_BLOCKS, 
                                   pan_shape=(PAN_SIZE['train'], PAN_SIZE['train'], N_PAN_BANDS),
                                   G_lr=GAN_G_LR, 
                                   D_lr=GAN_D_LR, 
                                   G_beta_1=G_BETA_1, 
                                   G_beta_2=G_BETA_2, 
                                   D_beta_1=D_BETA_1, 
                                   D_beta_2=D_BETA_2,
                                   G_loss_pixel_w=G_LOSS_PIXEL_W, 
                                   G_loss_pixel_l1_l2='l1',
                                   G_loss_percep_w=G_LOSS_PERCEP_W, 
                                   G_loss_percep_l1_l2='l1', 
                                   G_loss_percep_layer=54,
                                   G_loss_percep_before_act=True,
                                   G_loss_generator_w=G_LOSS_GENERATOR_W,
                                   metric_reg=False, 
                                   metric_ma=METRIC_MA, 
                                   metric_niqe=METRIC_NIQE, 
                                   ma_niqe_proportion=MA_NIQE_PROPORTION,
                                   matlab_wd_path='modules/matlab',
                                   scale_mean=train_tiles_mean, 
                                   scaled_range=scaled_range, 
                                   shave_width=SHAVE_WIDTH)

Starting matlab.engine ...
matlab.engine started


## 7. GAN training

In [13]:
if GAN_TRAIN:
    history = gan_train_esrgan(esrgan_model=gan_model,
                               ds_train_dict=ds_train,
                               epochs=GAN_EPOCHS,
                               steps_per_epoch=GAN_TRAIN_STEPS,
                               initial_epoch=0,
                               validate=True,
                               ds_val_dict=ds_val,
                               val_steps=GAN_VAL_STEPS,
                               model_name=EXPERIMENT + '-gan',
                               tag=EXPERIMENT,
                               log_tensorboard=True,
                               tensorboard_logs_dir=LOGS_EXP_DIR + '/tb',
                               save_models=True,
                               models_save_dir=LOGS_EXP_DIR + '/models',
                               save_weights_only=True,
                               log_train_images=True,
                               n_train_image_batches=TRAIN_N_BATCHES_SAVE,
                               log_val_images=True,
                               n_val_image_batches=VAL_N_BATCHES_SAVE)

## 8. Evaluation

### 8.1 Data input pipelines for final evaluation

The pipeline is modified to include the file paths of the tiles/patches so that it is possible to log performance metrics for individual files and by extension for individual satellite images.

#### 8.1.1 Validation set

In [14]:
# Validation set can have several sensors and is organized in a dictionary
# structure: ds_val = {sensor: dataset} ... ex: ds_val = {'WV02': dataset_with_only_WV02_images}
train_val_test = 'val'
ds_val = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_val[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                    batch_size=BATCH_SIZE[train_val_test], 
                                    ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                    pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                    sensor=sensor,
                                    band_selection=band_indices, 
                                    mean_correction=train_tiles_mean,
                                    cache_memory=False,
                                    cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                   + train_val_test + '-' + sensor + '_filepath_cache'), 
                                    repeat=False, 
                                    shuffle=False, 
                                    shuffle_buffer_size=0, #SHUFFLE_BUFFER_SIZE[train_val_test], 
                                    include_file_paths=True)
    ds_val[sensor] = ds_val[sensor].get_dataset()
print(ds_val.keys())

dict_keys(['WV02'])


#### 8.1.2 Test set

In [15]:
train_val_test = 'test'
ds_test = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_test[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                     batch_size=BATCH_SIZE[train_val_test], 
                                     ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                     pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                     sensor=sensor,
                                     band_selection=band_indices, 
                                     mean_correction=train_tiles_mean,
                                     cache_memory=False,
                                     cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                    + train_val_test + '-' + sensor + '_filepath_cache'), 
                                     repeat=False, 
                                     shuffle=False, 
                                     shuffle_buffer_size=0)
    ds_test[sensor] = ds_test[sensor].get_dataset()
print(ds_test.keys())

dict_keys(['WV02'])


### 8.2 Evaluate last epoch

In [16]:
if PRE_EVALUATE_LAST or GAN_EVALUATE_LAST:
    val_or_test = 'val'
    # SENSORS = {'val': ['GE01']}

    # Computing Ma is 100x more time consuming than anything else. It is not interesting to measure this for pretraining
    #if METRIC_MA:
    #    PRE_GAN = ['gan']
    #else:
    #    PRE_GAN = ['pre', 'gan']
    
    if PRE_EVALUATE_LAST and not GAN_EVALUATE_LAST:
        PRE_GAN = ['pre']
    if not PRE_EVALUATE_LAST and GAN_EVALUATE_LAST:
        PRE_GAN = ['gan']

    print(PRE_GAN, SENSORS[val_or_test])
    for pre_gan in PRE_GAN:
        
        for sensor in SENSORS[val_or_test]:
            # if sensor == 'GE01':
            #     continue
                
            if sensor == 'GE01':
                band_indices = MS_BANDS_GE01_IDXS
            elif sensor == 'WV02':
                band_indices = MS_BANDS_WV02_IDXS
            if pre_gan == 'pre':
                gan_model.G.load_weights(PRETRAIN_WEIGHTS_PATH)
            else:
                gan_model.G.load_weights(GAN_WEIGHTS_PATH)

            print(pre_gan, sensor)
            start = time.time()
            results_df = esrgan_evaluate(model=gan_model, 
                                         dataset=ds_val[sensor], 
                                         steps='all', 
                                         per_image=True, 
                                         write_csv=True,
                                         csv_path=str(LOGS_EXP_DIR + '/csv/' + 'final_epoch-' 
                                                      + pre_gan + '-' + val_or_test + '-' + sensor + '.csv'), 
                                         verbose=1
                                        )
            end = time.time()
            print(str((end - start) / 60), 'minutes')

['pre'] ['WV02']
pre WV02
Computed 8 images in  145.23655319213867 seconds
Last image: {'G_pixel_loss': 7.47361482353881e-05, 'G_perceptual_loss': 1.6200963258743286, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.599449157714844, 'SSIM': 0.9404297471046448, 'Ma': 2.301645278930664, 'NIQE': 8.172527313232422, 'PI': 7.935441017150879}
Computed 8 images in  131.40617179870605 seconds
Last image: {'G_pixel_loss': 2.3487824364565313e-05, 'G_perceptual_loss': 0.9412182569503784, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 41.20809555053711, 'SSIM': 0.9869775176048279, 'Ma': 3.6105332374572754, 'NIQE': 10.508476257324219, 'PI': 8.44897174835205}
Computed 8 images in  129.67709708213806 seconds
Last image: {'G_pixel_loss': 0.00015053717652335763, 'G_perceptual_loss': 3.5583670139312744, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.037029266357422, 'SSIM': 0.9002884030342102, 'Ma': 2.9595882892608643

Computed 8 images in  128.1093578338623 seconds
Last image: {'G_pixel_loss': 1.0520916475798003e-05, 'G_perceptual_loss': 0.4996829628944397, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 54.45320510864258, 'SSIM': 0.9971944093704224, 'Ma': 2.8136515617370605, 'NIQE': 14.863899230957031, 'PI': 11.025123596191406}
Computed 8 images in  127.28531694412231 seconds
Last image: {'G_pixel_loss': 8.59976134961471e-05, 'G_perceptual_loss': 1.308807373046875, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.319515228271484, 'SSIM': 0.922167181968689, 'Ma': 2.5772342681884766, 'NIQE': 5.8749098777771, 'PI': 6.648838043212891}
Computed 8 images in  127.5659065246582 seconds
Last image: {'G_pixel_loss': 6.631102587562054e-05, 'G_perceptual_loss': 0.9050847291946411, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.918697357177734, 'SSIM': 0.9442502856254578, 'Ma': 2.254753828048706, 'NIQE': 7.310122013092041, 'P

Computed 8 images in  125.78887391090393 seconds
Last image: {'G_pixel_loss': 0.0001381325419060886, 'G_perceptual_loss': 2.9790825843811035, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.03386688232422, 'SSIM': 0.8572788834571838, 'Ma': 2.699518918991089, 'NIQE': 6.519314765930176, 'PI': 6.909897804260254}
Computed 8 images in  126.06020140647888 seconds
Last image: {'G_pixel_loss': 0.00015595524746458977, 'G_perceptual_loss': 3.235534906387329, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.97201919555664, 'SSIM': 0.8734737634658813, 'Ma': 2.8327901363372803, 'NIQE': 5.970077037811279, 'PI': 6.568643569946289}
Computed 8 images in  126.45787954330444 seconds
Last image: {'G_pixel_loss': 9.729892553878017e-06, 'G_perceptual_loss': 0.5042688250541687, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 58.065399169921875, 'SSIM': 0.9990319013595581, 'Ma': 2.719608783721924, 'NIQE': 19.941120147705078,

Computed 8 images in  126.45327973365784 seconds
Last image: {'G_pixel_loss': 4.1853560105664656e-05, 'G_perceptual_loss': 0.9674561619758606, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 45.825069427490234, 'SSIM': 0.9772932529449463, 'Ma': 2.100830554962158, 'NIQE': 10.568958282470703, 'PI': 9.234064102172852}
Computed 8 images in  127.47836947441101 seconds
Last image: {'G_pixel_loss': 0.00010354364349041134, 'G_perceptual_loss': 1.2291299104690552, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.356689453125, 'SSIM': 0.9321893453598022, 'Ma': 2.7871248722076416, 'NIQE': 9.107011795043945, 'PI': 8.159943580627441}
Computed 8 images in  127.50968837738037 seconds
Last image: {'G_pixel_loss': 0.0005202078027650714, 'G_perceptual_loss': 1.521957516670227, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 26.762971878051758, 'SSIM': 0.5065419673919678, 'Ma': 2.823967933654785, 'NIQE': 6.587610244750977,

Computed 8 images in  126.32273149490356 seconds
Last image: {'G_pixel_loss': 0.00022886201622895896, 'G_perceptual_loss': 1.4842956066131592, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.77513885498047, 'SSIM': 0.8567222356796265, 'Ma': 2.2031922340393066, 'NIQE': 7.847066879272461, 'PI': 7.821937561035156}
Computed 8 images in  126.00715398788452 seconds
Last image: {'G_pixel_loss': 0.0002637822472024709, 'G_perceptual_loss': 1.95941960811615, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 29.813541412353516, 'SSIM': 0.696705162525177, 'Ma': 2.796229362487793, 'NIQE': 6.721340656280518, 'PI': 6.962555885314941}
Computed 8 images in  126.25784826278687 seconds
Last image: {'G_pixel_loss': 0.000649286201223731, 'G_perceptual_loss': 2.3508718013763428, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 24.820737838745117, 'SSIM': 0.2311670184135437, 'Ma': 3.033571720123291, 'NIQE': 6.863311767578125, 'P

Computed 8 images in  126.25260162353516 seconds
Last image: {'G_pixel_loss': 3.6964633181924e-05, 'G_perceptual_loss': 1.3807603120803833, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 43.702274322509766, 'SSIM': 0.966127872467041, 'Ma': 3.040699005126953, 'NIQE': 7.568190097808838, 'PI': 7.263745307922363}
Computed 8 images in  126.1197030544281 seconds
Last image: {'G_pixel_loss': 0.0006464230245910585, 'G_perceptual_loss': 2.2891290187835693, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 24.73272705078125, 'SSIM': 0.3258768320083618, 'Ma': 3.0449726581573486, 'NIQE': 7.367308616638184, 'PI': 7.161168098449707}
Computed 8 images in  126.63765454292297 seconds
Last image: {'G_pixel_loss': 0.0005923075950704515, 'G_perceptual_loss': 1.4764176607131958, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 25.823392868041992, 'SSIM': 0.41505345702171326, 'Ma': 2.5662953853607178, 'NIQE': 7.192307472229004, '

Computed 8 images in  127.48994445800781 seconds
Last image: {'G_pixel_loss': 9.830576891545206e-05, 'G_perceptual_loss': 2.06339430809021, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.68433380126953, 'SSIM': 0.9246266484260559, 'Ma': 2.7441234588623047, 'NIQE': 6.598148822784424, 'PI': 6.9270124435424805}
Computed 8 images in  127.71677803993225 seconds
Last image: {'G_pixel_loss': 7.612198533024639e-05, 'G_perceptual_loss': 1.3975906372070312, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.1971321105957, 'SSIM': 0.9377127885818481, 'Ma': 2.8082804679870605, 'NIQE': 7.200484752655029, 'PI': 7.196102142333984}
Computed 8 images in  126.84067702293396 seconds
Last image: {'G_pixel_loss': 3.4770833735819906e-05, 'G_perceptual_loss': 0.9089785814285278, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 43.49952697753906, 'SSIM': 0.9849100112915039, 'Ma': 3.0480754375457764, 'NIQE': 9.831851959228516, 

Computed 8 images in  127.70089030265808 seconds
Last image: {'G_pixel_loss': 2.9920838642283343e-05, 'G_perceptual_loss': 0.955250084400177, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 45.1644401550293, 'SSIM': 0.986068069934845, 'Ma': 2.9649906158447266, 'NIQE': 11.091964721679688, 'PI': 9.06348705291748}
Computed 8 images in  128.5061900615692 seconds
Last image: {'G_pixel_loss': 6.329677125904709e-05, 'G_perceptual_loss': 1.2464447021484375, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 41.100257873535156, 'SSIM': 0.9586089253425598, 'Ma': 2.85874080657959, 'NIQE': 7.326507091522217, 'PI': 7.233882904052734}
Computed 8 images in  127.28545999526978 seconds
Last image: {'G_pixel_loss': 7.848865789128467e-05, 'G_perceptual_loss': 1.8034541606903076, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.91291427612305, 'SSIM': 0.9381793141365051, 'Ma': 2.85160231590271, 'NIQE': 6.363280773162842, 'PI':

Computed 8 images in  127.3896701335907 seconds
Last image: {'G_pixel_loss': 5.811162191093899e-05, 'G_perceptual_loss': 1.57468581199646, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.11328125, 'SSIM': 0.9555395245552063, 'Ma': 2.9852545261383057, 'NIQE': 6.420994281768799, 'PI': 6.717869758605957}
Computed 8 images in  127.67100048065186 seconds
Last image: {'G_pixel_loss': 6.411399226635695e-05, 'G_perceptual_loss': 1.7028744220733643, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.835140228271484, 'SSIM': 0.9547625780105591, 'Ma': 2.8451943397521973, 'NIQE': 6.339192867279053, 'PI': 6.746999263763428}
Computed 8 images in  127.38433051109314 seconds
Last image: {'G_pixel_loss': 1.2599103683896828e-05, 'G_perceptual_loss': 0.8141200542449951, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 50.942230224609375, 'SSIM': 0.9962795972824097, 'Ma': 2.773202419281006, 'NIQE': 14.054986953735352, 'PI':

Computed 8 images in  127.7380862236023 seconds
Last image: {'G_pixel_loss': 7.653308421140537e-05, 'G_perceptual_loss': 1.4773311614990234, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.808372497558594, 'SSIM': 0.9340379238128662, 'Ma': 2.8470025062561035, 'NIQE': 7.6036577224731445, 'PI': 7.378327369689941}
Computed 8 images in  127.2171401977539 seconds
Last image: {'G_pixel_loss': 1.8094315237249248e-05, 'G_perceptual_loss': 0.8279811143875122, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 52.07252883911133, 'SSIM': 0.9950926899909973, 'Ma': 2.507859468460083, 'NIQE': 12.457260131835938, 'PI': 9.974700927734375}
Computed 8 images in  126.98083686828613 seconds
Last image: {'G_pixel_loss': 6.8830449890811e-05, 'G_perceptual_loss': 1.2022511959075928, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 40.843910217285156, 'SSIM': 0.9369421005249023, 'Ma': 2.7202110290527344, 'NIQE': 8.240160942077637,

Computed 8 images in  121.9372410774231 seconds
Last image: {'G_pixel_loss': 0.00012876898108515888, 'G_perceptual_loss': 2.127835988998413, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 32.96796417236328, 'SSIM': 0.91249680519104, 'Ma': 2.8528149127960205, 'NIQE': 6.5163254737854, 'PI': 6.831755638122559}
Computed 8 images in  122.15741658210754 seconds
Last image: {'G_pixel_loss': 9.387306636199355e-05, 'G_perceptual_loss': 1.6325818300247192, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.05823516845703, 'SSIM': 0.9186371564865112, 'Ma': 2.6672611236572266, 'NIQE': 7.662398815155029, 'PI': 7.4975690841674805}
Computed 8 images in  122.14211511611938 seconds
Last image: {'G_pixel_loss': 6.593346915906295e-05, 'G_perceptual_loss': 1.295230507850647, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 41.138465881347656, 'SSIM': 0.9461078643798828, 'Ma': 2.885856866836548, 'NIQE': 7.797994136810303, 'PI'

Computed 8 images in  121.67159175872803 seconds
Last image: {'G_pixel_loss': 0.00012545708159450442, 'G_perceptual_loss': 2.2765605449676514, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 32.450355529785156, 'SSIM': 0.9318416118621826, 'Ma': 3.2173030376434326, 'NIQE': 7.207655906677246, 'PI': 6.995176315307617}
Computed 8 images in  121.47310614585876 seconds
Last image: {'G_pixel_loss': 9.869928908301517e-05, 'G_perceptual_loss': 2.2080273628234863, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.600616455078125, 'SSIM': 0.9390168190002441, 'Ma': 2.8881380558013916, 'NIQE': 6.178308010101318, 'PI': 6.645085334777832}
Computed 8 images in  121.35752654075623 seconds
Last image: {'G_pixel_loss': 7.660155097255483e-05, 'G_perceptual_loss': 1.2215192317962646, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 39.927249908447266, 'SSIM': 0.9183610081672668, 'Ma': 2.827385187149048, 'NIQE': 7.9653902053833

Computed 8 images in  121.7257239818573 seconds
Last image: {'G_pixel_loss': 0.00010349901276640594, 'G_perceptual_loss': 2.310232639312744, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.78249740600586, 'SSIM': 0.9103509783744812, 'Ma': 3.0078537464141846, 'NIQE': 6.618074893951416, 'PI': 6.805110931396484}
Computed 8 images in  121.76836395263672 seconds
Last image: {'G_pixel_loss': 7.571745663881302e-05, 'G_perceptual_loss': 2.00276517868042, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.38750076293945, 'SSIM': 0.9443415403366089, 'Ma': 2.696770668029785, 'NIQE': 6.079867362976074, 'PI': 6.6915483474731445}
Computed 8 images in  122.07204151153564 seconds
Last image: {'G_pixel_loss': 0.00012222272926010191, 'G_perceptual_loss': 2.5182762145996094, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 35.7931022644043, 'SSIM': 0.9010745286941528, 'Ma': 2.9439468383789062, 'NIQE': 7.578108310699463, 'P

Computed 8 images in  121.757009267807 seconds
Last image: {'G_pixel_loss': 0.00010385558562120423, 'G_perceptual_loss': 2.6618497371673584, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 36.00935363769531, 'SSIM': 0.9182174205780029, 'Ma': 2.7792131900787354, 'NIQE': 5.914134502410889, 'PI': 6.567461013793945}
Computed 8 images in  121.42056250572205 seconds
Last image: {'G_pixel_loss': 6.960541941225529e-05, 'G_perceptual_loss': 1.6954981088638306, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 41.60639953613281, 'SSIM': 0.9552214741706848, 'Ma': 2.814223527908325, 'NIQE': 6.649964809417725, 'PI': 6.91787052154541}
Computed 8 images in  122.12681937217712 seconds
Last image: {'G_pixel_loss': 3.8128164305817336e-05, 'G_perceptual_loss': 1.1141538619995117, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 42.60576248168945, 'SSIM': 0.9757848381996155, 'Ma': 3.3276238441467285, 'NIQE': 7.983783721923828, '

Computed 8 images in  134.16783499717712 seconds
Last image: {'G_pixel_loss': 2.8492655474110506e-05, 'G_perceptual_loss': 0.7848737239837646, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 43.30661392211914, 'SSIM': 0.9930489659309387, 'Ma': 3.421401262283325, 'NIQE': 11.816807746887207, 'PI': 9.19770336151123}
Computed 8 images in  130.1010286808014 seconds
Last image: {'G_pixel_loss': 4.0846774936653674e-05, 'G_perceptual_loss': 1.1945489645004272, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 41.121253967285156, 'SSIM': 0.9808323979377747, 'Ma': 3.6551871299743652, 'NIQE': 9.302292823791504, 'PI': 7.823553085327148}
Computed 8 images in  131.33597779273987 seconds
Last image: {'G_pixel_loss': 4.440336488187313e-05, 'G_perceptual_loss': 0.920660138130188, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 37.16284942626953, 'SSIM': 0.9884056448936462, 'Ma': 4.166677474975586, 'NIQE': 8.795194625854492, 

Computed 8 images in  122.06510877609253 seconds
Last image: {'G_pixel_loss': 0.0001405970542691648, 'G_perceptual_loss': 2.9448142051696777, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.9202766418457, 'SSIM': 0.9086755514144897, 'Ma': 2.8582491874694824, 'NIQE': 7.345115661621094, 'PI': 7.243432998657227}
Computed 8 images in  121.60939812660217 seconds
Last image: {'G_pixel_loss': 2.5166915293084458e-05, 'G_perceptual_loss': 0.6433499455451965, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 50.79806137084961, 'SSIM': 0.9963196516036987, 'Ma': 2.668452501296997, 'NIQE': 16.26991844177246, 'PI': 11.80073356628418}
Computed 8 images in  121.81633949279785 seconds
Last image: {'G_pixel_loss': 0.0008806281839497387, 'G_perceptual_loss': 2.288360118865967, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 22.451963424682617, 'SSIM': -0.488547682762146, 'Ma': 2.659843921661377, 'NIQE': 7.916483402252197, '

Computed 8 images in  131.96045899391174 seconds
Last image: {'G_pixel_loss': 0.0009658020571805537, 'G_perceptual_loss': 2.8752212524414062, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 21.458465576171875, 'SSIM': -0.3478356599807739, 'Ma': 2.70342755317688, 'NIQE': 6.9614577293396, 'PI': 7.12901496887207}
Computed 8 images in  129.0265736579895 seconds
Last image: {'G_pixel_loss': 0.0014066052390262485, 'G_perceptual_loss': 4.248592376708984, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 18.056589126586914, 'SSIM': 0.44814595580101013, 'Ma': 2.7431931495666504, 'NIQE': 6.244375228881836, 'PI': 6.750591278076172}
Computed 8 images in  131.26443219184875 seconds
Last image: {'G_pixel_loss': 0.0016824688063934445, 'G_perceptual_loss': 5.388067722320557, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 16.56036949157715, 'SSIM': 0.6096059083938599, 'Ma': 2.8849382400512695, 'NIQE': 6.069187164306641, 'PI

Computed 8 images in  121.82954025268555 seconds
Last image: {'G_pixel_loss': 5.188381328480318e-05, 'G_perceptual_loss': 1.4574863910675049, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.10504150390625, 'SSIM': 0.9679086208343506, 'Ma': 2.6383614540100098, 'NIQE': 6.686849594116211, 'PI': 7.02424430847168}
Computed 8 images in  121.69794344902039 seconds
Last image: {'G_pixel_loss': 0.000159216855536215, 'G_perceptual_loss': 3.1567695140838623, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 30.338884353637695, 'SSIM': 0.9100728034973145, 'Ma': 2.912851333618164, 'NIQE': 6.770427703857422, 'PI': 6.928788185119629}
Computed 8 images in  122.13967537879944 seconds
Last image: {'G_pixel_loss': 0.0001302766613662243, 'G_perceptual_loss': 3.501293659210205, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.892086029052734, 'SSIM': 0.9186112284660339, 'Ma': 2.8699607849121094, 'NIQE': 6.3864827156066895, 

Computed 8 images in  122.4184582233429 seconds
Last image: {'G_pixel_loss': 4.130108209210448e-05, 'G_perceptual_loss': 1.362821102142334, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 42.309993743896484, 'SSIM': 0.9696282744407654, 'Ma': 2.3870458602905273, 'NIQE': 6.244194984436035, 'PI': 6.928574562072754}
Computed 8 images in  122.62092161178589 seconds
Last image: {'G_pixel_loss': 0.00019345733744557947, 'G_perceptual_loss': 3.513425350189209, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.307083129882812, 'SSIM': 0.8461294174194336, 'Ma': 2.809901475906372, 'NIQE': 6.951351642608643, 'PI': 7.070725440979004}
Computed 8 images in  122.7730565071106 seconds
Last image: {'G_pixel_loss': 1.8302982425666414e-05, 'G_perceptual_loss': 0.6718294620513916, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 47.99388885498047, 'SSIM': 0.9979493618011475, 'Ma': 2.9007153511047363, 'NIQE': 18.76014518737793, 

Computed 8 images in  130.94302678108215 seconds
Last image: {'G_pixel_loss': 0.00013567657151725143, 'G_perceptual_loss': 2.841007947921753, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 34.060577392578125, 'SSIM': 0.877464234828949, 'Ma': 2.6952950954437256, 'NIQE': 6.716294288635254, 'PI': 7.010499477386475}
Computed 8 images in  130.6101541519165 seconds
Last image: {'G_pixel_loss': 0.00014591838407795876, 'G_perceptual_loss': 3.022062301635742, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 32.72694396972656, 'SSIM': 0.8922284245491028, 'Ma': 2.7558624744415283, 'NIQE': 6.133430480957031, 'PI': 6.688784122467041}
Computed 8 images in  131.5996458530426 seconds
Last image: {'G_pixel_loss': 1.1920709766854998e-05, 'G_perceptual_loss': 0.5746155381202698, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 51.07958984375, 'SSIM': 0.9978247284889221, 'Ma': 2.947535991668701, 'NIQE': 15.641704559326172, 'PI

Computed 8 images in  132.78123021125793 seconds
Last image: {'G_pixel_loss': 4.813238047063351e-05, 'G_perceptual_loss': 1.495706558227539, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 38.196868896484375, 'SSIM': 0.9712569117546082, 'Ma': 3.341750383377075, 'NIQE': 7.024674892425537, 'PI': 6.841462135314941}
Computed 8 images in  129.00508975982666 seconds
Last image: {'G_pixel_loss': 0.0001238369004568085, 'G_perceptual_loss': 2.7180728912353516, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.011417388916016, 'SSIM': 0.9236901998519897, 'Ma': 3.1596767902374268, 'NIQE': 5.890652656555176, 'PI': 6.365488052368164}
Computed 8 images in  129.1276307106018 seconds
Last image: {'G_pixel_loss': 0.00015103384794201702, 'G_perceptual_loss': 3.2387747764587402, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 31.80850601196289, 'SSIM': 0.9013134241104126, 'Ma': 2.5318586826324463, 'NIQE': 6.955719947814941,

Computed 8 images in  122.74593377113342 seconds
Last image: {'G_pixel_loss': 0.000794718274846673, 'G_perceptual_loss': 2.9641330242156982, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 22.299179077148438, 'SSIM': 0.014831995591521263, 'Ma': 2.7404253482818604, 'NIQE': 7.028446197509766, 'PI': 7.144010543823242}
Computed 8 images in  122.97165656089783 seconds
Last image: {'G_pixel_loss': 0.0012015531538054347, 'G_perceptual_loss': 5.297347068786621, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 18.74985122680664, 'SSIM': -0.1298009753227234, 'Ma': 2.675398588180542, 'NIQE': 6.4071125984191895, 'PI': 6.865857124328613}
Computed 8 images in  123.00821352005005 seconds
Last image: {'G_pixel_loss': 0.00023179478012025356, 'G_perceptual_loss': 1.0860364437103271, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 34.17641067504883, 'SSIM': 0.918163001537323, 'Ma': 2.875134229660034, 'NIQE': 20.40656280517578

Computed 8 images in  130.30859637260437 seconds
Last image: {'G_pixel_loss': 0.00023295659048017114, 'G_perceptual_loss': 0.8577008843421936, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.74969482421875, 'SSIM': 0.9192195534706116, 'Ma': 2.8366332054138184, 'NIQE': 20.729812622070312, 'PI': 13.946589469909668}
Computed 8 images in  128.36317682266235 seconds
Last image: {'G_pixel_loss': 0.00023930080351419747, 'G_perceptual_loss': 0.7255091667175293, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 33.98802947998047, 'SSIM': 0.8969789147377014, 'Ma': 2.710846185684204, 'NIQE': 21.195566177368164, 'PI': 14.24236011505127}
Computed 8 images in  122.82224178314209 seconds
Last image: {'G_pixel_loss': 0.0016390973469242454, 'G_perceptual_loss': 8.195181846618652, 'G_generator_loss': nan, 'G_loss_total': nan, 'D_loss_total': 0.0, 'PSNR': 16.52741050720215, 'SSIM': -0.0036678456235677004, 'Ma': 2.4640729427337646, 'NIQE': 7.826578617

### 8.3 Evaluate every kth epoch

In [17]:
if PRE_EVALUATE_HISTORY or GAN_EVALUATE_HISTORY:
    val_or_test = 'val'
    
    # Computing Ma is 100x more time consuming than anything else. It is not interesting to measure this for pretraining
    if METRIC_MA:
        PRE_GAN = ['gan']
    else:
        PRE_GAN = ['pre', 'gan']

    for pre_gan in PRE_GAN:
        for sensor in SENSORS[val_or_test]:
            if sensor == 'GE01':
                band_indices = MS_BANDS_GE01_IDXS
            elif sensor == 'WV02':
                band_indices = MS_BANDS_WV02_IDXS
            if pre_gan == 'pre':
                model_weights_dir = PRETRAIN_WEIGHTS_DIR
                eval_first_step = 1
                eval_prefix = EXPERIMENT + '-pre-'
            else:
                model_weights_dir = GAN_WEIGHTS_DIR
                eval_first_step = 0
                eval_prefix = EXPERIMENT + '-gan-'

            esrgan_epoch_evaluator(gan_model,
                                   model_weights_dir=model_weights_dir,
                                   model_weight_prefix=eval_prefix,
                                   dataset=ds_val[sensor],
                                   n_epochs=EVAL_N_EPOCHS,
                                   first_epoch=eval_first_step,
                                   steps_per_epoch=EVAL_STEPS_PER_EPOCH,
                                   k_epoch=25,
                                   csv_dir=LOGS_EXP_DIR + '/csv/' + val_or_test + '-' + sensor, 
                                   per_image=EVAL_PER_IMAGE, 
                                   verbose=0)

### 8.4 Comparison plots