# Experiment 01: Band configurations

In this experiment we will study how different band configurations affect the training process and end results.

### Experiment variations:

- E01-8: All 8 MS WorldView-2 bands is used.
- E01-6: The 6 MS WorldView-2 bands that overlap the PAN band (RGB+Y+RedEdge+NIR1)
- E01-4: The 4 MS WorldView-2 bands that are also available in GeoEye-1 (MS1 array RGB+NIR1)
 - In this variation we will also validate performance on the GeoEye-1 validation set
- E01-3: Only the 3 RGB bands from WorldView-2
 - Also validated on the GeoEye-1 validation set
 
### The notebook is divided into the following main sections:
1. Imports and configuration parameters
2. Tile generation (sampling of tiles from the satellite images)
3. Building of models and tile input pipelines (`tf.dataset` objects reading tiles from disk)
4. Pretraining with L1 loss
5. GAN-training with L1 + Percep + GAN loss
6. Inspection of results

Training history is logged with TensorBoard.

## 1. Imports and config

In [None]:
import tensorflow as tf

# Internal modules
from modules.helpers import *
from modules.tile_generator import *
from modules.matlab_metrics import *
from modules.image_utils import *
from modules.tile_input_pipeline import *
from modules.models import *
from modules.logging import *

# Check GPUs and enable dynamic GPU memory use:",
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            # Prevent TensorFlow from allocating all memory of all GPUs:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [None]:
EXPERIMENT_NAMES = ['e01-8', 'e01-6', 'e01-4', 'e01-3']
EXPERIMENT = EXPERIMENT_NAMES[2]
GENERATE_TILES = False

# Load metadata dataframe from repository root
meta = load_meta_pickle_csv('.', 'metadata_df', from_pickle=True)

# Paths
DATA_PATH = 'data/toulon-laspezia'
DATA_PATH_TILES = 'data/toulon-laspezia-tiles/e01'
DATA_PATH_TILES_P = {'train': DATA_PATH_TILES + '/train', 
                     'val': DATA_PATH_TILES + '/val', 
                     'test': DATA_PATH_TILES + '/test'}

LOGS_DIR = 'logs/'
LOGS_EXP_DIR = LOGS_DIR + EXPERIMENT

# Subset by sensor and area
SENSORS_GENERATE = ['WV02', 'GE01']
AREAS_GENERATE = ['La_Spezia', 'Toulon']
meta = subset_by_areas_sensor(meta, areas=AREAS_GENERATE, sensors=SENSORS_GENERATE)
print('Sensors to generate tiles from:', SENSORS_GENERATE)
print('Areas to generate tiles from:', AREAS_GENERATE)

# Count images in partitions
N_IMAGES_TOTAL = count_images(meta)
N_IMAGES = {'train': count_images_in_partition(meta, 'train'), 
            'val': count_images_in_partition(meta, 'val'), 
            'test': count_images_in_partition(meta, 'test')}
assert N_IMAGES_TOTAL == sum(N_IMAGES.values())  # Verify that different ways of counting adds up
print('Number of images in partitions', N_IMAGES)
print('Total number of images:', N_IMAGES_TOTAL)

# Sensors used in which experiment variation
SENSORS_EXP = {'e01-8': {'train': 'WV02', 'val': ['WV02'], 'test': ['WV02']}, 
               'e01-6': {'train': 'WV02', 'val': ['WV02'], 'test': ['WV02']}, 
               'e01-4': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}, 
               'e01-3': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}}
SENSORS = SENSORS_EXP[EXPERIMENT]

