## Import libaries

In [None]:
#!pip uninstall dea_ml -y
# !pip install -e dea_ml
# !pip install git+https://github.com/digitalearthafrica/deafrica-sandbox-notebooks.git@minty-fresh-sandbox#subdirectory=Tools

In [1]:
import warnings
warnings.filterwarnings("ignore")

import json
import joblib
import fsspec
from odc.io.cgroups import get_cpu_quota

from dea_ml.core.feature_layer import create_features, get_xy_from_task
from dea_ml.helpers.json_to_taskstr import extract_taskstr_from_geojson
from dea_ml.helpers.io import download_file
from dea_ml.core.africa_geobox import AfricaGeobox
from dea_ml.core.predict_from_feature import PredictContext, predict_with_model
from dea_ml.config.product_feature_config import FeaturePathConfig

%load_ext autoreload
%autoreload 2 

## Analysis Params


In [2]:
# define the feature layer fucntion
from gm_mads_two_seasons import gm_mads_two_seasons
feature_layer_function = gm_mads_two_seasons 

#define the post_processing function
from post_processing import post_processing
post_process = post_processing

tiles_geojson = '../testing/eastern_cropmask/data/s2_tiles_eastern_aez.geojson'

model_path = '../testing/eastern_cropmask/results/gm_mads_two_seasons_ml_model_20210401.joblib'

#define the chunks to use for dask
dask_chunks = {'x':'auto', 'y':'auto'}


## Initiate configuration class

In [3]:
# import the configuration
config = FeaturePathConfig
config

dea_ml.config.product_feature_config.FeaturePathConfig

## Open tiles and model

In [4]:
config.tiles_geojson = tiles_geojson

tile_geojson_url= config.tiles_geojson 

with fsspec.open(tile_geojson_url) as fh:
    tiles_geojson_dict = json.load(fh)

In [5]:
# Open model
config.model_path = model_path
ml_model_url = config.model_path

with fsspec.open(ml_model_url) as fh:
    model = joblib.load(fh)
model.n_jobs = round(get_cpu_quota()) #update model with cpus available on this machine
print(model)

RandomForestClassifier(max_features='log2', n_estimators=300, n_jobs=15,
                       random_state=1)


## Generate 'tasks' based on tiles


In [6]:
tasks = extract_taskstr_from_geojson(time_range='2019-01--P6M', geojson=tiles_geojson_dict)
len(tasks)

390

## Generate features for model


First generate a dictionary of geobox's for each tile

In [7]:
x, y = get_xy_from_task(tasks[0])
#update tile id
x = x+181
y = y+77
print(x,y)

geobox_dict = AfricaGeobox().geobox_dict

210 77


Pass the feature layer function into the `create_features` function

**Note:** This will take a couple of minutes to run.


In [8]:
subfld, geobox, data = create_features(x,
                                       y,
                                       config,
                                       geobox_dict,
                                       feature_func=gm_mads_two_seasons,
                                       dask_chunks=dask_chunks) 

print(data)

<xarray.Dataset>
Dimensions:        (x: 4800, y: 4800)
Coordinates:
    time           datetime64[ns] 2019-07-02T11:59:59.999999
  * y              (y) float64 9.599e+04 9.597e+04 9.595e+04 ... 50.0 30.0 10.0
  * x              (x) float64 2.784e+06 2.784e+06 ... 2.88e+06 2.88e+06
    spatial_ref    int32 6933
    band           int64 1
Data variables:
    blue_S1        (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    green_S1       (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    red_S1         (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    nir_S1         (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    swir_1_S1      (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    swir_2_S1      (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    red_edge_1_S1  (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    red_edge_2_S1  (y, x) float32 dask.arra

## Run prediction

**Note:** This will take a couple of minutes to run as the calculations are computed and brought into memory

In [9]:
pff = PredictContext(config, geobox_dict)
pff.client

0,1
Client  Scheduler: inproc://10.95.105.28/4031/1  Dashboard: /user/chad/proxy/8787/status,Cluster  Workers: 1  Cores: 15  Memory: 96.64 GB


In [10]:
predicted = predict_with_model(config, model, data)

predicting...
   probabilities...
   input features...


In [11]:
print(predicted)

<xarray.Dataset>
Dimensions:        (x: 4800, y: 4800)
Coordinates:
  * x              (x) float64 2.784e+06 2.784e+06 ... 2.88e+06 2.88e+06
  * y              (y) float64 9.599e+04 9.597e+04 9.595e+04 ... 50.0 30.0 10.0
    spatial_ref    int32 0
Data variables:
    Predictions    (y, x) int64 dask.array<chunksize=(200, 4800), meta=np.ndarray>
    Probabilities  (y, x) float64 dask.array<chunksize=(200, 4800), meta=np.ndarray>
    red_S1         (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    blue_S1        (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    green_S1       (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    nir_S1         (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    swir_1_S1      (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    swir_2_S1      (y, x) float32 dask.array<chunksize=(4800, 4800), meta=np.ndarray>
    red_edge_1_S1  (y, x) float32 dask.array<chunksi

## Post processing

In [13]:
# there some minor issue with dask_ml functions to raise the warning.
predict, proba, mode = post_process(data, predicted, config, geobox)

In [None]:
print(prob)

## save data
- result will be in: 
```/home/jovyan/wa/u23/data/crop_mask_eastern/v0.1.7/x+029/y+000/2019```

In [None]:
# pff.save_data(subfld, predict, prob, geobox)

In [None]:
# !ls /home/jovyan/wa/u23/data/crop_mask_eastern/v0.1.7/x+029/y+000/2019