In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fsspec
import time
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from carbonplan_trace.v1.glas_preprocess import preprocess
import carbonplan_trace.v1.glas_allometric_eq as allo
import carbonplan_trace.v1.utils as utils
from carbonplan_trace.v1.glas_height_metrics import get_all_height_metrics


from gcsfs import GCSFileSystem

fs = GCSFileSystem(cache_timeout=0)

In [None]:
# import warnings

# warnings.filterwarnings("ignore")

In [None]:
# run all tiles that doesn't exist in output yet

# def get_list_of_mask_tiles(include=""):
#     """
#     Ecoregions mask is stored in 10 degree tiles, grab the filepaths
#     """
#     no_data_tiles = ["00N_070E", "20N_120W", "30N_170W", "40N_070W"]

#     fs = GCSFileSystem(cache_timeout=0)
#     mask_folder = "gs://carbonplan-climatetrace/intermediates/ecoregions_mask/"
#     # fs.ls includes the parent folder itself, skip that link
#     mask_paths = [
#         tp
#         for tp in fs.ls(mask_folder)
#         if not tp.endswith("/") and include in tp
#     ]

#     all_lat_lon_tags = [
#         utils.get_lat_lon_tags_from_tile_path(tp) for tp in mask_paths
#     ]

#     lat_lon_tags = []
#     for lat, lon in all_lat_lon_tags:
#         fn = f"{lat}_{lon}"
#         output_path = f"gs://carbonplan-climatetrace/intermediates/biomass/{lat}_{lon}.zarr/.zmetadata"
#         if not fs.exists(output_path) and not fn in no_data_tiles:
#             lat_lon_tags.append((lat, lon))

#     return lat_lon_tags


# lat_lon_tags = get_list_of_mask_tiles()
# # this should be in the order of min_lat, max_lat, min_lon, max_lon
# bounding_boxes = [
#     utils.parse_bounding_box_from_lat_lon_tags(lat, lon)
#     for lat, lon in lat_lon_tags
# ]

# len(bounding_boxes)

In [None]:
import dask
from dask.distributed import Client

client = Client(n_workers=3, threads_per_worker=1)
client

In [None]:
# from dask_gateway import Gateway

# gateway = Gateway()
# options = gateway.cluster_options()
# options.worker_cores = 4
# options.worker_memory = 120
# cluster = gateway.new_cluster(cluster_options=options)
# cluster.adapt(minimum=1, maximum=10)
# cluster

In [None]:
# client = cluster.get_client()
# client

In [None]:
# from dask.distributed import PipInstall
# plugin = PipInstall(packages=["git+https://github.com/carbonplan/trace.git@debug_biomass#egg=carbonplan_trace"],
#                     pip_options=["-e"])
# client.register_worker_plugin(plugin)

In [None]:
import dask


