In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fsspec
import time
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from carbonplan_trace.v1.glas_preprocess import preprocess
import carbonplan_trace.v1.glas_allometric_eq as allo
import carbonplan_trace.v1.utils as utils
from carbonplan_trace.v1.glas_height_metrics import get_all_height_metrics


from gcsfs import GCSFileSystem

fs = GCSFileSystem(cache_timeout=0)

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
def get_list_of_mask_tiles(include=""):
    """
    Ecoregions mask is stored in 10 degree tiles, grab the filepaths
    """
    fs = GCSFileSystem(cache_timeout=0)
    mask_folder = "gs://carbonplan-climatetrace/intermediates/ecoregions_mask/"
    # fs.ls includes the parent folder itself, skip that link
    mask_paths = [
        tp
        for tp in fs.ls(mask_folder)
        if not tp.endswith("/") and include in tp
    ]

    all_lat_lon_tags = [
        utils.get_lat_lon_tags_from_tile_path(tp) for tp in mask_paths
    ]

    lat_lon_tags = []
    for lat, lon in all_lat_lon_tags:
        output_path = f"gs://carbonplan-climatetrace/intermediates/preprocessed_lidar/{lat}_{lon}.zarr/.zmetadata"
        if not fs.exists(output_path):
            lat_lon_tags.append((lat, lon))

    return lat_lon_tags


lat_lon_tags = get_list_of_mask_tiles()
# this should be in the order of min_lat, max_lat, min_lon, max_lon
bounding_boxes = [
    utils.parse_bounding_box_from_lat_lon_tags(lat, lon)
    for lat, lon in lat_lon_tags
]

In [None]:
len(bounding_boxes)

In [None]:
import dask
from dask.distributed import Client

client = Client(n_workers=1, threads_per_worker=4)
client

In [None]:
# from dask_gateway import Gateway

# gateway = Gateway()
# options = gateway.cluster_options()
# options.worker_cores = 4
# options.worker_memory = 120
# cluster = gateway.new_cluster(cluster_options=options)
# cluster.adapt(minimum=1, maximum=10)
# cluster

In [None]:
# client = cluster.get_client()
# client

In [None]:
# from dask.distributed import PipInstall
# plugin = PipInstall(packages=["git+https://github.com/carbonplan/trace.git@debug_biomass#egg=carbonplan_trace"],
#                     pip_options=["-e"])
# client.register_worker_plugin(plugin)

In [None]:
@dask.delayed
def process_one_tile(bounding_box, skip_existing):
    min_lat, max_lat, min_lon, max_lon = bounding_box
    lat_tag, lon_tag = utils.get_lat_lon_tags_from_bounding_box(
        max_lat, min_lon
    )
    biomass_path = f"gs://carbonplan-climatetrace/intermediates/biomass/{lat_tag}_{lon_tag}.zarr"
    preprocessed_path = f"gs://carbonplan-climatetrace/intermediates/preprocessed_lidar/{lat_tag}_{lon_tag}.zarr"

    with dask.config.set(scheduler="single-threaded"):
        if skip_existing and fs.exists(biomass_path + "/MeanH"):
            return ("skipped", biomass_path)

        if fs.exists(preprocessed_path + "/.zmetadata"):
            preprocessed = (
                open_zarr_file(preprocessed_path)
                .stack(unique_index=("record_index", "shot_number"))
                .dropna(dim="unique_index", subset=["lat"])
            )
        else:
            # read in data, this step takes about 5 mins
            data01 = utils.open_glah01_data()
            data14 = utils.open_glah14_data()

            # subset data to the bounding box
            sub14 = utils.subset_data_for_bounding_box(
                data14, min_lat, max_lat, min_lon, max_lon
            )
            sub01 = data01.where(
                data01.record_index.isin(sub14.record_index), drop=True
            )
            combined = sub14.merge(sub01, join="inner")

            if len(combined.record_index) == 0:
                return ("no data", biomass_path)

            # preprocess data and persist
            preprocessed = preprocess(
                combined, min_lat, max_lat, min_lon, max_lon
            ).compute()
            del combined, sub14, sub01

            preprocessed["datetime"] = preprocessed.datetime.astype(
                "datetime64[ns]"
            )
            utils.save_to_zarr(
                ds=preprocessed.unstack("unique_index").chunk(
                    {"record_index": 10000, "shot_number": 40}
                ),
                url=preprocessed_path,
                mode="w",
            )

        # calculate biomass
        with_biomass = allo.apply_allometric_equation(
            preprocessed, min_lat, max_lat, min_lon, max_lon
        )

        # saving output
        height_metrics = [
            "VH",
            "h25_Neigh",
            "h50_Neigh",
            "h75_Neigh",
            "h90_Neigh",
            "QMCH",
            "MeanH",
            "f_slope",
            "senergy",
        ]
        with_biomass = get_all_height_metrics(
            with_biomass, height_metrics
        ).compute()
        variables = [
            "lat",
            "lon",
            "time",
            "biomass",
            "allometric_eq",
            "glas_elev",
            "ecoregion",
            "eosd",
            "nlcd",
            "igbp",
            "treecover2000_mean",
            "burned",
        ]
        utils.save_to_zarr(
            ds=with_biomass.unstack("unique_index").chunk(
                {"record_index": 10000, "shot_number": 40}
            ),
            url=biomass_path,
            list_of_variables=variables + height_metrics,
            mode="w",
        )

        return ("processed", biomass_path)

