In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fsspec
import time
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from carbonplan_trace.v1.glas_preprocess import preprocess
import carbonplan_trace.v1.glas_allometric_eq as allo
import carbonplan_trace.v1.utils as utils
from carbonplan_trace.v1.glas_height_metrics import get_all_height_metrics
from carbonplan_trace.utils import zarr_is_complete

# from gcsfs import GCSFileSystem
# fs = GCSFileSystem(cache_timeout=0)

In [None]:
from s3fs import S3FileSystem
from carbonplan_trace.v1.landsat_preprocess import access_credentials

access_key_id, secret_access_key = access_credentials()
fs = S3FileSystem(key=access_key_id, secret=secret_access_key)

In [None]:
import warnings

warnings.filterwarnings("ignore")

## local or remote cluster


In [None]:
cluster_type = "remote"  # 'local'

In [None]:
import dask
from dask.distributed import Client
from dask_gateway import Gateway

if cluster_type == "remote":
    gateway = Gateway()
    options = gateway.cluster_options()
    options.worker_cores = 1
    options.worker_memory = 120
    options.image = "carbonplan/trace-python-notebook:latest"
    cluster = gateway.new_cluster(cluster_options=options)
    cluster.adapt(minimum=1, maximum=10)
    client = cluster.get_client()
elif cluster_type == "local":
    client = Client(n_workers=2, threads_per_worker=1)
else:
    print("only cluster type of remote of local are supported")

In [None]:
client

## Plotting example waveforms


In [None]:
# from carbonplan_trace.v1.glas_height_metrics import plot_shot
# from carbonplan_trace.v1.glas_preprocess import get_modeled_waveform

# lat_tag = '00N'
# lon_tag = '010E'
# preprocessed_path = f"gs://carbonplan-climatetrace/intermediates/preprocessed_lidar/{lat_tag}_{lon_tag}.zarr"
# preprocessed = (
#     utils.open_zarr_file(preprocessed_path)
#     .stack(unique_index=("record_index", "shot_number"))
#     .dropna(dim="unique_index", subset=["lat"])
# )
# preprocessed['modeled_wf'] = get_modeled_waveform(preprocessed)
# record = preprocessed.isel(unique_index=0).load()
# plot_shot(record)

In [None]:
# import random
# for _ in range(10):
#     ind = random.randint(0, preprocessed.dims['unique_index'])
#     record = preprocessed.isel(unique_index=ind).load()
#     plot_shot(record)

## To run more tiles


In [None]:
import dask


@dask.delayed
def process_one_tile(bounding_box, skip_existing, access_key_id, secret_access_key):
    min_lat, max_lat, min_lon, max_lon = bounding_box
    lat_tag, lon_tag = utils.get_lat_lon_tags_from_bounding_box(max_lat, min_lon)
    biomass_path = f"s3://carbonplan-climatetrace/v1/biomass/{lat_tag}_{lon_tag}.zarr"
    preprocessed_path = (
        f"s3://carbonplan-climatetrace/v1/preprocessed_lidar/{lat_tag}_{lon_tag}.zarr"
    )
    with dask.config.set(scheduler="single-threaded"):
        from s3fs import S3FileSystem

        fs = S3FileSystem(key=access_key_id, secret=secret_access_key)

        if skip_existing and fs.exists(biomass_path + "/.zmetadata"):
            return ("skipped", biomass_path)

        try:
            assert fs.exists(preprocessed_path + "/.zmetadata")
            mapper = fs.get_mapper(preprocessed_path)
            preprocessed = (
                xr.open_zarr(mapper)
                .stack(unique_index=("record_index", "shot_number"))
                .dropna(dim="unique_index", subset=["lat"])
            )
            # filtering of null values stored as the maximum number for the datatype
            preprocessed = preprocessed.where(
                (preprocessed.rec_wf < 1e35).all(dim="rec_bin"), drop=True
            )
            assert preprocessed.dims["unique_index"] > 0
        except:
            # read in data, this step takes about 5 mins
            data01 = utils.open_glah01_data()
            data14 = utils.open_glah14_data()

            # subset data to the bounding box
            sub14 = utils.subset_data_for_bounding_box(data14, min_lat, max_lat, min_lon, max_lon)
            sub01 = data01.where(data01.record_index.isin(sub14.record_index), drop=True)
            combined = sub14.merge(sub01, join="inner")

            if len(combined.record_index) == 0:
                return ("no data in lidar", biomass_path)

            # preprocess data and persist
            preprocessed = preprocess(combined, min_lat, max_lat, min_lon, max_lon)
            del combined, sub14, sub01

            if len(preprocessed.record_index) == 0:
                return ("no data after preprocess", biomass_path)

            preprocessed["datetime"] = preprocessed.datetime.astype("datetime64[ns]")

            preprocessed = preprocessed.unstack("unique_index")
            preprocessed = preprocessed.chunk({"record_index": 500, "shot_number": 40})

            mapper = fs.get_mapper(preprocessed_path)
            mapper.clear()
            for v in list(preprocessed.keys()):
                if "chunks" in preprocessed[v].encoding:
                    del preprocessed[v].encoding["chunks"]
            preprocessed.to_zarr(mapper, mode="w", consolidated=True)

        # calculate biomass
        with_biomass = allo.apply_allometric_equation(
            preprocessed, min_lat, max_lat, min_lon, max_lon
        )

        # saving output
        height_metrics = [
            "VH",
            "h25_Neigh",
            "h50_Neigh",
            "h75_Neigh",
            "h90_Neigh",
            "QMCH",
            "MeanH",
            "f_slope",
            "senergy",
        ]

        with_biomass = get_all_height_metrics(with_biomass, height_metrics).compute()
        variables = [
            "lat",
            "lon",
            "time",
            "biomass",
            "allometric_eq",
            "glas_elev",
            "ecoregion",
            "eosd",
            "nlcd",
            "igbp",
            "treecover2000_mean",
            "burned",
        ]

        with_biomass = with_biomass.unstack("unique_index")[variables + height_metrics]
        with_biomass = with_biomass.chunk({"record_index": 500, "shot_number": 40})
        mapper = fs.get_mapper(biomass_path)
        for v in list(with_biomass.keys()):
            if "chunks" in with_biomass[v].encoding:
                del with_biomass[v].encoding["chunks"]
        with_biomass.to_zarr(mapper, mode="w", consolidated=True)

        return ("processed", biomass_path)

