In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fsspec
import time

import numpy as np
import pandas as pd
import xarray as xr

from carbonplan_trace.v1.glas_extract import (
    extract_GLAH14_data,
    extract_GLAH01_data,
    read_dimensions,
)
from carbonplan_trace.v1.glas_preprocess import *
from carbonplan_trace.v1.glas_height_metrics import *
from carbonplan_trace.v1.glas_allometric_eq import *
from carbonplan_trace.v1.utils import *

from dask.diagnostics import ProgressBar

In [None]:
# # washington data

# url01 = f"gs://carbonplan-scratch/trace_scratch/wa_glah01.zarr"
# mapper01 = fsspec.get_mapper(url01)
# data01 = xr.open_zarr(mapper01).chunk({"record_index": 1000, "shot_number": 10})

# url14 = f"gs://carbonplan-scratch/trace_scratch/wa_glah14.zarr"
# mapper14 = fsspec.get_mapper(url14)
# data14 = xr.open_zarr(mapper14).chunk({"record_index": 1000, "shot_number": 10})

# combined = data14.merge(data01)
# combined = combined.chunk({"record_index": 1000, "shot_number": 10})

# url = 'gs://carbonplan-scratch/trace_scratch/wa_combined.zarr'
# save_to_zarr(combined, url)

In [None]:
# get combined data

f01 = "/home/jovyan/data/glas/example/GLAH01_033_2107_003_0241_4_02_0001.H5"
f14 = "/home/jovyan/data/glas/example/GLAH14_634_2107_003_0239_0_01_0001.H5"

data01 = extract_GLAH01_data(f01).chunk(
    {"record_index": 1000, "shot_number": 10}
)
data14 = extract_GLAH14_data(f14).chunk(
    {"record_index": 1000, "shot_number": 10}
)

combined = data14.merge(data01, join="inner")
combined = combined.chunk({"record_index": 1000, "shot_number": 10})

In [None]:
combined

In [None]:
def main(ds, n):
    # preprocess
    t1 = time.time()

    if "rec_wf_sample_dist" in ds and "processed_wf" in ds:
        print("skipping preprocess")
    else:
        print("entering preprocess")
        ds = preprocess(ds).compute()
    t2 = time.time()

    #     url = 'gs://carbonplan-scratch/trace_scratch/wa_preprocessed.zarr'
    #     save_to_zarr(ds, url, ['rec_wf_sample_dist', 'processed_wf'])

    for dist_metric, func in DISTANCE_METRICS_MAP.items():
        print(dist_metric)
        ds[dist_metric] = func(ds).compute()

    for ht_metric, func in HEIGHT_METRICS_MAP.items():
        print(ht_metric)
        ds[ht_metric] = func(ds).compute()

    #     ds = apply_allometric_equation(ds).compute()
    t3 = time.time()

    print(f"preprocess took {(t2-t1) / 60. / n} min per record")
    print(f"other processes took {(t3-t2) / 60. / n} min per record")

    return ds


# could start a cluster
# ds.to_zarr('cloud')

In [None]:
# dummy timing
# on average each record takes ~0.05 mins for preprocessing if processing 10 records at a time
# on average each record takes ~0.005 mins for other process if processing 10 records at a time
# total = 139.15 mins

# on average each record takes ~0.005 mins for preprocessing if processing 100 records at a time
# on average each record takes ~0.0005 mins for other process if processing 100 records at a time

# on average each record takes ~0.001 mins for preprocessing if processing 1000 records at a time
# on average each record takes ~7 * 10-5 mins for other process if processing 1000 records at a time

# on average each record takes ~0.0008 mins for preprocessing if processing 10000 records at a time
# on average each record takes ~4 * 10-5 mins for other process if processing 10000 records at a time

n = 928
for i in range(1):
    sub = combined.isel(record_index=slice(i * n, (i + 1) * n))
    p = main(sub, n)
#     print(p.biomass.values[0, 0])

