# Experiment 01: Band configurations

In this experiment we will study how different band configurations affect the training process and end results.

### Experiment variations:

- E01-8: All 8 MS WorldView-2 bands is used.
- E01-6: The 6 MS WorldView-2 bands that overlap the PAN band (RGB+Y+RedEdge+NIR1)
- E01-4: The 4 MS WorldView-2 bands that are also available in GeoEye-1 (MS1 array RGB+NIR1)
 - In this variation we will also validate performance on the GeoEye-1 validation set
- E01-3: Only the 3 RGB bands from WorldView-2
 - Also validated on the GeoEye-1 validation set
 
### The notebook is divided into the following main sections:
1. Imports and configuration parameters
2. Tile generation (sampling of tiles from the satellite images)
3. Building of models and tile input pipelines (`tf.dataset` objects reading tiles from disk)
4. Pretraining with L1 loss
5. GAN-training with L1 + Percep + GAN loss
6. Inspection of results

Training history is logged with TensorBoard.

## 1. Imports and config

In [None]:
import tensorflow as tf

# Internal modules
from modules.helpers import *
from modules.tile_generator import *
from modules.matlab_metrics import *
from modules.image_utils import *

# Check GPUs and enable dynamic GPU memory use:",
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            # Prevent TensorFlow from allocating all memory of all GPUs:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [None]:
EXPERIMENT_NAMES = ['e01-8', 'e01-6', 'e01-4', 'e01-3']
EXPERIMENT = EXPERIMENT_NAMES[0]

# Load metadata dataframe from repository root
meta = load_meta_pickle_csv('.', 'metadata_df', from_pickle=True)

# Paths
DATA_PATH = 'data/toulon-laspezia'
DATA_PATH_TILES = 'data/toulon-laspezia-tiles/e01'

# Subset by sensor and area
SENSORS_GENERATE = ['WV02', 'GE01']
AREAS_GENERATE = ['La_Spezia', 'Toulon']
meta = subset_by_areas_sensor(meta, areas=AREAS_GENERATE, sensors=SENSORS_GENERATE)
print('Sensors to generate tiles from:', SENSORS_GENERATE)
print('Areas to generate tiles from:', AREAS_GENERATE)

# Count images in partitions
N_IMAGES_TOTAL = count_images(meta)
N_IMAGES = {'train': count_images_in_partition(meta, 'train'), 
            'val': count_images_in_partition(meta, 'val'), 
            'test': count_images_in_partition(meta, 'test')}
assert N_IMAGES_TOTAL == sum(N_IMAGES.values())  # Verify that different ways of counting adds up
print('Number of images in partitions', N_IMAGES)
print('Total number of images:', N_IMAGES_TOTAL)

# Sensors used in which experiment variation
SENSORS_EXP = {'e01-8': {'train': 'WV02', 'val': 'WV02', 'test': 'WV02'}, 
               'e01-6': {'train': 'WV02', 'val': 'WV02', 'test': 'WV02'}, 
               'e01-4': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}, 
               'e01-3': {'train': 'WV02', 'val': ['WV02', 'GE01'], 'test': ['WV02', 'GE01']}}
SENSORS = SENSORS_EXP[EXPERIMENT]

# Areas used in which experiment variation
AREAS_EXP = {'e01-8': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-6': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-4': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}, 
             'e01-3': {'train': AREAS_GENERATE, 'val': AREAS_GENERATE, 'test': AREAS_GENERATE}}
AREAS = AREAS_EXP[EXPERIMENT]


# Set tile dimensions
SR_FACTOR = 4
MS_SIZE = {'train': 32, 'val': 128, 'test': 128}
PAN_SIZE = {'train': MS_SIZE['train'] * SR_FACTOR, 
            'val': MS_SIZE['val'] * SR_FACTOR, 
            'test': MS_SIZE['test'] * SR_FACTOR}
print('MS (LR) tile size:', MS_SIZE)
print('PAN (HR) tile size:', PAN_SIZE)
print('SR factor:', SR_FACTOR)

# Band (channel) configurations:
N_PAN_BANDS = 1

WV02_FULL_BAND_CONFIG = get_sensor_bands('WV02', meta)
WV02_EXP_BAND_CONFIGS = {'e01-8': WV02_FULL_BAND_CONFIG, 
                        'e01-6': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items() if k not in ['Coastal', 'NIR2']}, 
                        'e01-4': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items() if k in ['Blue', 'Green', 'Red', 'NIR']},
                        'e01-3': {k:v for (k,v) in WV02_FULL_BAND_CONFIG.items() if k in ['Blue', 'Green', 'Red']}}
MS_BANDS_WV02_CONFIG = WV02_EXP_BAND_CONFIGS[EXPERIMENT]
MS_BANDS_WV02_IDXS = list(MS_BANDS_WV02_CONFIG.values())

GE01_FULL_BAND_CONFIG = get_sensor_bands('GE01', meta)
GE01_EXP_BAND_CONFIGS = {'e01-8': {None: None}, 
                        'e01-6': {None: None}, 
                        'e01-4': GE01_FULL_BAND_CONFIG, 
                        'e01-3': {k:v for (k,v) in GE01_FULL_BAND_CONFIG.items() if k not in ['NIR']}}
