# Collapse SLR draws

This notebook reduces the 20k MC draws of SLR per scenario and workflow to a fixed number of draws by binning the draws into quantile bins based on GMSL in each year. Within each bin, the median LSL across all of the draws in that bin is taken for each scenario-year.

This workflow:

1. Blends all modeled scenarios (combo of workflow and SSP-RCP).
2. Discards any scenario/draw combos that exhibit negative GSLR at any point, b/c these are assumed to be non-physical. Also drops any draw combos with any single global SLR component that falls outside of the 0.5th-99.5th percentiles within that scenario. These are considered outliers
3. Linearly interpolates to annual values
4. Calculates the quantile of each scenario based on it's max LSLR value observed for any site-year
5. Within each year, calculates the 1st and 99th percentiles
6. Chooses 10 equally spaced bin centers including these two end points.
7. Creates 5cm-wide bins centered at those values.
8. Ensures we have at least 100 draws in each bin.
9. Creates a grouping of draws for each year for each of the 10 bins.
10. Shuffle draws within each bin to ensure randomness across batches

The resulting dataset has the following variables:
- sample_id (tuple of scenario, mc_sample_id): year, slr (0-9), batch (0-X)
- lsl_msl05: year, slr, batch, site_id
- gsl_msl05: year, slr, batch
- gsl_bin_center (median of gsl_msl05): year, slr
- bin_bounds: year, slr, bound (lower, upper)

# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from math import ceil

import geopandas as gpd
import numpy as np
import pandas as pd
import pint_xarray  # noqa:F401
import xarray as xr
from numpy.dtypes import StringDType
from pyCIAM.io import add_nearest_slrs
from rhg_compute_tools.xarray import dataset_from_delayed
from shared import (
    BIN_MAP_PATH,
    DIR_SLR_GMSL_RAW,
    PATH_PARAMS,
    PATH_SLIIDERS,
    PATH_SLIIDERS_OLD,
    PATH_SLR_INT,
    PATH_SLR_RAW,
)

In [3]:
# NOTE: Specs carried over from EPA SCC exercise

SITE_OUTLIER_QUANTILE = (
    0.01  # quantile of top and bottom of site-year-scenario draws to drop as outliers
)

COMPONENT_OUTLIER_QUANTILE = 0.005  # quantile range of each SLR component allowed

MIN_BIN_DRAWS = 100  # need this many draws per bin
BIN_WIDTH = 50  # bin width in mm
N_BINS = 10  # number of bins

# number of scenarios per bin (EPA SCC was 500, reduced for mem footprint)
N_BATCHES = 15

# our damage functions blend 5 years together, so we stagger the bin centroids such that
# over 5 years we get a denser distribution of GMSL
YEARLY_ROTATION = 5

# We can expand the support of the last N bins by allowing a lower number of draws
# This allows the support of these damage functions to span a range
# that will be encountered in later years (i.e. 2100-2300).
EXPAND_LAST_RANGE_N = 3
MIN_BIN_DRAWS_FINAL = 15

DAMAGE_FUNC_Y0 = 2020

IR_FILE = "gs://impactlab-data/coastal/data/int/exposure/impactregions/impact_regions_v0.1.parquet"
VSL_FILE_PATT = "/gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP{ssp}_new.nc4"
FAIR_GMSL_FILE = (
    "/gcs/impactlab-data/coastal/data/int/hazard/slr/fair/gmsl_pulse_fixed.nc4"
)

In [4]:
SLR_0_YR = pd.read_json(PATH_PARAMS).loc["slr_0_year", "values"]

In [5]:
from shared import start_dask_cluster

client, cluster = start_dask_cluster()
cluster.scale(128)

# from distributed import Client

# client = Client(n_workers=7)
# cluster = client.cluster

cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    â€¦

2026-02-04 02:17:25,393 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


## Adjust old SLIIDERS database

The most recent version of the input database to pyCIAM that uses Impact Regions was made for the original EPA SCC modeling, which was before the SLIIDERS format was finalized and the DSCIM-Coastal paper was published (which used ADM1 admin regions instead of IRs). So, we have to make a few updates to that old format to get it into the right format for the current version of pyCIAM. In particular:

