# Create Training Samples UNET

Script to create the training samples for the U-Net.

## 1. Importing libraries and changing working directory

Reference for the libraries:

+ [numpy](https://numpy.org/)
+ [gdal](https://gdal.org/api/python.html)
+ [deepgeo](https://github.com/rvmaretto/deepgeo)
+ [skimage](https://scikit-image.org/)
+ [os](https://docs.python.org/3/library/os.html)

In [None]:
import numpy as np
import gdal
from deepgeo.dataset import rasterizer
import deepgeo.dataset.preprocessor as prep
import deepgeo.common.geofunctions as gf
import deepgeo.common.visualization as vis
import deepgeo.dataset.dataset_generator as dg
import deepgeo.dataset.data_augment as dtaug
import deepgeo.networks.model_builder as mb
import skimage
import os

In [None]:
# folder where all data is stored
os.chdir(os.getcwd().rsplit('/',2)[0]+'/Data')

## 2. Defining parameters

In [None]:
# cell from which the training samples are going to be created.
cell = '089098'
# State
state = 'BA'
# Training samples Approach identifier
situation = 'appr1'
# platform
platform = 'Sentinel'
# year
year = '2019'

# shapefile with chips centroids
shp_chips   = f'./ref/unet_chips.shp'

raster_file = f'./predictions/LSTM/{situation}.{year}.{cell}.tr{cell}.PS.tif'
shape_file  = f'./ref/ref_UNET_{year}_{cell}.shp'

In [None]:
if not os.path.exists('./training_samples/UNET'):
    os.makedirs('./training_samples/UNET')

output_ds = './training_samples/UNET/'+situation+'.'+state+'.'+platform
os.mkdir(output_ds)
ds_file_name = situation+'.'+state+'.'+platform

## 3. Opens data and reference

In [None]:
# opens data
img_raster = gf.load_image(raster_file, no_data=-9999)
vis.plot_rgb_img(img_raster.astype(np.uint16), bands=[0,0,1], contrast=True, title="Training Reference")

In [None]:
non_class = "no_data"
class_column = "Class"
out_labels = "labels.tiff"

In [None]:
# rasterizes reference shapefile
print('1...')
rstzr = rasterizer.Rasterizer(shape_file,
                              raster_file,
                              class_column,
                              #classes_interest=classes_of_interest,
                              non_class_name=non_class)
print('2...')
rstzr.collect_class_names()
print('3...')
rstzr.rasterize_layer()
print('4...')
m_class_names = ['no_data'] + rstzr.get_class_names()
print('5...')
rasterized_layer = rstzr.get_labeled_raster()
print('done')

In [None]:
rstzr.save_labeled_raster_to_gtiff(out_labels)

In [None]:
vis.plot_labels(rasterized_layer, m_class_names, colors=['white', 'red', 'green', 'red', 'green', 'red', 'green', 'red', 'green'])

## 4. Training samples metadata

In [None]:
dataset_description = {'years': year,
                       'standardization': 'norm_range',
                       'range': {"min":-1, "max":1},
                       'indexes_to_compute': 'none',
                       'bands': ['LSTM Prediction',
                                 'Slope'],
                       'sensor': 'derived',
                       'classes': m_class_names,
                       'img_no_data': -1,
                       'chip_size': 284,
                       'tolerance_nodata': .95,
                       'notes': raster_file+'\n'+shp_chips+'\n'+shape_file}

## 5. Pre-process the data

In [None]:
preproc = prep.Preprocessor(raster_file, no_data=-9999)
preproc.set_nodata_value(dataset_description['img_no_data'])

In [None]:
preproc.standardize_image(dataset_description['standardization'], dataset_description['range'])
raster_img = preproc.get_array_stacked_raster()

In [None]:
vis.plot_image_histogram(preproc.get_array_stacked_raster(), title="Normalized Raster", legend=dataset_description['bands'])#, cmap=['blue', 'red', 'green']) 

## 6. Extract training samples

In [None]:
rstzr = None
preproc = None
generator = dg.DatasetGenerator([raster_img],
                                [rasterized_layer],
                                strategy='centroids',
                                description=dataset_description)
params = {'win_size': dataset_description['chip_size'],
          'class_of_interest': ['Deforestation'], #, 'Not Deforestation'],
          'quantity': 1000,
          'class_names': m_class_names,
          'shp_path': shp_chips,
          'labels_tif': raster_file}
generator.generate_chips(params)
chip_struct = generator.get_samples()
# vis.plot_chips(chip_struct, raster_img, bands=[0,2,1], contrast=True)
generator.remove_no_data(tolerance=dataset_description['tolerance_nodata'])
vis.plot_chips(chip_struct, raster_img, bands=[0,0,1], contrast=True)

## 7. Save samples

In [None]:
generator.shuffle_ds()
generator.split_ds(perc_test=20, perc_val=5)

In [None]:
chip_struct = generator.get_samples()
generator.save_to_disk(output_ds, ds_file_name)