# Areas used in which experiment variation
AREAS_EXP = {'e01-8': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-6': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-4': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-3': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}}
AREAS = AREAS_EXP[EXPERIMENT]


# Set tile dimensions
SR_FACTOR = 4
MS_SIZE = {'train': 32, 'val': 128, 'test': 128}
PAN_SIZE = {'train': MS_SIZE['train'] * SR_FACTOR, 
            'val': MS_SIZE['val'] * SR_FACTOR, 
            'test': MS_SIZE['test'] * SR_FACTOR}
print('MS (LR) tile size:', MS_SIZE)
print('PAN (HR) tile size:', PAN_SIZE)
print('SR factor:', SR_FACTOR)

# Band (channel) configurations:
N_PAN_BANDS = 1

WV02_FULL_BAND_CONFIG = get_sensor_bands('WV02', meta)
WV02_EXP_BAND_CONFIGS = {'e01-8': WV02_FULL_BAND_CONFIG, 
                        'e01-6': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items() if k not in ['Coastal', 'NIR2']}, 
                        'e01-4': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items() if k in ['Blue', 'Green', 'Red', 'NIR']},
                        'e01-3': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items() if k in ['Blue', 'Green', 'Red']}}
MS_BANDS_WV02_CONFIG = WV02_EXP_BAND_CONFIGS[EXPERIMENT]
if EXPERIMENT == 'e01-8':
    MS_BANDS_WV02_IDXS = 'all'
else:
    MS_BANDS_WV02_IDXS = list(MS_BANDS_WV02_CONFIG.values())
N_MS_BANDS = len(MS_BANDS_WV02_CONFIG.values())

GE01_FULL_BAND_CONFIG = get_sensor_bands('GE01', meta)
GE01_EXP_BAND_CONFIGS = {'e01-8': {None: None}, 
                        'e01-6': {None: None}, 
                        'e01-4': GE01_FULL_BAND_CONFIG, 
                        'e01-3': {k:v for (k,v) in GE01_FULL_BAND_CONFIG.items() if k not in ['NIR']}}
MS_BANDS_GE01_CONFIG = GE01_EXP_BAND_CONFIGS[EXPERIMENT]
if EXPERIMENT == 'e01-4':
    MS_BANDS_GE01_IDXS = 'all'
else:
    MS_BANDS_GE01_IDXS = list(MS_BANDS_GE01_CONFIG.values())
print('MS (LR) Band Config WV02:', MS_BANDS_WV02_CONFIG)
print('MS (LR) Band Config GE01:', MS_BANDS_GE01_CONFIG)

TILES_PER_M2 = {'train': 0.01, 
                'val': 0.01, 
                'test': 0.01}

RESIZE_TO_PIXEL_SIZE = False
RESIZE_RESAMPLING_METHOD = 'nearest'  # 'nearest', 'bicubic', 'bilinear'

CLOUD_SEA_REMOVAL = True
CLOUD_SEA_WEIGHTS_PATH = 'models/cloud-sea-classifier/cloudsea-effb0-augm-bicubic-pan-0.0005--200-0.129130.h5'
CLOUD_SEA_PRED_CUTOFF = 0.95

BATCH_SIZE = {'train': 16, 'val': 8, 'test': 8}

SHAVE_WIDTH = 4
MATLAB_PATH = 'modules/matlab'


# OPTIMIZeER parameters
PRETRAIN_LOSS = 'l1'
#PRETRAIN_LR = 0.0002 # Official
PRETRAIN_LR = 5e-5
PRETRAIN_BETA_1 = 0.9
PRETRAIN_BETA_2 = 0.999

#GAN_G_LR = 1e-4 # Official
#GAN_D_LR = 1e-4 # Official
#GAN_G_LR = 5e-5
#GAN_D_LR = 5e-5
GAN_G_LR = 2e-5
GAN_D_LR = 2e-5
G_BETA_1, D_BETA_1 = 0.9, 0.9
G_BETA_2, D_BETA_2 = 0.999, 0.999

G_LOSS_PIXEL_W = 0.01
G_LOSS_PERCEP_W = 1.0
G_LOSS_GENERATOR_W = 0.005

# RRDB Model parameters 
N_BLOCKS = 16
#N_BLOCKS = 23
N_FILTERS = 64

## 2. Tile generation

### 2.1 Image resizing

### 2.2 Tile allocation