* Drop the pre-defined SLR_site_id (which was matched to LocalizeSL CMIP5-era SLR projections) and merge on a lat/lon value, which gets matched to the nearest point in a SLR database on the fly.
* Rename the segment coordinate from "seg_ir_seg" to "seg"

In [6]:
# find stray old segment->centroid mapping lying around and use that to append seg_lon/
# seg_lat
locs = gpd.read_parquet(
    "gs://impactlab-data/coastal/data/raw/ciam_inputs/shapefiles/tmp/"
    "gtsm_stations_withiso_ciam.parquet"
).set_index("station_id")[["lon", "lat"]]
locs.index = "seg_" + locs.index.str.split("_").str[-1]
locs = locs.rename_axis(index="seg").to_xarray().rename(lon="seg_lon", lat="seg_lat")

sliiders = (
    xr.open_zarr(str(PATH_SLIIDERS_OLD))
    .drop_encoding()
    .rename(seg_ir_seg="seg")
    .sel(ssp=["SSP2", "SSP3", "SSP4"])
    .drop_vars("SLR_site_id")
)
sliiders = sliiders.assign()
sliiders["params"] = sliiders.params.astype(StringDType)

sliiders = sliiders.merge(locs.sel(seg=sliiders.seg).reset_coords(drop=True))

### Update VSL

In [7]:
ir_mapping = "IR_" + pd.read_parquet(
    IR_FILE, columns=["hierid"]
).reset_index().set_index("hierid").gadmid.astype(str)

In [11]:
vsl_ds = (
    xr.open_mfdataset(
        [VSL_FILE_PATT.format(ssp=s) for s in [2, 3, 4]],
        concat_dim="ssp",
        combine="nested",
    )
    .sel(valuation="vsl", scaling="epa_scaled", eta=1, drop=True)
    .to_array()
    .squeeze(drop=True)
    .rename(model="iam")
)
vsl_ds["iam"] = (
    vsl_ds.iam.to_series()
    .replace({"IIASA GDP": "IIASA", "OECD Env-Growth": "OECD"})
    .values
)
vsl_ds["region"] = ir_mapping.reindex(vsl_ds.region.values).values
vsl_ds = vsl_ds.dropna("region")
vsl_ds = vsl_ds.sel(region=sliiders.impact_region).load().drop_vars("region")

ypc = sliiders.ypc.sel(ssp="SSP2", iam="IIASA").load()
y0 = ypc.year[0].item()
vsl_y0 = vsl_ds.year[0].item()
vsl0 = vsl_ds.sel(ssp="SSP2", iam="IIASA", year=vsl_y0, drop=True)

# where ypc is 0 for all years, just keep vsl constant
ypc0 = (ypc.sel(year=y0) / ypc.sel(year=vsl_y0)).fillna(1) * vsl0

sliiders["vsl"] = np.exp(
    np.log(
        xr.concat(
            [
                ypc0.broadcast_like(vsl_ds.sel(year=vsl_y0))
                .expand_dims(year=[y0])
                .reindex(year=np.arange(y0, vsl_y0)),
                vsl_ds,
            ],
            dim="year",
        )
    ).interpolate_na("year", method="linear")
)

# avoid fixedlength codec warning for zarr v3 by making this be a Vlen string
sliiders["ssp"] = sliiders.ssp.astype(np.dtypes.StringDType)

getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP2_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP3_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP4_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP3_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP4_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP2_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/valuation/inputs/vsl/SSP3_new.nc4: Operation not supported
getfattr: /gcs/impactlab-data/gcp/estimation/mortality/release_2020/data/val

In [13]:
sliiders.to_zarr(str(PATH_SLIIDERS), mode="w")



<xarray.backends.zarr.ZarrStore at 0x7aca30194b80>

## Adjust SLR dataset

We have a really big dataset of point-wise local SLR projections (20,000 MCs for each of 5 scenarios for each of 7 scenarios, for each of 9 decades, for each of 51,153 points). Some of the workflow/scenario combinations were interpolated for the purposes of having estimates for those SSPs, so we will drop those to reduce data volume. Also, we will drop all points that are not a nearest neighbor to any of the SLIIDERS segments.