@dask.delayed
def process_one_tile(bounding_box, skip_existing):
    min_lat, max_lat, min_lon, max_lon = bounding_box
    lat_tag, lon_tag = utils.get_lat_lon_tags_from_bounding_box(max_lat, min_lon)
    biomass_path = f"gs://carbonplan-climatetrace/intermediates/biomass/baccini_ground_all/{lat_tag}_{lon_tag}.zarr"
    preprocessed_path = (
        f"gs://carbonplan-climatetrace/intermediates/preprocessed_lidar/{lat_tag}_{lon_tag}.zarr"
    )

    with dask.config.set(scheduler="single-threaded"):
        if skip_existing and fs.exists(biomass_path + "/.zmetadata"):
            return ("skipped", biomass_path)

        if fs.exists(preprocessed_path + "/.zmetadata"):
            try:
                preprocessed = (
                    utils.open_zarr_file(preprocessed_path)
                    .stack(unique_index=("record_index", "shot_number"))
                    .dropna(dim="unique_index", subset=["lat"])
                )
            except:
                return ("failed to open lidar", biomass_path)

        else:
            try:
                # read in data, this step takes about 5 mins
                data01 = utils.open_glah01_data()
                data14 = utils.open_glah14_data()

                # subset data to the bounding box
                sub14 = utils.subset_data_for_bounding_box(
                    data14, min_lat, max_lat, min_lon, max_lon
                )
                sub01 = data01.where(data01.record_index.isin(sub14.record_index), drop=True)
                combined = sub14.merge(sub01, join="inner")

                if len(combined.record_index) == 0:
                    return ("no data in lidar", biomass_path)

                # preprocess data and persist
                preprocessed = preprocess(combined, min_lat, max_lat, min_lon, max_lon).compute()
                del combined, sub14, sub01

                if len(preprocessed.record_index) == 0:
                    return ("no data after preprocess", biomass_path)

                preprocessed["datetime"] = preprocessed.datetime.astype("datetime64[ns]")
                utils.save_to_zarr(
                    ds=preprocessed.unstack("unique_index").chunk(
                        {"record_index": 10000, "shot_number": 40}
                    ),
                    url=preprocessed_path,
                    mode="w",
                )
            except:
                return ("failed in preprocess", biomass_path)
        # calculate biomass

        try:
            with_biomass = allo.apply_allometric_equation(
                preprocessed, min_lat, max_lat, min_lon, max_lon
            )

            # saving output
            height_metrics = [
                "VH",
                "h25_Neigh",
                "h50_Neigh",
                "h75_Neigh",
                "h90_Neigh",
                "QMCH",
                "MeanH",
                "f_slope",
                "senergy",
            ]

            with_biomass = get_all_height_metrics(with_biomass, height_metrics).compute()
            variables = [
                "lat",
                "lon",
                "time",
                "biomass",
                "allometric_eq",
                "glas_elev",
                "ecoregion",
                "eosd",
                "nlcd",
                "igbp",
                "treecover2000_mean",
                "burned",
            ]
            utils.save_to_zarr(
                ds=with_biomass.unstack("unique_index").chunk(
                    {"record_index": 10000, "shot_number": 40}
                ),
                url=biomass_path,
                list_of_variables=variables + height_metrics,
                mode="w",
            )

            return ("processed", biomass_path)
        except:
            return ("failed", biomass_path)

In [None]:
# preprocessed_path = f"gs://carbonplan-climatetrace/intermediates/preprocessed_lidar/{tiles[10]}.zarr"

# preprocessed = (
#     utils.open_zarr_file(preprocessed_path)
#     .stack(unique_index=("record_index", "shot_number"))
#     .dropna(dim="unique_index", subset=["lat"])
# )

In [None]:
# fns = [
#     '70N_010E', '70N_020E',
#     '60N_000E', '60N_010E', '60N_020E', '60N_030E', '60N_040E', '60N_050E',
#     '50N_090W', '50N_080W', '50N_010W', '50N_000E', '50N_010E', '50N_020E', '50N_030E',
#     '40N_100W', '40N_090W', '40N_080W', '40N_030W', '40N_020W',
#     '30N_100W', '30N_090W',
# ]

In [None]:
min_lat = 30
max_lat = 40
min_lon = 60
max_lon = 140

tiles = utils.find_tiles_for_bounding_box(
    min_lat=min_lat, max_lat=max_lat, min_lon=min_lon, max_lon=max_lon
)
all_lat_lon_tags = [utils.get_lat_lon_tags_from_tile_path(tp) for tp in tiles]
bounding_boxes = [
    utils.parse_bounding_box_from_lat_lon_tags(lat, lon) for lat, lon in all_lat_lon_tags
]

In [None]:
len(tiles)

In [None]:
skip_existing = True

tasks = []
for bounding_box in bounding_boxes:
    tasks.append(process_one_tile(bounding_box, skip_existing))
#     process_one_tile(bounding_box, skip_existing)

In [None]:
results = dask.compute(tasks, retries=10)[0]

In [None]:
results

In [None]:
mapper = fsspec.get_mapper(
    "s3://carbonplan-climatetrace/v1/data/intermediates/annual_averaged_worldclim.zarr"
)
worldclim_ds = xr.open_zarr(mapper, consolidated=True)

In [None]:
worldclim_ds.nbytes / 1e9

In [None]:
worldclim_ds["BIO15"].max().values

In [None]:
worldclim_ds["BIO15"].min().values

In [None]:
for var in worldclim_ds.data_vars:
    df = worldclim_ds[var].to_dataframe()
    print(df.describe())

In [None]:
df

## Add ancillary data


In [None]:
biomass_folder = "gs://carbonplan-climatetrace/intermediates/biomass/"
biomass_paths = [path for path in fs.ls(biomass_folder) if not path.endswith("/")]

