## Post processing of aboveground biomass dataset

### Input

Random forest model prediction results from inference.ipynb. These are parquet files (1 for each
landsat scene x year) with columns x, y, biomass. x, y are in lat/lon coordinates, and biomass is in
unit of Mg biomass / ha and only accounts for aboveground, live, woody biomass.

### Processes

For each 10x10 degree tile in our template

1. merge and mosaic all landsat scenes within a 10x10 degree tile for all years available and store
   the data in zarr format
2. fill gaps within the biomass dataset by xarray interpolate_na with linear method (first through
   dim time, then through dim x, then dim y)
3. mask with MODIS MCD12Q1 land cover dataset to only select the forest pixels
4. calculate belowground biomass and deadwood and litter


In [None]:
%load_ext autoreload
%autoreload 2

from pyproj import CRS
import boto3
from rasterio.session import AWSSession
from s3fs import S3FileSystem
aws_session = AWSSession(boto3.Session(),#profile_name='default'), 
                         requester_pays=True)
fs = S3FileSystem(requester_pays=True)
import xgboost as xgb

from osgeo.gdal import VSICurlClearCache
import rasterio as rio
import numpy as np
import xarray as xr
import dask
import os
import fsspec

import rioxarray # for the extension to load
import pandas as pd
from datetime import datetime

from dask_gateway import Gateway
from carbonplan_trace.v1.landsat_preprocess import access_credentials, test_credentials
from carbonplan_trace.v1.inference import predict, predict_delayed 
from carbonplan_trace.v1 import utils, postprocess, load
from carbonplan_trace.tiles import tiles
from carbonplan_trace.v1.landsat_preprocess import access_credentials, test_credentials
import prefect
from prefect import task, Flow, Parameter
from prefect.executors import DaskExecutor
from prefect.utilities.debug import raise_on_exception
from datetime import datetime as time


In [None]:
from carbonplan_trace import version
%reload_ext watermark
print(version)

In [None]:
watermark -d -n -t -u -v -p carbonplan_trace -h -m -g -r -b

In [None]:
dask.config.set({"array.slicing.split_large_chunks": False})
dask.config.set({"distributed.comm.timeouts.tcp": "50s"})
dask.config.set({"distributed.comm.timeouts.connect": "50s"})

In [None]:
kind_of_cluster = "remote"
if kind_of_cluster == "local":
    # spin up local cluster. must be on big enough machine
    from dask.distributed import Client

    local_cluster_client = Client(n_workers=30, threads_per_worker=1, resources={"workertoken": 1})

    local_cluster_client
elif kind_of_cluster == "remote":
    gateway = Gateway()
    options = gateway.cluster_options()
    options.environment = {
        "AWS_REQUEST_PAYER": "requester",
        "AWS_REGION_NAME": "us-west-2",
        #         "DASK_DISTRIBUTED__WORKER__RESOURCES__WORKERTOKEN": "1",
    }
    options.worker_cores = 1
    options.worker_memory = 31

    options.image = "carbonplan/trace-python-notebook:latest"
    cluster = gateway.new_cluster(cluster_options=options)
    #     cluster.adapt(minimum=0,maximum=2)
    cluster.scale(25)

In [None]:
postprocess._set_thread_settings()

In [None]:
cluster.shutdown()
# local_cluster_client.shutdown()

In [None]:
# gateway = Gateway()
# clusters = gateway.list_clusters()
# cluster = gateway.connect(clusters[0].name)

In [None]:
# cluster.shutdown()

In [None]:
client = cluster.get_client()
client

In [None]:
# cluster.shutdown()

In [None]:
access_key_id, secret_access_key = access_credentials()

In [None]:
tasks = []
# define starting and ending years (will want to go back to 2014 but that might not be ready right now)
year0, year1 = 2014, 2021
# define the size of subtile you want to work in (2 degrees recommended)
tile_degree_size = 2
# if you want to write the metadata for the zarr store
write_tile_metadata = True
chunks_dict = {"x": 1000, "y": 1000}

