In [None]:
# Basic plots
%matplotlib inline
import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = [12, 8]

# Common imports and settings
import os, sys
os.environ['USE_PYGEOS'] = '0'
from IPython.display import Markdown
import pandas as pd
pd.set_option("display.max_rows", None)
import xarray as xr

# Datacube
import datacube
from datacube.utils.rio import configure_s3_access
from datacube.utils import masking
from datacube.utils.cog import write_cog
# https://github.com/GeoscienceAustralia/dea-notebooks/tree/develop/Tools
from dea_tools.plotting import display_map, rgb
from dea_tools.datahandling import mostcommon_crs

# EASI defaults
easinotebooksrepo = '/home/jovyan/easi-notebooks'
if easinotebooksrepo not in sys.path: sys.path.append(easinotebooksrepo)
from easi_tools import EasiDefaults, xarray_object_size, notebook_utils

In [None]:
# Data tools
import numpy as np
from datetime import datetime

# Datacube
from datacube.utils import masking  # https://github.com/opendatacube/datacube-core/blob/develop/datacube/utils/masking.py
from odc.algo import enum_to_bool   # https://github.com/opendatacube/odc-tools/blob/develop/libs/algo/odc/algo/_masking.py
from odc.algo import xr_reproject   # https://github.com/opendatacube/odc-tools/blob/develop/libs/algo/odc/algo/_warp.py
from datacube.utils.geometry import GeoBox, box  # https://github.com/opendatacube/datacube-core/blob/develop/datacube/utils/geometry/_base.py

# Holoviews, Datashader and Bokeh
import hvplot.pandas
import hvplot.xarray
import holoviews as hv
import panel as pn
import colorcet as cc
import cartopy.crs as ccrs
from datashader import reductions
from holoviews import opts
# import geoviews as gv
# from holoviews.operation.datashader import rasterize
hv.extension('bokeh', logo=False)

# Dask
from dask.distributed import Client, LocalCluster

In [None]:
easi = EasiDefaults()

family = 'sentinel-2'
product = easi.product(family)
display(Markdown(f'Default {family} product for "{easi.name}": [{product}]({easi.explorer}/products/{product})'))

In [None]:
# Local cluster
# Default is to run on a compute node with 28 GiB of available memory and 8 cores.
# We'll make that explicit here .. but this should be adjusted based on your workflow

# We can try different combinations of number of workers and memory per worker
# Maybe try:
# cluster.scale(n=4, memory="6GiB")
# cluster.scale(n=8, memory="3GiB")
# We can also try using a "dask-gateway" cluster > spin up many new worker pods with their own cpu/memory

cluster = LocalCluster(n_workers=2, threads_per_worker=4)
cluster.scale(n=2, memory="14GiB")
client = Client(cluster)
display(client)

dashboard_address = notebook_utils.localcluster_dashboard(client=client,server=easi.hub)
display(dashboard_address)

In [None]:
dc = datacube.Datacube()

# Access AWS "requester-pays" buckets
# This is necessary for reading data from most third-party AWS S3 buckets such as for Landsat and Sentinel-2
configure_s3_access(aws_unsigned=False, requester_pays=True, client=client);

In [None]:
from utils import load_data_geo
import geopandas as gpd
from deafrica_tools.areaofinterest import define_area
from datacube.utils.geometry import Geometry
import xarray as xr
train_path = "train/Soc Trang_Traning.shp"
train = load_data_geo(train_path)
train.head()

In [None]:
train = train.to_crs('EPSG:4326')

In [None]:
train.head()

In [None]:
train.head().explore(column="Name", legend=True)

In [None]:
train.crs

In [None]:
min_date = '2022-01-01' # 2021-11-01
max_date = '2022-02-01' # 2022-01-01
product = 's2_l2a'

loaded_datasets = {}


# Current workflow (I think)
# for each training point
# - get S2 data (utm)
# - apply mask
# - for each red,green,blue,nir
#    - apply scale/offset
#    - persist
#    - stack (result.merge)
# - add (dask xarray) to loaded_datasets dict
# - for each dict item
#    - calculate NDVI for the point's dask xarray
#    - do the actual read data and calculations
#    - = value at that point


# Proposed workflow
# 1. get bounding polygon for all training data points
# 2. dc.load with dask for bounding polygon (and all times when you're ready to try that)
#     - consider also remapping S2 data to lat/lon projection (e.g., epsg:4326) - may not be necessary
# 2a. apply S2 masking, scale, offset
# 3. calculate NDVI (still in dask so its a "virtual" on-demand calculation)
# 3a. use xarray.persist() to pre-calculate NDVI for all pixels in our bounding polygon
#     - more efficient to read and process all pixels than process each training point
# 4. for idx, point in train.iterrows():
#     -  get points from xarray (dask)
#        need to convert point lat/lon to S2 UTM or dc.load into epsg:4326
#        xarray data in S2 UTM project (output_crs, resolution)
#        point data in epsg:4326 (train.crs)
#     -  Store the loaded point data in the dictionary with a key based on the point index