Next, we need to isolate the no-climate-change (i.e. vertical land motion-only) component, so we pull a "no-VLM" source of SLR from a google cloud bucket and subtract it.

Finally, we pull GMSL data at the sample level. This will be used to filter out samples that result in negative GMSL, which are not helpful for fitting the damage function.

In [6]:
slr = (
    xr.open_zarr(str(PATH_SLR_RAW))
    .drop_encoding()
    .stack(scen=["workflow", "scenario"])
    .rename(location="site_id")
)

valid_site_ids = np.unique(
    add_nearest_slrs(
        xr.open_zarr(str(PATH_SLIIDERS))[["seg_lon", "seg_lat"]], slr
    ).SLR_site_id
)

slr = (
    slr.sel(scen=slr.scenario == slr.correlation_source, site_id=valid_site_ids)
    .drop_vars("correlation_source")
    .sea_level_change.astype("int16")
    .chunk({"site_id": 10, "scen": -1})
)
slr = (
    slr.drop_vars(["scen", "workflow", "scenario"])
    .assign_coords(scen=(slr.scenario + "_" + slr.workflow).values)
    .rename(scen="scenario")
    .persist()
)

# we just pick a random workflow and ssp b/c vlm should be independent of these
novlm = (
    xr.open_zarr(
        "gs://ar6-lsl-simulations-requesterpays-standard/gridded/"
        "full_sample_workflows_novlm/wf_2f/ssp245/total-workflow.zarr",
        storage_options={"requester_pays": True},
    )
    .sea_level_change.astype("int16")
    .stack(location=["lon", "lat"])
)
novlm = (
    novlm.drop_vars(["location", "lon", "lat"])
    .assign_coords(
        location=(
            novlm.lon.astype(int).astype(str) + "_" + novlm.lat.astype(int).astype(str)
        ).values
    )
    .rename(location="site_id", samples="sample", years="year")
    .sel(site_id=slr.site_id, year=slr.year)
    .chunk({"site_id": 10, "year": -1, "sample": -1})
)

vlm = slr.sel(scenario="ssp245_wf_2f") - novlm

slr_out = xr.Dataset({"lsl_msl05": slr, "lsl_ncc_msl05": vlm})
slr_out["site_id"] = slr_out.site_id.astype(object)

# add on GSL data
paths = []
for s in slr_out.scenario.values:
    sp = s.split("_")
    wf = "_".join(sp[1:])
    ssp = sp[0]
    paths.append(
        str(DIR_SLR_GMSL_RAW / wf / ssp / "total-workflow.nc").replace("gs://", "/gcs/")
    )


def preproc(ds):
    path = ds.encoding["source"].split("/")
    return (
        ds.sea_level_change.squeeze(drop=True)
        .sel(years=slr.year.values)
        .expand_dims(scenario=[path[-2] + "_" + path[-3]])
    )


gsl = xr.open_mfdataset(
    paths,
    preprocess=preproc,
    combine="nested",
    concat_dim="scenario",
    join="exact",
).rename(samples="sample", years="year")
slr_out["gsl_msl05"] = gsl.astype("int16").chunk(
    {k: slr_out.chunksizes[k][0] for k in gsl.dims}
)

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp119/total-workflow.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp126/total-workflow.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp245/total-workflow.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp370/total-workflow.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp585/total-workflow.nc: Operation not supported
getfattr: /g

In [7]:
# save
slr_out.to_zarr(str(PATH_SLR_INT), mode="w")



<xarray.backends.zarr.ZarrStore at 0x7d1864a42fc0>

## Get scenarios needed for binned GMSL approach

In [25]:
sliiders = xr.open_zarr(str(PATH_SLIIDERS))

### Filter based on declining GMSL

In [26]:
# identify samples that have negative GSLR, which we throw out
gmsl = xr.open_zarr(str(PATH_SLR_INT)).gsl_msl05.load()
scenarios = gmsl.scenario.values
gmsl_valid = (gmsl.diff("year") >= 0).all("year")
gmsl = gmsl.stack({"full_sample_id": ["sample", "scenario"]})