We allocate `n_tiles` to each satellite image in proportion to the area covered by the satellite image. We adjust `n_tiles` by the argument `tiles_per_m2`. If `tiles_per_m2=1.0` then `n_tiles` is set deterministically to a value so that a square meter of satellite image is expected to be covered by `1.0` tile.

In [None]:
meta = allocate_tiles_by_expected(meta, 
                                  override_pan_pixel_size=RESIZE_TO_PIXEL_SIZE,
                                  by_partition=True, 
                                  tiles_per_m2_train_val_test=(TILES_PER_M2['train'], 
                                                               TILES_PER_M2['val'], 
                                                               TILES_PER_M2['test']),
                                  pan_tile_size_train_val_test=(PAN_SIZE['train'], 
                                                                PAN_SIZE['val'], 
                                                                PAN_SIZE['test']),
                                  new_column_name='n_tiles')

n_tiles = {'train': count_tiles_in_partition(meta, 'train'),
           'val': count_tiles_in_partition(meta, 'val'), 
           'test':  count_tiles_in_partition(meta, 'test')}
n_tiles_total = count_tiles(meta)
assert n_tiles_total == sum(n_tiles.values())
print('Number of tiles per partition:')
print(n_tiles)
print('Total number of tiles:', n_tiles_total)

### 2.3 Tile generation to disk

In [None]:
if GENERATE_TILES:
    meta = generate_all_tiles(meta, 
                              save_dir=DATA_PATH_TILES, 
                              sr_factor=SR_FACTOR, 
                              by_partition=True,
                              ms_tile_size_train_val_test=(MS_SIZE['train'], MS_SIZE['val'], MS_SIZE['test']), 
                              cloud_sea_removal=CLOUD_SEA_REMOVAL, 
                              cloud_sea_weights_path=CLOUD_SEA_WEIGHTS_PATH, 
                              cloud_sea_pred_cutoff=CLOUD_SEA_PRED_CUTOFF,
                              save_meta_to_disk=True)

In [None]:
train_tiles_mean, train_tiles_sd = mean_sd_of_train_tiles(DATA_PATH_TILES, 
                                                          sample_proportion=1.0, 
                                                          write_json=True)

## 3. Data input pipeline from disk

### 3.1 Training set

In [None]:
SHUFFLE_BUFFER_SIZE = {'train': n_tiles['train'],  # 100
                       'val': n_tiles['val'],  # 100
                       'test': n_tiles['test']}  # 100

train_tiles_mean, train_tiles_sd = read_mean_sd_json(DATA_PATH_TILES)

train_val_test = 'train'
sensor = SENSORS[train_val_test]
ds_train = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                          batch_size=BATCH_SIZE[train_val_test], 
                          ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                          pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                          sensor=sensor,
                          band_selection=MS_BANDS_WV02_IDXS, 
                          mean_correction=train_tiles_mean,
                          cache_memory=True,
                          cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                           + train_val_test + '-' + sensor + '_cache'), 
                          repeat=True, 
                          shuffle=True, 
                          shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test])
# Getting the scaled output range from the scaler. Needed to calculate PSNR and SSIM:
scaled_range = ds_train.get_scaler_output_range(print_ranges=True)

# Returning the actual tf.data.dataset object:
ds_train = ds_train.get_dataset()

### 3.2 Validation set

In [None]:
# Validation set can have several sensors and is organized in a dictionary
# structure: ds_val = {sensor: dataset} ... ex: ds_val = {'WV02': dataset_with_only_WV02_images}
train_val_test = 'val'
ds_val = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_val[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                    batch_size=BATCH_SIZE[train_val_test], 
                                    ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                    pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                    sensor=sensor,
                                    band_selection=band_indices, 
                                    mean_correction=train_tiles_mean,
                                    cache_memory=True,
                                    cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                   + train_val_test + '-' + sensor + '_cache'), 
                                    repeat=True, 
                                    shuffle=True, 
                                    shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test])
    ds_val[sensor] = ds_val[sensor].get_dataset()
print(ds_val.keys())

### 3.3 Test set