In [None]:
log_bucket = "s3://carbonplan-climatetrace/v1.1/postprocess_log/"
completed_subtiles = fs.ls(log_bucket)
completed_subtiles = [subtile.split("/")[-1].split(".txt")[0] for subtile in completed_subtiles]
len(completed_subtiles)

In [None]:
# try running africa first
running_tiles = [tile for tile in tiles if "N" in tile and "E" in tile and "1" not in tile]
running_tiles.extend([tile for tile in tiles if "W" in tile])
running_tiles.extend([tile for tile in tiles if "S" in tile])
running_tiles = ["00N_020E"]

In [None]:
len(running_tiles)

In [None]:
parameters_list = []
# for tile in tiles:
for tile in running_tiles:
    lat_tag, lon_tag = utils.get_lat_lon_tags_from_tile_path(tile)
    lat_lon_box = utils.parse_bounding_box_from_lat_lon_tags(lat_tag, lon_tag)
    # find the lat_lon_box for that tile
    min_lat, max_lat, min_lon, max_lon = lat_lon_box
    # initialize empty dataset. only need to do this once, and not if the tile has already been processed
    data_path = postprocess.initialize_empty_dataset(
        lat_tag, lon_tag, year0, year1, write_tile_metadata=write_tile_metadata
    )
    # now we'll split up each of those tiles into smaller subtiles of length `tile_degree_size`
    # and run through those. In this case since we've specified 2, we'll have 25 in each box

    prefect_parameters = {
        "MIN_LAT": min_lat,
        "MIN_LON": min_lon,
        "YEAR_0": year0,
        "YEAR_1": year1,
        "TILE_DEGREE_SIZE": tile_degree_size,
        "DATA_PATH": data_path,
        "ACCESS_KEY_ID": access_key_id,
        "SECRET_ACCESS_KEY": secret_access_key,
        "CHUNKS_DICT": chunks_dict,
    }

    for lat_increment in np.arange(0, 10, tile_degree_size):
        for lon_increment in np.arange(0, 10, tile_degree_size):
            task_tag = "{}_{}_{}_{}".format(min_lat, min_lon, lat_increment, lon_increment)
            if task_tag in completed_subtiles:
                continue
            else:
                increment_parameters = prefect_parameters.copy()
                increment_parameters["LAT_INCREMENT"] = lat_increment
                increment_parameters["LON_INCREMENT"] = lon_increment
                parameters_list.append(increment_parameters)
#         tasks.append(client.compute(postprocess_delayed(subtile_ul_lat, subtile_ul_lon, year0, year1, tile_degree_size, mapper)))

In [None]:
len(parameters_list)

In [None]:
# postprocess.postprocess_subtile(parameters_list[4])

In [None]:
if kind_of_cluster == "local":
    executor = DaskExecutor(address=local_cluster_client.scheduler.address)
elif kind_of_cluster == "remote":
    executor = DaskExecutor(
        address=client.scheduler.address,
        client_kwargs={"security": cluster.security},
        debug=True,
    )

In [None]:
def fail_nicely(task, old_state, new_state):
    if new_state.is_running():
        print("running!")
    if new_state.is_failed():
        print("this task {} failed".format(task))
        raise ValueError("OH NO")  # function that sends a notification
    return new_state

In [None]:
# prefect.engine.signals.state.Skipped()

In [None]:
postprocess_task = task(
    postprocess.postprocess_subtile,  # .test_to_zarr,#
    #     tags=["dask-resource:workertoken=1"],
    state_handlers=[fail_nicely],
)

In [None]:
with Flow("Postprocessing") as flow:
    # Run postprocess
    postprocess_task.map(parameters_list)

In [None]:
# with raise_on_exception():
# if running locally (no cluster)
#     flow.run()
# if running on cluster
flow.run(executor=executor)