### Filter based on what FaIR-based GMSL projections will be

In [27]:
fair_ds = (
    xr.open_dataset(FAIR_GMSL_FILE)[["control_gmsl", "pulse_gmsl"]]
    .to_array(dim="pulse")
    .sel(year=slice(None, sliiders.year.max()))
    .pint.quantify()
    .pint.to(gmsl.units)
    .pint.dequantify()
)

getfattr: /gcs/impactlab-data/coastal/data/int/hazard/slr/fair/gmsl_pulse_fixed.nc4: Operation not supported


### Filter based on 0.5th to 99.5th percentile of each SLR component in year 2100

In [45]:
def preproc(ds):
    name = ds.encoding["source"]
    parts = name.split("/")
    comp = parts[-1].split("-")
    if comp[0] == "icesheets":
        comp = comp[-1].split("_")[1]
    else:
        comp = comp[0]
    scenario = parts[-2] + "_" + parts[-3]
    return (
        ds.sea_level_change.squeeze(drop=True)
        .sel(years=2100, drop=True)
        .expand_dims(component=[comp], scenario=[scenario])
    )


def get_paths(dirname):
    return [
        str(path).replace("gs:/", "/gcs")
        for path in dirname.glob("*.nc")
        if "total-" not in path.stem
    ]


ssps = [i.split("_")[0] for i in scenarios]
wfs = ["_".join(i.split("_")[1:]) for i in scenarios]
paths = sum(
    [get_paths(DIR_SLR_GMSL_RAW / wfs[ix] / ssp) for ix, ssp in enumerate(ssps)], []
)

ds = xr.open_mfdataset(paths, preprocess=preproc, join="outer").sea_level_change.rename(
    samples="sample"
)
qs = ds.quantile(
    [COMPONENT_OUTLIER_QUANTILE, 1 - COMPONENT_OUTLIER_QUANTILE], dim="sample"
)
valid_comp = (ds >= qs.isel(quantile=0)) & (ds <= qs.isel(quantile=1))
valid_comp = (valid_comp | ds.isnull()).all("component").load()
valid = (valid_comp & gmsl_valid).stack({"full_sample_id": ["sample", "scenario"]})

getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp119/glaciers-ipccar6-gmipemuglaciers-ssp119_globalsl.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp119/icesheets-ipccar6-ismipemuicesheet-ssp119_AIS_globalsl.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp119/icesheets-ipccar6-ismipemuicesheet-ssp119_GIS_globalsl.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp119/landwaterstorage-ssp-landwaterstorage-ssp119_globalsl.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/ssp119/oceandynamics-tlm-oceandynamics-ssp119_globalsl.nc: Operation not supported
getfattr: /gcs/impactlab-data/coastal/data/raw/slr/ar6/ar6/global/full_sample_workflows/wf_1e/

### Run binning pipeline

In [46]:
slr_in = (
    xr.open_zarr(str(PATH_SLR_INT))[["gsl_msl05", "lsl_msl05"]]
    .chunk({"site_id": 1, "sample": -1})
    .stack({"full_sample_id": ["sample", "scenario"]})
    .chunk({"full_sample_id": -1})
).persist()

# create filter to drop top and bottom 1% based on max and min value observed for
# any site-year
max_val = slr_in.lsl_msl05.max(["site_id", "year"]).load()
allowable = max_val <= max_val.quantile(1 - SITE_OUTLIER_QUANTILE)
min_val = slr_in.lsl_msl05.min(["site_id", "year"]).load()
allowable &= min_val >= min_val.quantile(SITE_OUTLIER_QUANTILE)

# combine w/ GMSL and component-based filtering
allowable = (allowable & valid).drop_vars("quantile")

# keep only valid samples
slr_in = slr_in.sel(full_sample_id=allowable).persist()

# store integer index to tuple index mapping and convert index to integer
slr_in = slr_in.reset_index("full_sample_id")
slr_in["full_sample_id"] = np.arange(len(slr_in.full_sample_id))

