In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fsspec
import time

import numpy as np
import pandas as pd
import xarray as xr

from carbonplan_trace.v1.glas_allometric_eq import (
    get_list_of_mask_tiles,
    subset_data_for_tile,
    apply_allometric_equation,
)
from carbonplan_trace.v1.glas_preprocess import preprocess
from carbonplan_trace.v1.utils import convert_long3_to_long1, save_to_zarr

from dask.diagnostics import ProgressBar

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
# from carbonplan_trace.v1.glas_extract import (
#     extract_GLAH14_data,
#     extract_GLAH01_data,
# )

# # test extraction combined data

# f01 = "/home/jovyan/data/glas/example/GLAH01_033_2107_003_0241_4_02_0001.H5"
# f14 = "/home/jovyan/data/glas/example/GLAH14_634_2107_003_0239_0_01_0001.H5"

# data01 = extract_GLAH01_data(f01).chunk(
#     {"record_index": 1000, "shot_number": 10}
# )
# data14 = extract_GLAH14_data(f14).chunk(
#     {"record_index": 1000, "shot_number": 10}
# )

# combined = data14.merge(data01, join="inner")
# combined = combined.chunk({"record_index": 1000, "shot_number": 10})

In [None]:
# read in data
mapper01 = fsspec.get_mapper("gs://carbonplan-scratch/glas_01.zarr")
data01 = xr.open_zarr(
    mapper01
)  # .chunk({"record_index": 1000, "shot_number": 10})

mapper14 = fsspec.get_mapper("gs://carbonplan-scratch/glah14.zarr")
data14 = xr.open_zarr(
    mapper14
)  # .chunk({"record_index": 1000, "shot_number": 10})

In [None]:
# tile_paths = get_list_of_mask_tiles()
tile_paths = [
    "carbonplan-scratch/trace_scratch/ecoregions_mask/50N_120W.zarr"
]  # just for testing

# glas lon data go from 0-360 instead of -180-180, convert
data14["lon"] = convert_long3_to_long1(data14.lon)

for tp in tile_paths:
    # subset the data within a 10x10 degree bounding box to make processing easier
    sub14 = subset_data_for_tile(data=data14, tile_path=tp)
    sub01 = data01.where(
        data01.record_index.isin(sub14.record_index), drop=True
    )
    combined = sub14.merge(sub01, join="inner")
    preprocessed = preprocess(combined).compute()

    with_biomass = apply_allometric_equation(preprocessed, tp)

    fn = tp.split("/")[-1]

    save_to_zarr(
        with_biomass.unstack("unique_index").chunk(
            {"record_index": 1000, "shot_number": 10}
        ),
        f"gs://carbonplan-scratch/trace_scratch/biomass/{fn}",
        list_of_variables=[
            "lat",
            "lon",
            "time",
            "x",
            "y",
            "ecoregion",
            "nlcd",
            "eosd",
            "igbp",
            "biomass",
            "allometric_eq",
        ],
        mode="w",
    )

In [None]:
with_biomass.unstack("unique_index").chunk(
    {"record_index": 1000, "shot_number": 10}
)

In [None]:
mapper = fsspec.get_mapper(
    f"gs://carbonplan-scratch/trace_scratch/biomass/{fn}"
)
check = xr.open_zarr(mapper)
check

In [None]:
(1.0 - check.biomass.isnull().mean().values) * (
    check.dims["record_index"] * check.dims["shot_number"]
)

In [None]:
# def main(ds, n):
#     # preprocess
#     t1 = time.time()

#     if "rec_wf_sample_dist" in ds and "processed_wf" in ds:
#         print("skipping preprocess")
#     else:
#         print("entering preprocess")
#         ds = preprocess(ds).compute()
#     t2 = time.time()

#     for dist_metric, func in DISTANCE_METRICS_MAP.items():
#         ds[dist_metric] = func(ds).compute()

