In [None]:
%load_ext autoreload
%autoreload 2

from pyproj import CRS
import boto3
from rasterio.session import AWSSession
from s3fs import S3FileSystem

aws_session = AWSSession(boto3.Session(), requester_pays=True)  # profile_name='default'),
fs = S3FileSystem(requester_pays=True)

import rasterio as rio
import numpy as np
import xarray as xr
import dask
import os
import fsspec

import rioxarray  # for the extension to load
import pandas as pd

from dask_gateway import Gateway
from carbonplan_trace.v1.landsat_preprocess import access_credentials, test_credentials
from carbonplan_trace.v1 import utils, load
from carbonplan_trace.tiles import tiles
import prefect
from prefect import task, Flow, Parameter
from prefect.executors import DaskExecutor
from prefect.utilities.debug import raise_on_exception
from datetime import datetime

In [None]:
from carbonplan_trace.v1 import postprocess as p
from carbonplan_trace.v1 import change_point_detection as c

In [None]:
dask.config.set({"array.slicing.split_large_chunks": False})
dask.config.set({"distributed.comm.timeouts.tcp": "50s"})
dask.config.set({"distributed.comm.timeouts.connect": "50s"})

In [None]:
kind_of_cluster = "remote"
if kind_of_cluster == "local":
    # spin up local cluster. must be on big enough machine
    from dask.distributed import Client

    local_cluster_client = Client(n_workers=15, threads_per_worker=1, resources={"workertoken": 1})

    local_cluster_client
elif kind_of_cluster == "remote":
    gateway = Gateway()
    options = gateway.cluster_options()
    options.environment = {
        "AWS_REQUEST_PAYER": "requester",
        "AWS_REGION_NAME": "us-west-2",
        "DASK_DISTRIBUTED__WORKER__RESOURCES__WORKERTOKEN": "1",
    }
    options.worker_cores = 1
    options.worker_memory = 10

    options.image = "carbonplan/trace-python-notebook:latest"
    cluster = gateway.new_cluster(cluster_options=options)
    cluster.adapt(minimum=0, maximum=150)
#     cluster.scale(150)

In [None]:
cluster

In [None]:
# local_cluster_client
client = cluster.get_client()
client

In [None]:
# p._set_thread_settings()

In [None]:
access_key_id, secret_access_key = access_credentials()

In [None]:
# define starting and ending years (will want to go back to 2014 but that might not be ready right now)
year0, year1 = 2014, 2021
# define the size of subtile you want to work in (2 degrees recommended)
tile_degree_size = 1
# if you want to write the metadata for the zarr store
write_tile_metadata = True
version = "v1.2"

In [None]:
log_bucket = f"s3://carbonplan-climatetrace/{version}/carbonpool/"

completed_subtiles = fs.ls(log_bucket)
completed_subtiles = [subtile.split("/")[-1].split(".txt")[0] for subtile in completed_subtiles]
len(completed_subtiles)

In [None]:
running_tiles = [tile for tile in tiles]  # if ("E" in tile)
running_tiles = running_tiles

In [None]:
len(running_tiles)

In [None]:
parameters_list = []
# for tile in tiles:
for tile in running_tiles:
    lat_tag, lon_tag = utils.get_lat_lon_tags_from_tile_path(tile)
    lat_lon_box = utils.parse_bounding_box_from_lat_lon_tags(lat_tag, lon_tag)
    # find the lat_lon_box for that tile
    min_lat, max_lat, min_lon, max_lon = lat_lon_box
    # initialize empty dataset. only need to do this once, and not if the tile has already been processed
    data_path = p.initialize_empty_dataset(
        lat_tag, lon_tag, year0, year1, write_tile_metadata=write_tile_metadata
    )
    # now we'll split up each of those tiles into smaller subtiles of length `tile_degree_size`
    # and run through those. In this case since we've specified 2, we'll have 25 in each box

    prefect_parameters = {
        "MIN_LAT": min_lat,
        "MIN_LON": min_lon,
        "YEAR_0": year0,
        "YEAR_1": year1,
        "TILE_DEGREE_SIZE": tile_degree_size,
        "DATA_PATH": data_path,
        "ACCESS_KEY_ID": access_key_id,
        "SECRET_ACCESS_KEY": secret_access_key,
        "LOG_BUCKET": log_bucket,
    }

    for lat_increment in np.arange(0, 10, tile_degree_size):
        for lon_increment in np.arange(0, 10, tile_degree_size):
            task_tag = "{}_{}_{}_{}".format(min_lat, min_lon, lat_increment, lon_increment)
            if task_tag in completed_subtiles:
                # if this subtile has already been ran, continue
                continue
            else:

                increment_parameters = prefect_parameters.copy()
                increment_parameters["LAT_INCREMENT"] = lat_increment
                increment_parameters["LON_INCREMENT"] = lon_increment
                parameters_list.append(increment_parameters)

In [None]:
len(parameters_list)

In [None]:
import random

random.shuffle(parameters_list)

In [None]:
# c.run_change_point_detection_for_subtile(parameters_list[0])

In [None]:
if kind_of_cluster == "local":
    executor = DaskExecutor(address=local_cluster_client.scheduler.address)
elif kind_of_cluster == "remote":
    executor = DaskExecutor(
        address=client.scheduler.address,
        client_kwargs={"security": cluster.security},
        debug=True,
    )

In [None]:
def fail_nicely(task, old_state, new_state):
    if new_state.is_running():
        print("running!")
    if new_state.is_failed():
        print("this task {} failed".format(task))
        raise ValueError("OH NO")  # function that sends a notification
    return new_state

In [None]:
change_point_detection_task = task(
    c.run_change_point_detection_for_subtile,
    tags=["dask-resource:workertoken=1"],
    state_handlers=[fail_nicely],
)

In [None]:
for i in range(4):
    print(i)
    batch_size = 1500
    with Flow("ChangePoint") as flow:
        change_point_detection_task.map(parameters_list[i * batch_size : (i + 1) * batch_size])
    flow.run(executor=executor)
    client.restart()

In [None]:
client.shutdown()

In [None]:
print(datetime.now())

In [None]:
# with raise_on_exception():
# if running locally (no cluster)
#     flow.run()
# if running on cluster
flow.run(executor=executor)

In [None]:
print(datetime.now())

In [None]:
client.shutdown()

In [None]:
from carbonplan_trace.v1.biomass_rollup import open_biomass_tile

In [None]:
ds = open_biomass_tile(version="v1.2", tile_id="30S_170E")
lat = -36.021375
lon = 173.877375
buffer = 1 / 40

In [None]:
patch = ds.sel(lat=slice(lat - buffer, lat + buffer), lon=slice(lon - buffer, lon + buffer))

In [None]:
patch.load()

In [None]:
patch.AGB_raw.plot(col="time", col_wrap=3, vmin=0)

In [None]:
patch.AGB.plot(col="time", col_wrap=3, vmin=0)