# Test or check CRSs
# - Change training data (and geopolygon) CRS to "ncrs" (most common)
# - Or Load dc data into training data CRS (output_crs=epsg:4326, resolution:(-0.0001, 0.0001) (approximate degrees equivalent of 10 m)


# Iterate over each point in the GeoDataFrame
for idx, point in train.iterrows():
    # Create a bounding box around the point
    aoi = define_area(lat=point.geometry.y, lon=point.geometry.x, buffer=0)
    geopolygon = Geometry(aoi["features"][0]["geometry"], crs=train.crs)
    geopolygon_gdf = gpd.GeoDataFrame(geometry=[geopolygon], crs=train.crs)
    # Get the latitude and longitude range of the geopolygon
    lat_range = (geopolygon_gdf.total_bounds[1], geopolygon_gdf.total_bounds[3])
    lon_range = (geopolygon_gdf.total_bounds[0], geopolygon_gdf.total_bounds[2])
    #print(geopolygon_gdf.total_bounds)
    query = {
            "product": product,
            "x": lon_range,   # default assumed crs is epsg:4326, which is fine
            "y": lat_range,
            "time": (min_date, max_date),
    }
    ncrs = notebook_utils.mostcommon_crs(dc, query)   # UTM for the "most common" S2 MGRS grid
    # print(ncrs)
    query.update({
            "output_crs": ncrs,
            "resolution": (-10, 10),
            "dask_chunks": {'x': 2048, 'y': 2048}
     })
    # print(query)
    # print(qr)
    # break
    data = dc.load(**query)  # UTM for the "most common" S2 MGRS grid

    # Store the loaded dataset in the dictionary with a key based on the point index
    key = f'point_{idx + 1}'
    
    valid_mask = masking.valid_data_mask(data)
    
    measurement_info = dc.list_measurements().loc[query['product']]

    # Separate lists of measurement names and flag names
    measurement_names = measurement_info[ pd.isnull(measurement_info.flags_definition)].index
    flag_names        = measurement_info[pd.notnull(measurement_info.flags_definition)].index
    for flag in flag_names:
        notebook_utils.heading(f'Flag definition table for flag name: {flag}')
    display(masking.describe_variable_flags(data[flag]))
    flags_def = masking.describe_variable_flags(data[flag]).values
    flags_def = flags_def.tolist()[0][1]
    flag_name = 'scl'
    flag_data = data[[flag_name]].where(valid_mask[flag_name]).persist()   # Dataset

    good_pixel_flags = [flags_def[str(i)] for i in [4, 5, 6]]

    good_pixel_mask = enum_to_bool(data[flag_name], good_pixel_flags)
    rs = []
    for layer_name in ['red', 'green', 'blue', 'nir']:

        # Get scaling and offset values from product description
        scale = measurement_info.loc[layer_name].scale_factor
        offset = measurement_info.loc[layer_name].add_offset

        # Apply valid mask and good pixel mask
        layer = data[[layer_name]].where(valid_mask[layer_name] & good_pixel_mask) * scale + offset
        layer = layer.persist()
        rs.append(layer)
    result = rs[0].merge(rs[1])
    result = result.merge(rs[2])
    result = result.merge(rs[3])
    
    loaded_datasets[key] = result

In [None]:
len(loaded_datasets)

In [None]:
from deafrica_tools.bandindices import calculate_indices
import numpy as np

In [None]:
ndivi_dataset = {}
for i in loaded_datasets.keys():
    tmp = calculate_indices(loaded_datasets[i], index='NDVI', satellite_mission='s2')
    ndivi_dataset[i] = tmp.NDVI.mean(dim='time')

In [None]:
ndivi_dataset["point_1"].values

In [None]:
ndivi_dataset

In [None]:
# ndivi_dataset['point_1'].plot(cmap='RdYlGn',
#            size=6, vmin=-2, vmax=2,
# col_wrap=2)

In [None]:
ndivi_dataset['point_1'].values

In [None]:
labels = train.Name.values

In [None]:
X = []
for i in ndivi_dataset.keys():
    X.append(ndivi_dataset[i].values)

In [None]:
max(arr.shape for arr in X)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder

In [None]:
label_encoder = LabelEncoder()

# Fit and transform the labels
numeric_labels = label_encoder.fit_transform(labels)

In [None]:
label_mapping = dict(zip(labels, numeric_labels))

In [None]:
X_flat = np.vstack([arr.flatten() for arr in X])