#     for ht_metric, func in HEIGHT_METRICS_MAP.items():
#         ds[ht_metric] = func(ds).compute()

#     #     ds = apply_allometric_equation(ds).compute()
#     t3 = time.time()

#     print(f"preprocess took {(t2-t1) / 60. / n} min per record")
#     print(f"other processes took {(t3-t2) / 60. / n} min per record")

#     return ds

In [None]:
# p = main(combined, len(combined.record_index))

In [None]:
# dummy timing
# on average each record takes ~0.05 mins for preprocessing if processing 10 records at a time
# on average each record takes ~0.005 mins for other process if processing 10 records at a time
# total = 139.15 mins

# on average each record takes ~0.005 mins for preprocessing if processing 100 records at a time
# on average each record takes ~0.0005 mins for other process if processing 100 records at a time

# on average each record takes ~0.001 mins for preprocessing if processing 1000 records at a time
# on average each record takes ~7 * 10-5 mins for other process if processing 1000 records at a time

# on average each record takes ~0.0008 mins for preprocessing if processing 10000 records at a time
# on average each record takes ~4 * 10-5 mins for other process if processing 10000 records at a time

n = 928
for i in range(1):
    sub = combined.isel(record_index=slice(i * n, (i + 1) * n))
    p = main(sub, n)
#     print(p.biomass.values[0, 0])

In [None]:
def plot_shot(record):
    cut = 250
    bins = record.rec_wf_sample_dist.values[:-cut]
    plt.figure(figsize=(6, 10))
    #     plt.scatter(record.rec_wf.values[:-cut], bins, s=5, label="Raw")  # raw wf
    plt.plot(record.rec_wf.values[:-cut], bins, "b", label="Raw")
    # plot various variables found in GLAH14
    plt.plot(
        [-0.05, 0.5],
        np.array([record.sig_begin_dist, record.sig_begin_dist]),
        "r--",
        label="Signal Beginning",
    )
    plt.plot(
        [-0.05, 0.5],
        np.array([record.sig_end_dist, record.sig_end_dist]),
        "g--",
        label="Signal End",
    )

    # plot noise mean and std from GLAH01
    plt.plot(
        [record.noise_mean, record.noise_mean],
        [bins.min(), bins.max()],
        "0.5",
        label="Noise Mean",
    )
    n_sig = 3.5
    noise_threshold = record.noise_mean + n_sig * record.noise_sd
    plt.plot(
        [noise_threshold, noise_threshold],
        [bins.min(), bins.max()],
        color="0.5",
        linestyle="dashed",
        label="Noise Threshold",
    )

    # plot filtered wf
    plt.plot(
        record.processed_wf.values[:-cut] + record.noise_mean.values,
        bins,
        "k-",
        label="Filtered Waveform",
    )

    plt.scatter(
        record.gaussian_amp,
        record.gaussian_fit_dist,
        s=20,
        c="orange",
        label="Gaussian fits",
    )
    #     # plot percentile heights
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record["10th_distance"], record["10th_distance"]],
    #         "b--",
    #         label="10th Percentile",
    #     )
    #     plt.plot([-0.05, 0.5], [record.meanH_dist, record.meanH_dist], "c--", label="Mean H")
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record["90th_distance"], record["90th_distance"]],
    #         "m--",
    #         label="90th Percentile",
    #     )
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record.ground_peak_dist, record.ground_peak_dist],
    #         "y--",
    #         label="Ground Peak",
    #     )

    plt.gca().invert_yaxis()
    plt.xlabel("lidar return (volt)")
    plt.ylabel("distance from satelite (m)")
    plt.legend()
    plt.show()
    plt.close()

In [None]:
import random

In [None]:
pos = np.where((p.num_gaussian_peaks > 2) & p.mask)

for i in range(10):
    ind = random.randint(0, len(pos[0]))
    r = p.isel(record_index=pos[0][ind], shot_number=pos[1][ind])
    plot_shot(r)