In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fsspec
import time

import numpy as np
import pandas as pd
import xarray as xr

from carbonplan_trace.v1.glas_preprocess import *
from carbonplan_trace.v1.glas_height_metrics import *
from carbonplan_trace.v1.glas_allometric_eq import *
from carbonplan_trace.v1.utils import *

from dask.diagnostics import ProgressBar

In [None]:
url01 = f"gs://carbonplan-scratch/trace_scratch/wa_glah01.zarr"
mapper01 = fsspec.get_mapper(url01)
data01 = xr.open_zarr(mapper01).chunk({"record_index": 1000, "shot_number": 10})

url14 = f"gs://carbonplan-scratch/trace_scratch/wa_glah14.zarr"
mapper14 = fsspec.get_mapper(url14)
data14 = xr.open_zarr(mapper14).chunk({"record_index": 1000, "shot_number": 10})

combined = data14.merge(data01)

In [None]:
combined = combined.chunk({"record_index": 1000, "shot_number": 10})

In [None]:
# url = 'gs://carbonplan-scratch/trace_scratch/wa_combined.zarr'
# save_to_zarr(combined, url)

In [None]:
def main(ds, n):
    # preprocess
    if "rec_wf_sample_distance" in ds and "processed_wf" in ds:
        print("skipping preprocess")
        pass
    else:
        print("preprocess")
        t1 = time.time()
        ds = preprocess(ds).compute()
        t2 = time.time()

    #         url = 'gs://carbonplan-scratch/trace_scratch/wa_preprocessed.zarr'
    #         save_to_zarr(ds, url, ['rec_wf_sample_distance', 'processed_wf'])

    # specific height metrics (still in units of "distance from satellite")
    ds["meanH_distance"] = get_mean_distance(
        bins=ds.rec_wf_sample_distance, wf=ds.processed_wf
    )

    # percentile distance
    percentiles = [10, 90]  # get 10th and 90th percentiles
    percentile_distances = get_percentile_distance(
        bins=ds.rec_wf_sample_distance,
        wf=ds.processed_wf,
        percentiles=percentiles,
    )
    for p in percentiles:
        ds[f"{p}th_distance"] = percentile_distances[p]

    print("getting ground peak distance")
    # get ground peak distance
    ds["ground_peak_distance"] = get_ground_peak_distance(
        bins=ds.rec_wf_sample_distance, wf=ds.processed_wf
    )

    # get heights from distance
    list_of_distance_vars = [
        "meanH_distance",
        "10th_distance",
        "90th_distance",
    ]
    ds = get_heights_from_distance(
        ds=ds,
        list_of_distance_vars=list_of_distance_vars,
        referece_distance_var="ground_peak_distance",
    )

    ds = apply_allometric_equation(ds).compute()
    t3 = time.time()

    print(f"preprocess took {(t2-t1) / 60. / n} min per record")
    print(f"other processes took {(t3-t2) / 60. / n} min per record")

    return ds


# could start a cluster
# ds.to_zarr('cloud')

In [None]:
# dummy timing
# on average each record takes ~0.05 mins for preprocessing if processing 10 records at a time
# on average each record takes ~0.005 mins for other process if processing 10 records at a time
# total = 139.15 mins

# on average each record takes ~0.005 mins for preprocessing if processing 100 records at a time
# on average each record takes ~0.0005 mins for other process if processing 100 records at a time

# on average each record takes ~0.001 mins for preprocessing if processing 1000 records at a time
# on average each record takes ~7 * 10-5 mins for other process if processing 1000 records at a time

# on average each record takes ~0.0008 mins for preprocessing if processing 10000 records at a time
# on average each record takes ~4 * 10-5 mins for other process if processing 10000 records at a time

n = 30000
for i in range(1):
    sub = combined.isel(record_index=slice(i * n, (i + 1) * n))
    p = main(sub, n)
    print(p.biomass.values[0, 0])

In [None]:
url = "gs://carbonplan-scratch/trace_scratch/wa_processed.zarr"
save_to_zarr(p, url)

In [None]:
record = p.sel(record_index=22232245, shot_number=33)
plot_shot(record)

