## Import libaries

In [None]:
# !pip uninstall dea_ml -y
# !pip install -e dea_ml
# !pip install git+https://github.com/digitalearthafrica/deafrica-sandbox-notebooks.git@minty-fresh-sandbox#subdirectory=Tools

In [None]:
import warnings
warnings.filterwarnings("ignore")

import json
import joblib
import fsspec
from odc.io.cgroups import get_cpu_quota

from dea_ml.core.feature_layer import create_features, get_xy_from_task
from dea_ml.helpers.json_to_taskstr import extract_taskstr_from_geojson
from dea_ml.helpers.io import download_file
from dea_ml.core.africa_geobox import AfricaGeobox
from dea_ml.core.predict_from_feature import PredictContext, predict_with_model
from dea_ml.config.product_feature_config import FeaturePathConfig

%load_ext autoreload
%autoreload 2 

## Analysis Params


In [None]:
# define the feature layer fucntion
from gm_mads_two_seasons import gm_mads_two_seasons
feature_layer_function = gm_mads_two_seasons 

#define the post_processing function
from post_processing import post_processing
post_process = post_processing

tiles_geojson = '../testing/eastern_cropmask/data/s2_tiles_eastern_aez.geojson'

model_path = '../testing/eastern_cropmask/results/gm_mads_two_seasons_ml_model_20210401.joblib'

#define the chunks to use for dask
dask_chunks = {'x': 5000, 'y':5000}


## Initiate configuration class

In [None]:
# import the configuration
config = FeaturePathConfig
config

## Open tiles and model

In [None]:
config.tiles_geojson = tiles_geojson

tile_geojson_url= config.tiles_geojson 

with fsspec.open(tile_geojson_url) as fh:
    tiles_geojson_dict = json.load(fh)

In [None]:
# Open model
config.model_path = model_path
ml_model_url = config.model_path

with fsspec.open(ml_model_url) as fh:
    model = joblib.load(fh)
model.n_jobs = round(get_cpu_quota()) #update model with cpus available on this machine
print(model)

## Generate 'tasks' based on tiles


In [None]:
tasks = extract_taskstr_from_geojson(time_range='2019-01--P6M', geojson=tiles_geojson_dict)
len(tasks)

## Generate features for model


First generate a dictionary of geobox's for each tile

In [None]:
x, y = get_xy_from_task(tasks[0])

#update tile id
x = x+181
y = y+77
print(x,y)

geobox_dict = AfricaGeobox().geobox_dict

Pass the feature layer function into the `create_features` function

**Note:** This will take a couple of minutes to run.


In [None]:
%%time
subfld, geobox, data = create_features(x,
                                       y,
                                       config,
                                       geobox_dict,
                                       feature_func=gm_mads_two_seasons,
                                       dask_chunks=dask_chunks) 

print(data)

## Run prediction

**Note:** This will take a couple of minutes to run as the calculations are computed and brought into memory

In [None]:
pff = PredictContext(config, geobox_dict)
pff.client

In [None]:
%%time
predicted = predict_with_model(config,model,data).compute()
print(predicted)

## Post processing

In [None]:
%%time
predict, proba, filtered = post_process(data, predicted, config, geobox)

## save data
- result will be in: 
```/home/jovyan/wa/u23/data/crop_mask_eastern/v0.1.7/x+029/y+000/2019```

In [None]:
pff.save_data(subfld, predict, proba, filtered, geobox)

In [None]:
import xarray as xr
xr.open_rasterio('/g/data/crop_mask_eastern_data/crop_mask_eastern/v0.1.8/x210/y077/2019/crop_mask_eastern_x210_y077_2019_prob.tif').plot(size=10)

In [None]:
# !ls /home/jovyan/wa/u23/data/crop_mask_eastern/v0.1.7/x+029/y+000/2019