<a href="https://colab.research.google.com/github/Max-FM/seagrass/blob/master/notebooks/prepare_training_data_banc_d_arguin.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Preparing Banc d'Arguin imaging to create seagrass training data for machine learning

##Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

##Install `seagrass` package

In [None]:
%%capture

!pip install git+https://github.com/Max-FM/seagrass.git

##Load Sentinel 2 image and projected seagrass data

In [None]:
from seagrass.raster import open_and_match_rasters

In [None]:
#  Scenes of interest are numbers 7, 8, 11, 12 & 13, particulary 11 and 12.
scene_number = '000011'
s2_filepath = f'/content/drive/MyDrive/Bathymetry/BancDarguin_s2cldmdn_{scene_number}.tif'
seagrass_filepath = '/content/drive/Shareddrives/1_Satellite_Derived_Bathymetry & coastal veg/Banc dArguin bathymetry & seagrass/seagrass_geotiff/seagrass_combined_clipped.tif'

In [None]:
s2, seagrass_map = open_and_match_rasters(s2_filepath, seagrass_filepath)

##Mask out land pixels

In [None]:
#  Normalised Difference Water Index (NDWI)
def ndwi(s2):
    green = s2[4]
    nir = s2[9]
    return (green - nir) / (green + nir)

def ndvi(s2):
    red = s2[5]
    nir = s2[9]
    return (nir - red)/(nir + red)

#  Creates a land pixel mask using the NDWI as a threshold.
def land_mask(s2):
    return (ndwi(s2).values < -0.1) & (ndvi(s2).values < 0.1)    

In [None]:
mask = land_mask(s2)==False
seagrass_masked = seagrass_map.where(mask, -9999)

##Plot Sentinel 2 and seagrass images

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from rasterio.plot import show
import numpy as np

In [None]:
def normalize(array):
    """Normalizes numpy arrays into scale 0.0 - 1.0"""
    array_min, array_max = array.min(), array.max()
    return ((array - array_min)/(array_max - array_min))

def make_composite(band_1, band_2, band_3):
    """Converts three raster bands into a composite image"""
    return normalize(np.dstack((band_1, band_2, band_3)))

def make_s2_rgb(s2_raster):
    red = s2_raster[5]
    green = s2_raster[4]
    blue = s2_raster[3]

    return make_composite(red, green, blue)

In [None]:
rgb = make_s2_rgb(s2)

rgb.shape

In [None]:
rgb_masked = rgb.copy()
rgb_masked[land_mask(s2)] = 1

In [None]:
# fig, (ax1, ax2) = plt.subplots(2,2, figsize=(30,30))
# ndwi_fig = ax1[0].imshow(ndwi(s2))
# fig.colorbar(ndwi_fig, ax=ax1[0])
# ax1[0].set_title('NDWI')
# ax2[0].hist(ndwi(s2).values.ravel(), bins=np.linspace(-1, 1, 50))
# ndvi_fig = ax1[1].imshow(ndvi(s2))
# ax1[1].set_title('NDVI')
# fig.colorbar(ndvi_fig, ax=ax1[1])
# ax2[1].hist(ndvi(s2).values.ravel(), bins=np.linspace(-1, 1, 50))

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(30,15))
ax1.imshow(rgb)
ax2.imshow(rgb_masked)

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(30,15))
plt.figure(figsize=(15,15))
show(seagrass_map.where(seagrass_map!=-9999), ax=ax1)
show(seagrass_masked.where(seagrass_masked!=-9999), ax=ax2)

##Define features and targets for machine learning

In [None]:
from seagrass.prepare import create_training_data
from seagrass.utils import save_training_data

In [None]:
%%time

X, y = create_training_data(s2.values, seagrass_masked.values, no_data_value=-9999, bands=[3,4,5,6,7,8,9])

In [None]:
display(X, y)
display(X.shape, y.shape)

##Save training data to Modulos compatible tar file

In [None]:
# training_dir = '/content/drive/Shareddrives/1_Satellite_Derived_Bathymetry & coastal veg/Banc dArguin bathymetry & seagrass/seagrass_training_data'
training_dir = '.'

# Optional column header labels.
# cols = ['b', 'g', 'r', 're1', 're2', 're3', 'nir', 'b_g', 'g_g', 'r_g', 're1_g', 're2_g', 're3_g', 'nir_g', 'seagrass']
cols = ['b', 'g', 'r', 're1', 're2', 're3', 'nir', 'seagrass']

In [None]:
from datetime import date

timestamp = str(date.today())
train_filepath = f'{training_dir}/banc_d_arguin_seagrass_train_{scene_number}_{timestamp}.tar'

train_filepath

In [None]:
save_training_data(train_filepath, X, y, column_labels=cols)