In [None]:
# run all tiles that doesn't exist in output yet


def get_list_of_mask_tiles(include=""):
    """
    Ecoregions mask is stored in 10 degree tiles, grab the filepaths
    """
    no_data_tiles = ["40N_070W", "30N_170W", "20N_120W", "00N_070E"]

    fs = S3FileSystem()
    mask_folder = "s3://carbonplan-climatetrace/intermediate/ecoregions_mask/"
    # fs.ls includes the parent folder itself, skip that link
    mask_paths = [tp for tp in fs.ls(mask_folder) if not tp.endswith("/") and include in tp]

    all_lat_lon_tags = [utils.get_lat_lon_tags_from_tile_path(tp) for tp in mask_paths]

    lat_lon_tags = []
    for lat, lon in all_lat_lon_tags:
        fn = f"{lat}_{lon}"
        output_path = f"s3://carbonplan-climatetrace/v1/biomass/{lat}_{lon}.zarr/.zmetadata"
        if not fs.exists(output_path) and not fn in no_data_tiles:
            lat_lon_tags.append((lat, lon))

    return lat_lon_tags


lat_lon_tags = get_list_of_mask_tiles()
# this should be in the order of min_lat, max_lat, min_lon, max_lon
bounding_boxes = [utils.parse_bounding_box_from_lat_lon_tags(lat, lon) for lat, lon in lat_lon_tags]

len(bounding_boxes)

In [None]:
# run all tiles within the lat/lon box

# min_lat = -90
# max_lat = 90
# min_lon = -180
# max_lon = 180

# tiles = utils.find_tiles_for_bounding_box(
#     min_lat=min_lat, max_lat=max_lat, min_lon=min_lon, max_lon=max_lon
# )
# all_lat_lon_tags = [utils.get_lat_lon_tags_from_tile_path(tp) for tp in tiles]
# bounding_boxes = [
#     utils.parse_bounding_box_from_lat_lon_tags(lat, lon)
#     for lat, lon in all_lat_lon_tags
# ]
# len(bounding_boxes)

In [None]:
skip_existing = True
tasks = []

for bounding_box in bounding_boxes:
    tasks.append(
        client.compute(
            process_one_tile(bounding_box, skip_existing, access_key_id, secret_access_key)
        )
    )

In [None]:
results = dask.compute(tasks, retries=1)[0]
results

In [None]:
for i, task in enumerate(tasks):
    if task.status != "pending":
        print(i)
        print(task.result())

In [None]:
# for task in tasks:
#     task.cancel()

In [None]:
mapper = fs.get_mapper("s3://carbonplan-climatetrace/v1/biomass/50N_120W.zarr")
ds = xr.open_zarr(mapper)

In [None]:
ds.stack(unique_index=("record_index", "shot_number")).dropna(dim="unique_index", subset=["lat"])