In [None]:
vt_nan = [i[0] for i in np.argwhere(np.isnan(X_flat))]

In [None]:
x_new = []
lb_new = []
for i in range(len(X)):
    if i not in vt_nan:
        x_new.append(X[i])
        lb_new.append(numeric_labels[i])

In [None]:
x_new = np.vstack([arr.flatten() for arr in x_new])

In [None]:
X_train, X_test, y_train, y_test = train_test_split(x_new, lb_new, test_size=0.3, random_state=42)

In [None]:
model = RandomForestClassifier(n_estimators=150, random_state=42, criterion='gini', max_depth=2)
model.fit(X_train, y_train)

In [None]:
predictions = model.predict(X_test)

In [None]:
accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy: {accuracy}')

In [None]:
predictions

In [None]:
# Vietnam
min_longitude, max_longitude = (105.5, 106.4)
min_latitude, max_latitude = (9.2, 10.0)
min_date = '2022-01-01' # 2021-11-01
max_date = '2022-02-01' # 2022-01-01
product = 's2_l2a'

query1 = {
    'product': product,                     # Product name
    'x': (min_longitude, max_longitude),    # "x" axis bounds
    'y': (min_latitude, max_latitude),      # "y" axis bounds
    'time': (min_date, max_date),           # Any parsable date strings
}

# Most common CRS
native_crs = notebook_utils.mostcommon_crs(dc, query1)

query1.update({
    'output_crs': native_crs,               # EPSG code
    'resolution': (-10, 10),                # Target resolution
    'group_by': 'solar_day',                # Scene ordering
    'dask_chunks': {'x': 2048, 'y': 2048},  # Dask chunks
})


In [None]:
# Load data
data = dc.load(**query1)

notebook_utils.heading(notebook_utils.xarray_object_size(data))
display(data)

# Calculate valid (not nodata) masks for each layer
valid_mask = masking.valid_data_mask(data)
notebook_utils.heading('Valid data masks for each variable')
display(valid_mask)

In [None]:
# Measurement definitions for the selected product
measurement_info = dc.list_measurements().loc[query1['product']]
notebook_utils.heading(f'Measurement table for product: {query1["product"]}')
display(measurement_info)

# Separate lists of measurement names and flag names
measurement_names = measurement_info[pd.isnull(measurement_info.flags_definition)].index
flag_names        = measurement_info[pd.notnull(measurement_info.flags_definition)].index

notebook_utils.heading('Selected Measurement and Flag names')
display(pd.DataFrame({
    'group': ['Measurement names', 'Flag names'],
    'names': [', '.join(measurement_names), ', '.join(flag_names)]
}))

# Flag definitions
for flag in flag_names:
    notebook_utils.heading(f'Flag definition table for flag name: {flag}')
    display(masking.describe_variable_flags(data[flag]))

In [None]:
flags_def = masking.describe_variable_flags(data[flag]).values
flags_def = flags_def.tolist()[0][1]
# Make SCL flags image
flag_name = 'scl'
flag_data = data[[flag_name]].where(valid_mask[flag_name]).persist()   # Dataset
display(flag_data)
# Create Mask layer

good_pixel_flags = [flags_def[str(i)] for i in [4, 5, 6]]

good_pixel_mask = enum_to_bool(data[flag_name], good_pixel_flags)  # -> DataArray
# display(good_pixel_mask)  # Type: bool

In [None]:
# Select a layer and apply masking and scaling, then persist in dask
# layer_name = 'red'
rs = []
for layer_name in ['red', 'green', 'blue', 'nir']:

    # Get scaling and offset values from product description
    scale = measurement_info.loc[layer_name].scale_factor
    offset = measurement_info.loc[layer_name].add_offset

    # Apply valid mask and good pixel mask
    layer = data[[layer_name]].where(valid_mask[layer_name] & good_pixel_mask) * scale + offset
    layer = layer.persist()
    rs.append(layer)

In [None]:
import xarray as xr
result = rs[0].merge(rs[1])
result = result.merge(rs[2])
result = result.merge(rs[3])

In [None]:
ds1 = calculate_indices(result, index='NDVI', satellite_mission='s2')
ndvi = ds1["NDVI"]
average_ndvi = ndvi.mean(dim='time')

In [None]:
average_ndvi

In [None]:
data_array = xr.DataArray(np.zeros((8874, 9902)), dims=('y', 'x'))

In [None]:
for i in range(average_ndvi.values.shape[0]):
    for j in range(average_ndvi.values.shape[1]):
        x = average_ndvi.values[i][j]
        if np.isnan(x):
            data_array[i][j] = -1
        else:
            data_array[i][j] = model.predict([[x]])[0]

In [None]:
average_ndvi["labels"] = data_array

In [None]:
data_array