# randomly shuffle order of scenarios so choice of draws is not dependent on order
shuffled_scen_order = (
    slr_in.full_sample_id.to_series()
    .sample(frac=1, replace=False, random_state=0)
    .values
)
slr_in = slr_in.sel(full_sample_id=shuffled_scen_order)
slr_in["full_sample_id"] = np.arange(len(slr_in.full_sample_id))

# interpolate
slr_in = slr_in.reindex(
    year=[SLR_0_YR] + slr_in.year.values.tolist(), fill_value=0
).interp(year=np.arange(DAMAGE_FUNC_Y0, slr_in.year.max() + 1))

# get gsl dataset with which we will perform binning
gsl = slr_in.gsl_msl05.load()

In [47]:
min_gmsl = fair_ds.min(["pulse", "simulation", "pulse_year", "rcp"]).reindex(
    year=gsl.year, fill_value=-np.inf
)
max_gmsl = fair_ds.max(["pulse", "simulation", "pulse_year", "rcp"]).reindex(
    year=gsl.year, fill_value=np.inf
)

In [48]:
def get_bin_mask(q):
    quantiles = gsl.quantile(q=[q, 1 - q], dim="full_sample_id").clip(
        min_gmsl, max_gmsl
    )
    step_size = (quantiles.isel(quantile=1) - quantiles.isel(quantile=0)) / N_BINS

    gmsl_bin_centers = (
        step_size
        * xr.DataArray(
            np.arange(0, N_BINS), dims=["slr"], coords={"slr": np.arange(0, N_BINS)}
        )
        + quantiles.isel(quantile=0)
        # jittering the centroids so we get better coverage for 5-year groupings
        + step_size * (quantiles.year % YEARLY_ROTATION) / YEARLY_ROTATION
    )

    bin_widths = np.minimum(step_size, BIN_WIDTH)
    this_bin_bounds = xr.Dataset(
        {
            "lower": gmsl_bin_centers - bin_widths / 2,
            "upper": gmsl_bin_centers + bin_widths / 2,
        }
    ).to_array(dim="bound")

    # put bin edges at the .5 in order to put them in between integer mm
    # quanta of SLR values
    this_bin_bounds = np.floor(np.round(this_bin_bounds)) + 0.5

    # drop duplicate bins when there aren't enough quanta to make 10 separate GMSL values
    this_bin_bounds = this_bin_bounds.where(
        this_bin_bounds.isel(bound=0) != this_bin_bounds.isel(bound=1)
    )

    # get sample mask
    return xr.Dataset(
        dict(
            mask=(gsl >= this_bin_bounds.isel(bound=0))
            & (gsl < this_bin_bounds.isel(bound=1)),
            bounds=this_bin_bounds,
        )
    )

In [49]:
qs = np.arange(0, 0.01, 0.0001)
ds_quants = client.map(get_bin_mask, qs)
ds_quants = dataset_from_delayed(ds_quants, dim="quantile").persist()

# find widest quantile range valid for each year such that we meet the minumum bin draw
# threshold
min_bin_draws = (
    ds_quants["mask"]
    .sum("full_sample_id")
    .where(ds_quants.bounds.notnull().all(dim="bound").min(dim="slr"))
)

# adjust the threshold for the final N years to allow for wider range
min_draw_da = xr.ones_like(ds_quants.year) * MIN_BIN_DRAWS
min_draw_da.loc[{"year": min_draw_da.year[-EXPAND_LAST_RANGE_N:].values}] = (
    MIN_BIN_DRAWS_FINAL
)
valid = min_bin_draws > min_draw_da

# confirm that there is a choice of quantile endpoints that make each of these choices
# valid
assert valid.any(dim="quantile").all()

# now take the widest quantile range that qualifies for each year
quantile_selection = valid.argmax(dim="quantile").load()
bin_bounds = ds_quants.bounds.isel(quantile=quantile_selection).load()
bin_masks = (gsl >= bin_bounds.isel(bound=0)) & (gsl < bin_bounds.isel(bound=1)).load()
bin_centroids = bin_bounds.mean(dim="bound").load()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