In [None]:
train_val_test = 'test'
ds_test = {}
for sensor in SENSORS[train_val_test]:
    if sensor == 'WV02':
        band_indices = MS_BANDS_WV02_IDXS
    elif sensor == 'GE01':
        band_indices = MS_BANDS_GE01_IDXS
    ds_test[sensor] = GeotiffDataset(tiles_path=DATA_PATH_TILES_P[train_val_test], 
                                     batch_size=BATCH_SIZE[train_val_test], 
                                     ms_tile_shape=(MS_SIZE[train_val_test], MS_SIZE[train_val_test], N_MS_BANDS), 
                                     pan_tile_shape=(PAN_SIZE[train_val_test], PAN_SIZE[train_val_test], N_PAN_BANDS),
                                     sensor=sensor,
                                     band_selection=band_indices, 
                                     mean_correction=train_tiles_mean,
                                     cache_memory=False,
                                     cache_file=str(DATA_PATH_TILES + '/ds_' + EXPERIMENT + '-'
                                                    + train_val_test + '-' + sensor + '_cache'), 
                                     repeat=False, 
                                     shuffle=False, 
                                     shuffle_buffer_size=SHUFFLE_BUFFER_SIZE[train_val_test])
    ds_test[sensor] = ds_test[sensor].get_dataset()
print(ds_test.keys())

In [None]:
for i, batch in enumerate(ds_val['GE01']):
    print(batch[0].shape)
    if i == 10:
        break

## 4. Models


### 4.1 Bicubic baseline model

In [None]:
bicubic = build_deterministic_sr_model(upsample_factor=SR_FACTOR,
                                       resize_method='bicubic',
                                       loss='mean_absolute_error',
                                       metrics=('PSNR', 'SSIM'),
                                       scaled_range=scaled_range)

### 4.2 ESRGAN Generator model (pretrain version)

In [None]:
pretrain_model =  build_generator(pretrain_or_gan='pretrain', 
                                  pretrain_learning_rate=PRETRAIN_LR, 
                                  pretrain_loss_l1_l2=PRETRAIN_LOSS,
                                  pretrain_beta_1=PRETRAIN_BETA_1, 
                                  pretrain_beta_2=PRETRAIN_BETA_2, 
                                  pretrain_metrics=('PSNR', 'SSIM'),
                                  scaled_range=scaled_range, 
                                  n_channels_in=N_MS_BANDS, 
                                  n_channels_out=N_PAN_BANDS, 
                                  height_width_in=None,  # None will make network image size agnostic
                                  n_filters=N_FILTERS, 
                                  n_blocks=N_BLOCKS)
# pretrain_model.summary()

## 5. Pretraining

In [None]:
logger = EsrganLogger(
    model_name=EXPERIMENT + '-pretrain',
    log_tensorboard=True,
    tensorboard_logs_dir=LOGS_EXP_DIR + '/tb',
    save_models=True,
    models_save_dir=LOGS_EXP_DIR + '/models',
    save_weights_only=True,
    log_train_images=False,
    model=pretrain_model,
    n_train_image_batches=1,
    train_image_dataset=ds_train,
    log_val_images=False,
    n_val_image_batches=1,
    val_image_dataset=None,  # This can also be dict with different sensors
    log_val_secondary_sensor=True,
    val_second_dataset=ds_test['GE01'],
    val_second_name='GE01',
    val_second_steps=10)

In [None]:
logger.get_callbacks()

In [None]:
for val_name, val_ds in ds_val.items():
    print(val_name, val_ds)

In [None]:
#show_batch(iter(ds_val['GE01']).get_next())

In [None]:
EPOCHS = 2

history = pretrain_model.fit(ds_train, 
                             epochs=EPOCHS, 
                             validation_data=ds_val['WV02'],
                             steps_per_epoch=10, 
                             validation_steps=10, 
                             initial_epoch=0,
                             callbacks=logger.get_callbacks()
                             )
#pretrain_model.save_weights("models/esrgan-psnr-train-WV02-val-WV02-98-0.001750.h5")