# Data generator demonstration

This demonstrates the function of data generators (which feed training data to the ramp model during training), with and without augmentation and loss function weighting.

In [None]:
%matplotlib inline

In [None]:
import tensorflow as tf 
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # only print errors
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# set up logging
import logging
logging.basicConfig(level = logging.INFO)

In [None]:
import sys
RAMP_HOME = os.environ["RAMP_HOME"]
from ramp.data_mgmt.data_generator import training_batches_from_gtiff_dirs, add_class_weights

In [None]:
train_base = os.path.join(RAMP_HOME, 'ramp-code/notebooks/sample-data/training_data')
image_dir = os.path.join(train_base, "chips")
mask_dir = os.path.join(train_base, "multimasks")
batch_size = 16
input_image_size = (256,256)
output_image_size = (256,256)

### Demonstration with class weights for weighting loss function

In [None]:
train_batches_wts = training_batches_from_gtiff_dirs(image_dir, 
                                                 mask_dir, 
                                                 batch_size, 
                                                 input_image_size, 
                                                 output_image_size)

In [None]:
def visualize_all(image, mask, weights=None):
    fontsize = 18
    
    if weights is None:
        f, ax = plt.subplots(1, 2, figsize=(8, 8))

        ax[0].imshow(image)
        ax[0].set_title('image', fontsize=fontsize)
        ax[1].imshow(mask)
        ax[1].set_title('label', fontsize=fontsize)
    else:
        f, ax = plt.subplots(1,3, figsize=(8, 8))

        ax[0].imshow(image)
        ax[0].set_title('image', fontsize=fontsize)
        
        ax[1].imshow(mask)
        ax[1].set_title('label', fontsize=fontsize)

        ax[2].imshow(weights)
        ax[2].set_title('class weight mask', fontsize=fontsize)

In [None]:
class_weights = tf.constant([1.0, 1.0, 2.0, 2.0])
iterator_wts = iter(train_batches_wts.map(lambda chip, label: add_class_weights(chip, label, class_weights)))
batch = iterator_wts.get_next()

In [None]:
# batch has length 3: image, mask, class weights
len(batch)

In [None]:
for item in batch:
    print(item.shape)

In [None]:
for ii in range(batch_size):
    image = batch[0][ii,:,:,:]
    mask = batch[1][ii,:,:,:]
    wts = batch[2][ii,:,:,:]
    visualize_all(image, mask, wts)

### Demonstration with augmentation

In [None]:
import albumentations as A
from cv2 import BORDER_CONSTANT, INTER_NEAREST

aug = A.Compose([
                A.Rotate(
                    border_mode=BORDER_CONSTANT, 
                    interpolation=INTER_NEAREST, 
                    value=(0.0,0.0,0.0), 
                    mask_value = 0, 
                    p=0.9),
                A.RandomBrightnessContrast(brightness_limit=0.2, 
                    contrast_limit=0.2, 
                    brightness_by_max=True, 
                    p=0.9)
        ])

In [None]:
# note addition of augmentation transform parameter
train_batches_aug = training_batches_from_gtiff_dirs(image_dir, 
                                                 mask_dir, 
                                                 batch_size, 
                                                 input_image_size, 
                                                 output_image_size, 
                                                 aug)

In [None]:
iterator_aug = iter(train_batches_aug)
batch = iterator_aug.get_next()
len(batch)

In [None]:
for ii in range(batch_size):
    image = batch[0][ii,:,:,:]
    mask = batch[1][ii,:,:,:]
    visualize_all(image, mask)

### Demonstrate simultaneous augmentation and class weighting.

In [None]:
train_batches_aug = training_batches_from_gtiff_dirs(image_dir, 
                                                 mask_dir, 
                                                 batch_size, 
                                                 input_image_size, 
                                                 output_image_size, 
                                                 aug)

iterator_aug_wts = iter(train_batches_aug.map(lambda chip, label: add_class_weights(chip, label, class_weights)))
batch = iterator_aug_wts.get_next()
len(batch)

In [None]:
for ii in range(batch_size):
    image = batch[0][ii,:,:,:]
    mask = batch[1][ii,:,:,:]
    wts = batch[2][ii,:,:,:]
    visualize_all(image, mask, wts)

##### Created for ramp project, August 2022
##### Author: carolyn.johnston@dev.global