In [None]:
skip_existing = True

tasks = []
for bounding_box in bounding_boxes:
    tasks.append(process_one_tile(bounding_box, skip_existing))
results = dask.compute(tasks, retries=1)[0]

In [None]:
results

In [None]:
print("done")

In [None]:
mapper = fsspec.get_mapper(
    f"gs://carbonplan-climatetrace/intermediates/biomass/60N_120W.zarr"
)
check = xr.open_zarr(mapper)
check

In [None]:
(1.0 - check.biomass.isnull().mean().values) * (
    check.dims["record_index"] * check.dims["shot_number"]
)

In [None]:
# import h5py
# import pandas as pd

# data_dir = "/home/jovyan/data/glas/example/"
# f01 = h5py.File(data_dir + 'GLAH01_033_2107_003_0241_4_02_0001.H5', "r")
# table1 = f01["ANCILLARY_DATA"].attrs['volt_table_1']
# volt_table = pd.DataFrame(
#     {
#         'ind': np.arange(len(table1)),
#         'volt_value': table1
#     }
# )

# volt_table.to_csv('/home/jovyan/trace/data/volt_to_digital_count.csv', index=False)

In [None]:
def plot_shot(record):
    cut = 250
    bins = record.rec_wf_sample_dist.values[:-cut]
    plt.figure(figsize=(6, 10))
    #     plt.scatter(record.rec_wf.values[:-cut], bins, s=5, label="Raw")  # raw wf
    plt.plot(record.rec_wf.values[:-cut], bins, "b", label="Raw")
    # plot various variables found in GLAH14
    plt.plot(
        [-0.05, 0.5],
        np.array([record.sig_begin_dist, record.sig_begin_dist]),
        "r--",
        label="Signal Beginning",
    )
    plt.plot(
        [-0.05, 0.5],
        np.array([record.sig_end_dist, record.sig_end_dist]),
        "g--",
        label="Signal End",
    )

    # plot noise mean and std from GLAH01
    plt.plot(
        [record.noise_mean, record.noise_mean],
        [bins.min(), bins.max()],
        "0.5",
        label="Noise Mean",
    )
    n_sig = 3.5
    noise_threshold = record.noise_mean + n_sig * record.noise_sd
    plt.plot(
        [noise_threshold, noise_threshold],
        [bins.min(), bins.max()],
        color="0.5",
        linestyle="dashed",
        label="Noise Threshold",
    )

    # plot filtered wf
    plt.plot(
        record.processed_wf.values[:-cut] + record.noise_mean.values,
        bins,
        "k-",
        label="Filtered Waveform",
    )

    plt.scatter(
        record.gaussian_amp,
        record.gaussian_fit_dist,
        s=20,
        c="orange",
        label="Gaussian fits",
    )
    #     # plot percentile heights
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record["10th_distance"], record["10th_distance"]],
    #         "b--",
    #         label="10th Percentile",
    #     )
    #     plt.plot([-0.05, 0.5], [record.meanH_dist, record.meanH_dist], "c--", label="Mean H")
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record["90th_distance"], record["90th_distance"]],
    #         "m--",
    #         label="90th Percentile",
    #     )
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record.ground_peak_dist, record.ground_peak_dist],
    #         "y--",
    #         label="Ground Peak",
    #     )

    plt.gca().invert_yaxis()
    plt.xlabel("lidar return (volt)")
    plt.ylabel("distance from satelite (m)")
    plt.legend()
    plt.show()
    plt.close()

In [None]:
import random

In [None]:
pos = np.where((p.num_gaussian_peaks > 2) & p.mask)

for i in range(10):
    ind = random.randint(0, len(pos[0]))
    r = p.isel(record_index=pos[0][ind], shot_number=pos[1][ind])
    plot_shot(r)