In [None]:
def plot_shot(record):
    cut = 250
    bins = record.rec_wf_sample_dist.values[:-cut]
    plt.figure(figsize=(6, 10))
    #     plt.scatter(record.rec_wf.values[:-cut], bins, s=5, label="Raw")  # raw wf
    plt.plot(record.rec_wf.values[:-cut], bins, "b", label="Raw")
    # plot various variables found in GLAH14
    plt.plot(
        [-0.05, 0.5],
        np.array([record.sig_begin_dist, record.sig_begin_dist]),
        "r--",
        label="Signal Beginning",
    )
    plt.plot(
        [-0.05, 0.5],
        np.array([record.sig_end_dist, record.sig_end_dist]),
        "g--",
        label="Signal End",
    )

    # plot noise mean and std from GLAH01
    plt.plot(
        [record.noise_mean, record.noise_mean],
        [bins.min(), bins.max()],
        "0.5",
        label="Noise Mean",
    )
    n_sig = 3.5
    noise_threshold = record.noise_mean + n_sig * record.noise_sd
    plt.plot(
        [noise_threshold, noise_threshold],
        [bins.min(), bins.max()],
        color="0.5",
        linestyle="dashed",
        label="Noise Threshold",
    )

    # plot filtered wf
    plt.plot(
        record.processed_wf.values[:-cut] + record.noise_mean.values,
        bins,
        "k-",
        label="Filtered Waveform",
    )

    plt.scatter(
        record.gaussian_amp,
        record.gaussian_fit_dist,
        s=20,
        c="orange",
        label="Gaussian fits",
    )
    #     # plot percentile heights
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record["10th_distance"], record["10th_distance"]],
    #         "b--",
    #         label="10th Percentile",
    #     )
    #     plt.plot([-0.05, 0.5], [record.meanH_dist, record.meanH_dist], "c--", label="Mean H")
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record["90th_distance"], record["90th_distance"]],
    #         "m--",
    #         label="90th Percentile",
    #     )
    #     plt.plot(
    #         [-0.05, 0.5],
    #         [record.ground_peak_dist, record.ground_peak_dist],
    #         "y--",
    #         label="Ground Peak",
    #     )

    plt.gca().invert_yaxis()
    plt.xlabel("lidar return (volt)")
    plt.ylabel("distance from satelite (m)")
    plt.legend()
    plt.show()
    plt.close()

In [None]:
import random

In [None]:
pos = np.where((p.num_gaussian_peaks > 2) & p.mask)

for i in range(10):
    ind = random.randint(0, len(pos[0]))
    r = p.isel(record_index=pos[0][ind], shot_number=pos[1][ind])
    plot_shot(r)

In [None]:
url = "gs://carbonplan-scratch/trace_scratch/wa_processed.zarr"
save_to_zarr(p, url)

In [None]:
def interpolate_wf(bins, wf, target, area_to_include):
    """
    not vectorized
    if area_to_include = 'above', the function returns the area between the target location to the "upper bound" (larger value) bin
    if area_to_include = 'below', the function returns the area between the target location to the "lower bound" (smaller value) bin
    """
    upper_ind = np.where(bins > target)[0].max()
    lower_ind = np.where(bins < target)[0].min()
    # since bins goes from large to small values, the "upper bound" index would be smaller than the "lower bound" index
    assert lower_ind - upper_ind == 1

    x_upper = bins[upper_ind]
    x_lower = bins[lower_ind]
    y_upper = wf[upper_ind]
    y_lower = wf[lower_ind]

    x_mid = (x_upper + x_lower) / 2.0
    x_span = x_upper - x_lower

    if area_to_include == "above":
        if target < x_mid:
            energy = (
                (x_mid - target) / x_span * y_lower
            )  # energy to add to the lower bin
            bin_to_modify = lower_ind
        else:
            energy = (
                (x_mid - target) / x_span * y_upper
            )  # energy to subtract out of the upper bin
            bin_to_modify = upper_ind
    elif area_to_include == "below":
        if target < x_mid:
            energy = (
                (target - x_mid) / x_span * y_lower
            )  # energy to subtract out of the lower bin
            bin_to_modify = lower_ind
        else:
            energy = (
                (target - x_mid) / x_span * y_upper
            )  # energy to add to the upper bin
            bin_to_modify = upper_ind
    else:
        raise NotImplementedError(
            "Please specify whether we want to include area above or below the target to the bounds"
        )

    return bin_to_modify, energy


def interpolate_and_select_valid_area(bins, wf, beg, end):
    """
    not vectorized
    """
    # initialize output
    output = np.zeros(len(wf))

    # within signal beginning and end locations, set otuput to be equal to input wf
    valid = np.where((bins > beg) & (bins < end))[0]
    output[valid] = wf[valid]

    # for the begining and end bin, interpolate
    # bins goes from large values (furthest away from satellite) to small (closest to satellite)
    bin_to_modify, energy = interpolate_wf(bins, wf, beg, "above")
    bins[bin_to_modify] += energy
    bin_to_modify, energy = interpolate_wf(bins, wf, end, "below")
    bins[bin_to_modify] += energy

    # min at 0
    output = np.maximum(output, 0)

    return output