In [None]:
from pyproj import transform

In [None]:
def transform_lat_lon(lat, lon):
    return np.array(transform(4326, 32610, lat, lon))

In [None]:
xr.apply_ufunc(
    transform_lat_lon,
    sub.lat,
    sub.lon,
    vectorize=True,
    dask='parallelized',
    dask_gufunc_kwargs={'allow_rechunk': 1},
    output_core_dims=[['lat_lon']],
    output_sizes=
    output_dtypes=np.float64
).compute()

In [None]:
df = pd.read_csv("index.csv")
df.head()

x0, y0, x1, y1 = [-124.763068, 45.543541, -116.915989, 49.002494]
time_start = "2003-02-20T00:00:00Z"
time_end = "2009-10-11T23:59:59Z"

df_wa = df[
    (time_start < df["SENSING_TIME"])
    & (df["SENSING_TIME"] < time_end)
    & (df["NORTH_LAT"] < y1)
    & (df["SOUTH_LAT"] > y0)
    & (df["WEST_LON"] > x0)
    & (df["EAST_LON"] < x1)
]

In [None]:
landsat_times = pd.to_datetime(df_wa.SENSING_TIME.sort_values()

In [None]:
ds = xr.DataArray(dims=["landsat_time"], coords=[landsat_times])

In [None]:
def get_nearest_time(glastime, list_of_time):
    # find nearest 
    
xr.apply_ufunc(
    get_nearest_time,
    sub.time,
    ds
)

In [None]:
# for each glas

ds.sel(landsat_time=glas_time, method="nearest")

# avail_times = list_of_time

# sub.sel(time=1, method=

# i have list of available timestamps
# for each record tell me which time is closest

In [None]:
def interpolate_wf(bins, wf, target, area_to_include):
    """
    not vectorized
    if area_to_include = 'above', the function returns the area between the target location to the "upper bound" (larger value) bin
    if area_to_include = 'below', the function returns the area between the target location to the "lower bound" (smaller value) bin
    """
    upper_ind = np.where(bins > target)[0].max()
    lower_ind = np.where(bins < target)[0].min()
    # since bins goes from large to small values, the "upper bound" index would be smaller than the "lower bound" index
    assert lower_ind - upper_ind == 1

    x_upper = bins[upper_ind]
    x_lower = bins[lower_ind]
    y_upper = wf[upper_ind]
    y_lower = wf[lower_ind]

    x_mid = (x_upper + x_lower) / 2.0
    x_span = x_upper - x_lower

    if area_to_include == "above":
        if target < x_mid:
            energy = (
                (x_mid - target) / x_span * y_lower
            )  # energy to add to the lower bin
            bin_to_modify = lower_ind
        else:
            energy = (
                (x_mid - target) / x_span * y_upper
            )  # energy to subtract out of the upper bin
            bin_to_modify = upper_ind
    elif area_to_include == "below":
        if target < x_mid:
            energy = (
                (target - x_mid) / x_span * y_lower
            )  # energy to subtract out of the lower bin
            bin_to_modify = lower_ind
        else:
            energy = (
                (target - x_mid) / x_span * y_upper
            )  # energy to add to the upper bin
            bin_to_modify = upper_ind
    else:
        raise NotImplementedError(
            "Please specify whether we want to include area above or below the target to the bounds"
        )

    return bin_to_modify, energy


def interpolate_and_select_valid_area(bins, wf, beg, end):
    """
    not vectorized
    """
    # initialize output
    output = np.zeros(len(wf))

    # within signal beginning and end locations, set otuput to be equal to input wf
    valid = np.where((bins > beg) & (bins < end))[0]
    output[valid] = wf[valid]

    # for the begining and end bin, interpolate
    # bins goes from large values (furthest away from satellite) to small (closest to satellite)
    bin_to_modify, energy = interpolate_wf(bins, wf, beg, "above")
    bins[bin_to_modify] += energy
    bin_to_modify, energy = interpolate_wf(bins, wf, end, "below")
    bins[bin_to_modify] += energy

    # min at 0
    output = np.maximum(output, 0)

    return output