MS_BANDS_GE01_CONFIG = GE01_EXP_BAND_CONFIGS[EXPERIMENT]
MS_BANDS_GE01_IDXS = list(MS_BANDS_GE01_CONFIG.values())
print('MS (LR) Band Config WV02:', MS_BANDS_WV02_CONFIG)
print('MS (LR) Band Config GE01:', MS_BANDS_GE01_CONFIG)

TILES_PER_M2 = {'train': 0.1, 
                'val': 0.1, 
                'test': 0.1}

RESIZE_TO_PIXEL_SIZE = False
RESIZE_RESAMPLING_METHOD = 'nearest'  # 'nearest', 'bicubic', 'bilinear'

CLOUD_SEA_REMOVAL = False
CLOUD_SEA_WEIGHTS_PATH = 'models/cloud-sea-classifier/cloudsea-effb0-augm-bicubic-pan-0.0005--200-0.129130.h5'
CLOUD_SEA_PRED_CUTOFF = 0.9

BATCH_SIZE = {'train': 16, 'val': 8, 'test': 8}

SHAVE_WIDTH = 4
MATLAB_PATH = 'modules/matlab'

## 2. Tile generation

### 2.1 Image resizing

### 2.2 Tile allocation

We allocate `n_tiles` to each satellite image in proportion to the area covered by the satellite image. We adjust `n_tiles` by the argument `tiles_per_m2`. If `tiles_per_m2=1.0` then `n_tiles` is set deterministically to a value so that a square meter of satellite image is expected to be covered by `1.0` tile.

In [None]:
meta = allocate_tiles_by_expected(meta, 
                                  override_pan_pixel_size=RESIZE_TO_PIXEL_SIZE,
                                  by_partition=True, 
                                  tiles_per_m2_train_val_test=(TILES_PER_M2['train'], 
                                                               TILES_PER_M2['val'], 
                                                               TILES_PER_M2['test']),
                                  pan_tile_size_train_val_test=(PAN_SIZE['train'], 
                                                                PAN_SIZE['val'], 
                                                                PAN_SIZE['test']),
                                  new_column_name='n_tiles')

n_tiles = {'train': count_tiles_in_partition(meta, 'train'),
           'val': count_tiles_in_partition(meta, 'val'), 
           'test':  count_tiles_in_partition(meta, 'test')}
n_tiles_total = count_tiles(meta)
assert n_tiles_total == sum(n_tiles.values())
print('Number of tiles per partition:')
print(n_tiles)
print('Total number of tiles:', n_tiles_total)

### 2.3 Tile generation to disk

In [None]:
mean_sd_of_train_tiles(str(DATA_PATH_TILES + '/train'), sample_proportion=1, write_json=True)

## 3. Data input pipeline from disk

In [None]:
mean_of_train_tiles, sd_of_train_tiles = read_mean_sd_json(DATA_PATH_TILES + '/train')

ds_train = dataset_from_tif_tiles(DATA_PATH_TILES_TRAIN, BATCH_SIZE, 
                                  ms_tile_shape=(MS_HEIGHT,MS_WIDTH,MS_BANDS), 
                                  pan_tile_shape=(PAN_HEIGHT,PAN_WIDTH,PAN_BANDS),
                                  imitate_ge01=False, imitation_bands=None, mean_correction=mean_of_train_tiles,
                                  shuffle_buffer_size=SHUFFLE_BUFFER_TRAIN,
                                  cache_memory=True, cache_file=str(DATA_PATH_TILES+'/ds_train_cache'))

ds_val = dataset_from_tif_tiles(DATA_PATH_TILES_VAL, BATCH_SIZE, 
                                ms_tile_shape=(MS_HEIGHT,MS_WIDTH,MS_BANDS), 
                                pan_tile_shape=(PAN_HEIGHT,PAN_WIDTH,PAN_BANDS),
                                imitate_ge01=False, imitation_bands=None, mean_correction=mean_of_train_tiles,
                                shuffle_buffer_size=SHUFFLE_BUFFER_VAL, 
                                cache_memory=True, cache_file=str(DATA_PATH_TILES+'/ds_val_cache'))

ds_test = dataset_from_tif_tiles(DATA_PATH_TILES_TEST, BATCH_SIZE, 
                                 ms_tile_shape=(MS_HEIGHT,MS_WIDTH,MS_BANDS), 
                                 pan_tile_shape=(PAN_HEIGHT,PAN_WIDTH,PAN_BANDS),
                                 imitate_ge01=False, imitation_bands=None, mean_correction=mean_of_train_tiles,
                                 shuffle_buffer_size=SHUFFLE_BUFFER_TEST,
                                 cache_memory=True, cache_file=str(DATA_PATH_TILES+'/ds_train_cache'))

In [None]:
mean_of_train_tiles
#type(mean_of_train_tiles)

In [None]:
round(mean_of_train_tiles, 1)
