# Process benchmarking inputs
Read in the raw benchmarking inputs and process them to make plottable data.

Note, we split this up from the [`benchmarking-make-inputs.ipynb`](benchmarking-make-inputs.ipynb) script because this task requires significantly less memory (and therefore a less expensive machine), but still takes some time so we only want to do it once. Operations should be achievable with < 35 GB memory.

In [None]:
import os

import pandas as pd
import xarray as xr

: 

In [None]:
# --- read in dat
s3_path = 's3://carbonplan-ocr/evaluation/1.0.0/'
savename = 'benchmarking-input-dat.zarr'
ds = xr.open_zarr(os.path.join(s3_path, savename))

In [None]:
# --- filter to four slices
#     (CONUS, west, east, testbox) and later (all data, non-burnable)
slicenames = {
    # "testbox": {"minlat": 42.7, "maxlat": 46.3, "minlon": -116.8, "maxlon": -112.8},
    'CONUS': None,
    # "West of -98": {"maxlon": -98},
    # "East of -98": {"minlon": -98},
}

outdict = {}  # to hold outputs
for slc_key, slc_val in slicenames.items():
    if slc_key == 'testbox':
        outdict[slc_key] = ds.sel(
            latitude=slice(slc_val['minlat'], slc_val['maxlat']),
            longitude=slice(slc_val['minlon'], slc_val['maxlon']),
        )
    elif slc_key == 'CONUS':
        outdict[slc_key] = ds.copy()

    elif slc_key == 'West of -98':
        outdict[slc_key] = ds.where(ds['longitude'] < slc_val['maxlon'], drop=True)

    elif slc_key == 'East of -98':
        outdict[slc_key] = ds.where(ds['longitude'] >= slc_val['minlon'], drop=True)

## Create N bins of data (for plotting distributions)

In [5]:
# --- set bin bounds
n_bins = 10000  # this number controls distribution resolution
bp_range = (0, 0.14)  # to confirm max: dsslice['burn_probability_2011'].max(skipna=True).compute()

## Bins across burn mask conditions
Weighted by cell area

In [10]:
import dask.array as da

df_dict = {}  # to hold results


# helper function to compute weighted histogram lazily
def weighted_histogram(data, weights, n_bins, data_range):
    # make sure chunks match
    weights = weights.rechunk(data.chunks)
    hist, bins = da.histogram(data, bins=n_bins, range=data_range, weights=weights)
    return hist, bins


# --- run the loop
for dsname, dsslice in outdict.items():
    # track progress
    print(f'Now solving: {dsname}')

    # define variables
    bp = dsslice['bp_2011']
    mask = dsslice['burn_mask']
    riley_b_mask = dsslice['riley_burnable_mask']
    cell_area = dsslice['cell_area']  # assumed same shape as bp

    # define masks
    burned = mask == 1
    unburned = mask == 0
    riley_unburnable = riley_b_mask == 0

    print('CHECK 1: Masks defined')

    # --- compute histograms lazily
    # ALL DATA
    burned_hist, bins = weighted_histogram(
        bp.where(burned).data, cell_area.where(burned).data, n_bins, bp_range
    )
    unburned_hist, _ = weighted_histogram(
        bp.where(unburned).data, cell_area.where(unburned).data, n_bins, bp_range
    )

    print('CHECK 2: All data wtd hists done ')

    # UNBURNABLE ONLY
    burned_hist_unburnable, _ = weighted_histogram(
        bp.where(burned & riley_unburnable).data,
        cell_area.where(burned & riley_unburnable).data,
        n_bins,
        bp_range,
    )
    unburned_hist_unburnable, _ = weighted_histogram(
        bp.where(unburned & riley_unburnable).data,
        cell_area.where(unburned & riley_unburnable).data,
        n_bins,
        bp_range,
    )

    print('CHECK 3: Unburnable data wtd hists done ')

    # --- bring histograms into memory
    bin_centers = (bins[:-1] + bins[1:]) / 2
    burned_hist, unburned_hist, burned_hist_unburnable, unburned_hist_unburnable = da.compute(
        burned_hist, unburned_hist, burned_hist_unburnable, unburned_hist_unburnable
    )

    print('CHECK 4: Hists in memory ')

    # --- normalize to density
    burned_density = burned_hist / burned_hist.sum()
    unburned_density = unburned_hist / unburned_hist.sum()
    burned_density_unburnable = burned_hist_unburnable / burned_hist_unburnable.sum()
    unburned_density_unburnable = unburned_hist_unburnable / unburned_hist_unburnable.sum()

    print('CHECK 5: Hists normalized')

    # --- combine into single dataframe
    tmpdict = {
        'bin_centers': bin_centers,
        'burned_BPdensity': burned_density,
        'unburned_BPdensity': unburned_density,
        'burned_BPdensity_NB': burned_density_unburnable,
        'unburned_BPdensity_NB': unburned_density_unburnable,
    }

    df_dict[dsname] = pd.DataFrame(tmpdict)

Now solving: CONUS
CHECK 1: Masks defined
CHECK 2: All data wtd hists done 
CHECK 3: Unburnable data wtd hists done 
CHECK 4: Hists in memory 
CHECK 5: Hists normalized


#### Save mask dict


In [11]:
import s3fs

s3 = s3fs.S3FileSystem()

for name, df in df_dict.items():
    path = (
        f's3://carbonplan-ocr/evaluation/1.0.0/benchmarking-processed/{name}_area-wt_maskdf.parquet'
    )
    with s3.open(path, 'wb') as f:
        df.to_parquet(f)