In [None]:
mapper = fsspec.get_mapper("gs://carbonplan-data/raw/worldclim/30s/raster.zarr")
worldclim = xr.open_zarr(mapper, consolidated=True).rename({"x": "lon", "y": "lat"})

In [None]:
worldclim

In [None]:
# group monthly worldclim data into seasons DJF MAM JJA SON
days_in_month = {
    1: 31,
    2: 28.25,
    3: 31,
    4: 30,
    5: 31,
    6: 30,
    7: 31,
    8: 31,
    9: 30,
    10: 31,
    11: 30,
    12: 31,
}

months_in_season = [
    (1, [12, 1, 2]),
    (4, [3, 4, 5]),
    (7, [6, 7, 8]),
    (10, [9, 10, 11]),
]

month_to_season = {}
for s, m in months_in_season:
    month_to_season.update({mm: s for mm in m})

In [None]:
import pandas as pd

monthly_variables = ["prec", "srad", "tavg", "tmax", "tmin", "vapr", "wind"]

seasons = []
seasonal_data = []
for season, months in months_in_season:
    weights = xr.DataArray(
        [days_in_month[m] for m in months],
        dims=["month"],
        coords={"month": months},
    )

    seasons.append(season)
    seasonal_data.append(
        worldclim[monthly_variables].sel(month=months).weighted(weights).mean(dim="month")
    )

seasonal_data = xr.concat(seasonal_data, pd.Index(seasons, name="season"))

In [None]:
seasonal_data

In [None]:
static_vars = [f"BIO{str(n).zfill(2)}" for n in range(1, 20)] + ["elev"]
static_data = worldclim[static_vars]

In [None]:
all_vars = (
    [
        "VH",
        "h25_Neigh",
        "h50_Neigh",
        "h75_Neigh",
        "h90_Neigh",
        "QMCH",
        "MeanH",
        "f_slope",
        "senergy",
        "lat",
        "lon",
        "time",
        "biomass",
        "allometric_eq",
        "glas_elev",
        "ecoregion",
        "eosd",
        "nlcd",
        "igbp",
        "treecover2000_mean",
        "burned",
    ]
    + static_vars
    + monthly_variables
)

In [None]:
from datetime import datetime, timezone
import time
import shutil

In [None]:
failed = []

for path in biomass_paths:
    if fs.exists(path + "/BIO01"):
        continue
    else:
        print(path)
        lat, lon = utils.get_lat_lon_tags_from_tile_path(path)

        # load the biomass data
        biomass = (
            utils.open_zarr_file(path)
            .stack(unique_index=("record_index", "shot_number"))
            .dropna(dim="unique_index", subset=["lat"])
        )

        # find the static data to index to
        records = utils.find_matching_records(data=static_data, lats=biomass.lat, lons=biomass.lon)
        for v in static_vars:
            biomass[v] = records[v]

        # find the seasonal data to index to
        biomass["datetime"] = xr.apply_ufunc(
            datetime.fromtimestamp,
            biomass.time,
            vectorize=True,
            dask="parallelized",
        )
        biomass["datetime"] = biomass.datetime.astype("datetime64[ns]")
        biomass["month"] = biomass.datetime.dt.month
        biomass["season"] = xr.apply_ufunc(
            month_to_season.__getitem__,
            biomass.month.astype(int),
            vectorize=True,
            dask="parallelized",
            output_dtypes=[int],
        )

        records = seasonal_data.sel(
            lat=biomass.lat,
            lon=biomass.lon,
            season=biomass.season,
            method="nearest",
        ).drop_vars(["lat", "lon", "season"])
        for v in monthly_variables:
            biomass[v] = records[v]

        local_path = f"/home/jovyan/temp/{lat}_{lon}.zarr"
        biomass["allometric_eq"] = biomass.allometric_eq.astype(np.dtype("<U35"))

        utils.save_to_zarr(
            ds=biomass.unstack("unique_index").chunk({"record_index": 10000, "shot_number": 40}),
            url=local_path,
            list_of_variables=all_vars,
            mode="w",
        )

        fs.rm(path, recursive=True)
        time.sleep(60)
        fs.put(local_path, path, recursive=True)
        time.sleep(60)
        shutil.rmtree(local_path)

In [None]:
print("done")