### Option 1. Choose the closest scenarios to each bin centroid

This minimizes the influence of GMSL differences on damages, but doesn't allow for maximum diversity of scenarios.

In [50]:
# now reduce bin widths in order to get N_BATCHES draws by removing furthest away values
# first calculate distance from bin centers
dist_from_centroids = np.abs(gsl.where(bin_masks) - bin_centroids)

# now order and take the first N values, where N is the minimum of N_BATCHES or the number of
# draws that fell within the bin
argsorter = np.argsort(dist_from_centroids, axis=1).rename(full_sample_id="batch")
argsorter["batch"] = np.arange(len(argsorter.batch))
binned = (
    gsl.full_sample_id.isel(full_sample_id=argsorter)
    .isel(batch=slice(None, N_BATCHES))
    .drop_vars("full_sample_id")
)

valid = binned.batch < bin_masks.sum(dim="full_sample_id")
binned = binned.reset_coords(["scenario", "quantile"]).where(valid)

# now repeat the draws for bins that have < N_BATCHES
n_notnull = binned.full_sample_id.notnull().sum(dim="batch")
for y in binned.year:
    for b in binned.slr:
        this_bin = binned.sel(year=y, slr=b)
        this_n = n_notnull.sel(year=y, slr=b).item()
        if this_n < N_BATCHES and this_bin.full_sample_id.notnull().any().item():
            for v in binned.data_vars:
                binned[v].loc[{"year": y, "slr": b}] = np.tile(
                    this_bin[v].isel(batch=slice(None, this_n)),
                    ceil(N_BATCHES / this_n),
                )[:N_BATCHES]

### Option 2. Randomly shuffle all scenarios that fall within bin differently each year

This maximizes the diversity of scenarios used but does not minimize the deviation from GMSL bin centroid (relative to a scenario in which you always pick the closest scenarios to the bin centroid).

In [None]:
# rng = np.random.default_rng(seed=1)
# binned = xr.zeros_like(
#     bin_masks.isel(full_sample_id=slice(0, N_BATCHES))
#     .rename(full_sample_id="batch")
#     .astype(int)
#     .rename("full_sample_id")
#     .reset_coords(["scenario", "sample"])
# )
# for v in ["scenario", "sample"]:
#     binned[v] = binned[v].broadcast_like(binned.full_sample_id).copy()

# n_notnull = bin_masks.sum(dim="full_sample_id")
# for y in bin_masks.year:
#     for b in bin_masks.slr:
#         this_bin = bin_masks.sel(year=y, slr=b)
#         this_n = n_notnull.sel(year=y, slr=b).item()
#         if this_bin.sum().item():
#             this_bin = this_bin[this_bin]
#             order = rng.choice(len(this_bin), size=len(this_bin), replace=False)
#             this_bin = this_bin.isel(full_sample_id=order).copy()
#             for v in binned.data_vars:
#                 binned[v].loc[{"year": y, "slr": b}] = np.tile(
#                     this_bin[v],
#                     ceil(N_BATCHES / this_n),
#                 )[:N_BATCHES]

# # make sure the missing ones are missing
# binned = binned.where(n_notnull > 0)

### Save bin/batch/year-to-scenario/sample mapping

In [51]:
binned_scen = slr_in.scenario.sel(full_sample_id=binned.full_sample_id).reset_coords(
    drop=True
)
binned_samp = (
    slr_in.sample.sel(full_sample_id=binned.full_sample_id)
    .reset_coords(drop=True)
    .astype("int16")
)
binned_gsl = (
    slr_in.gsl_msl05.sel(full_sample_id=binned.full_sample_id)
    .reset_coords(drop=True)
    .astype("int16")
)
xr.Dataset(
    {"gsl_msl05": binned_gsl, "sample": binned_samp, "scenario": binned_scen}
).to_zarr(str(BIN_MAP_PATH), mode="w")



<xarray.backends.zarr.ZarrStore at 0x7aad09de1940>

## close cluster

In [55]:
cluster.close(), client.close()

(None, None)