In [2]:
import os 
from glob import glob

import numpy as np
import jax
import jax.numpy as jnp

import xarray as xr
import xesmf as xe
import pandas as pd
import dask
import zarr

from utils.param_names import param_names
from utils.initial_params import constants
from utils.subsets import subsets
from utils.global_paths import project_data_path, project_code_path, loca_path
from src.read_inputs import read_projection_inputs
from src.prediction import make_prediction_vmap
from src.data_processing import _subset_states

In [3]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    account="pches",
    # account="open",
    cores=1,
    memory="25GiB",
    walltime="04:30:00"
)
cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.155:46159,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# LOCA2

In [4]:
##############
### Models ###
##############

models = os.listdir(f"{loca_path}/")
models.remove('training_data')
models.remove('scripts')

loca_all = {}

# Loop through models
for model in models:
    loca_all[model] = {}
    # Loop through members
    members = os.listdir(f"{loca_path}/{model}/0p0625deg/")
    for member in members:
        # Append SSPs
        ssps = os.listdir(f"{loca_path}/{model}/0p0625deg/{member}/")
        loca_all[model][member] = ssps

# Matches website (https://loca.ucsd.edu/loca-version-2-for-north-america-ca-jan-2023/) as of Jan 2023
print(f"# models: {len(models)}")
print(f"# model/expts: {np.sum([len(np.unique([item for row in [loca_all[model][member] for member in loca_all[model].keys()] for item in row])) for model in models])}")
print(f"# model/expts/ens: {np.sum([len(loca_all[model][ssp]) for model in models for ssp in loca_all[model]])}")
print(f"# model/expts/ens (not including historical): {np.sum([len([ssp for ssp in loca_all[model][member] if ssp != 'historical']) for model in models for member in loca_all[model]])}")

# models: 27
# model/expts: 99
# model/expts/ens: 329
# model/expts/ens (not including historical): 221


## Regridding

In [5]:
###################
# Regrid function
###################
def regrid_subset(model, member, ssp, subset_name, list_of_states):
    # Read inputs
    tasmin_in = xr.open_mfdataset(f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/tasmin/*.nc", chunks="auto")
    tasmin_in["tasmin"] = tasmin_in["tasmin"] - 273.15
    tasmin_in["tasmin"].attrs["units"] = "degC"
    
    tasmax_in = xr.open_mfdataset(f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/tasmax/*.nc", chunks="auto")
    tasmax_in["tasmax"] = tasmax_in["tasmax"] - 273.15
    tasmax_in["tasmax"].attrs["units"] = "degC"

    tas_in = (tasmin_in["tasmin"] + tasmax_in["tasmax"]) / 2.0
    tas_in.attrs["units"] = "degC"
    
    pr_in = xr.open_mfdataset(f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/pr/*.nc", chunks="auto")
    pr_in["pr"] = pr_in['pr'] * 86400
    pr_in["pr"].attrs["units"] = "mm/day"

    # Merge
    ds_in = xr.merge([xr.Dataset({"tas": tas_in}), pr_in])

    # Construct out grid
    nldas_lats = np.load(f"{project_code_path}/code/utils/grids/{subset_name}_lat.npy")
    nldas_lons = np.load(f"{project_code_path}/code/utils/grids/{subset_name}_lon.npy")

    dr_out = xr.Dataset({
        "lat": (["lat"], nldas_lats,
                {"standard_name": "latitude", "units": "degrees_north"},),
        "lon": (["lon"], nldas_lons,
                {"standard_name": "longitude", "units": "degrees_east"},),
    })

    # Regrid conservatively
    regridder = xe.Regridder(ds_in, dr_out, "conservative")
    ds_out = regridder(ds_in, skipna=True, na_thres=0.99) # This threshold is somewhat subjective

    # Subset to states
    ds_out = _subset_states(ds_out, list_of_states)
    
    # Store 
    ds_out = ds_out.chunk({'time': 200 , 'lat':-1, 'lon':-1})
    compressor = zarr.Blosc(cname="zstd", clevel=3)
    encoding = {vname: {"compressor": compressor} for vname in ds_out.data_vars}
    ds_out.to_zarr(f"{project_data_path}/projections/{subset_name}/forcing/LOCA2/{model}_{member}_{ssp}.zarr",
                   encoding=encoding, mode='w-', consolidated=True)

In [6]:
## File path function
def make_loca_file_path(loca_path, model, member, ssp, var):
    """
    Returns list of file paths for a given downscaled LOCA output.
    """
    out_path = f"{loca_path}/{model}/0p0625deg/{member}/{ssp}/{var}"

    if os.path.isdir(out_path):
        files = os.listdir(out_path)
        files = [file for file in files if file[-7:] != 'ORIG.nc'] # Skip ORIGs (had to fix tasmin naming errors)
        return files
    else:
        return []

### eCONUS

In [7]:
subset_name = "eCONUS"
list_of_states = subsets[subset_name]

In [8]:
%%time
# Loop through models
for model in models:
    # Loop through members
    for member in loca_all[model].keys():
        # Loop through SSPs
        for ssp in loca_all[model][member]:
            if ssp == "historical":
                continue
            # Some vars are missing for some outputs: skip
            file_paths = make_loca_file_path(loca_path, model, member, ssp, "tasmin")
            if len(file_paths) == 0:
                print(f"Missing: {model} {ssp} {member}")

            # Check if done
            if not os.path.exists(f"{project_data_path}/projections/{subset_name}/forcing/LOCA2/{model}_{member}_{ssp}.zarr"):
                # Re-grid and subset
                try:
                    regrid_subset(model=model,
                                  member=member,
                                  ssp=ssp,
                                  subset_name=subset_name,
                                  list_of_states=list_of_states)
                except:
                    print(f"{model}_{member}_{ssp}")

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


INM-CM5-0_r3i1p1f1_ssp370


    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


INM-CM5-0_r4i1p1f1_ssp370


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
This may cau

IPSL-CM6A-LR_r4i1p1f1_ssp245


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


IPSL-CM6A-LR_r5i1p1f1_ssp245


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


IPSL-CM6A-LR_r7i1p1f1_ssp370


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Co

Missing: MPI-ESM1-2-LR ssp585 r10i1p1f1
MPI-ESM1-2-LR_r10i1p1f1_ssp585


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
    >>> with dask.config.set(**{

MPI-ESM1-2-LR_r4i1p1f1_ssp585


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


Missing: MPI-ESM1-2-LR ssp585 r5i1p1f1
MPI-ESM1-2-LR_r5i1p1f1_ssp585


This may cause some slowdown.
Consider scattering data ahead of time and using futures.


Missing: MPI-ESM1-2-LR ssp585 r6i1p1f1
MPI-ESM1-2-LR_r6i1p1f1_ssp585


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


Missing: MPI-ESM1-2-LR ssp585 r7i1p1f1
MPI-ESM1-2-LR_r7i1p1f1_ssp585


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.


Missing: MPI-ESM1-2-LR ssp585 r8i1p1f1
MPI-ESM1-2-LR_r8i1p1f1_ssp585


This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Co

CPU times: user 42min 7s, sys: 1min 40s, total: 43min 47s
Wall time: 2h 36min 3s




## Run projections

In [3]:
def get_training_res(subset_name, obs_name, loss_metric, iden):
    """
    Reads the training results
    """
    # Loop through files
    files = glob(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/{iden}.txt')

    df_out = []
    for file in files:
        # Read
        df = pd.read_csv(file, sep = ' ')

        # Add identifiers
        _, param_id, val_id, _ = file.split('/')[-1].split('_')
        df['param_id'] = param_id
        df['val_id'] = val_id
        
        # Take best val
        df_best = df.sort_values(by=loss_metric).iloc[:1]
        df_out.append(df_best)

    # Join and return
    df_out = pd.concat(df_out)
    return df_out

In [4]:
df = get_training_res('eCONUS', 'SMAP', 'val_loss', '*')

In [5]:
theta = jnp.array([df.iloc[50][param] for param in param_names])

In [11]:
def run_projection(subset_name, obs_name, projection_id, theta, sim_id):
    # Read all
    x_forcing_nt, x_forcing_nyrs, x_maps, valid_inds = read_projection_inputs(subset_name, obs_name, projection_id, True)

    # Do it
    out = make_prediction_vmap(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps)

    # Store
    np.savez(f'{project_data_path}/projections/{subset_name}/out/{projection_id}_{sim_id}.npy',
             out=out,
             valid_inds=valid_inds)

### eCONUS

In [7]:
subset_name = 'eCONUS'
obs_name = 'SMAP'
projection_id = 'LOCA2/CanESM5_r1i1p1f1_ssp585'
sim_id = 'test'

In [8]:
%%time
run_projection(subset_name, obs_name, projection_id, theta, sim_id)

CPU times: user 4min 39s, sys: 1min 5s, total: 5min 45s
Wall time: 5min 50s


In [9]:
npy = np.load(f'{project_data_path}/projections/{subset_name}/out/{projection_id}_{sim_id}.npy')

In [10]:
npy.shape

(30212, 28470)

In [9]:
out_full = np.full((len(valid_inds), out.shape[1]), np.nan)
out_full[valid_inds] = out

CPU times: user 2.7 s, sys: 4.47 s, total: 7.17 s
Wall time: 7.22 s


In [49]:
# Construct xr 
ds_grid = xr.open_dataset(f"{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.nc")
lons = ds_grid.lon
lats = ds_grid.lat
nt = out.shape[1]

In [55]:
ds_sim = xr.Dataset(
    data_vars=dict(soilMoist=(["time", "lat", "lon"], np.transpose(out_full.reshape(len(lons), len(lats), nt), (2,1,0)))),
    coords=dict(
    lon=lons,
    lat=lats,
    time=xr.cftime_range(start='2023-01-01', end='2100-12-31', calendar='365_day')))

In [63]:
%%time
ds_sim = ds_sim.chunk({'time': 200 , 'lat':-1, 'lon':-1})
compressor = zarr.Blosc(cname="zstd", clevel=3)
encoding = {vname: {"compressor": compressor} for vname in ds_sim.data_vars}
ds_sim.to_zarr(f"{project_data_path}/projections/test.zarr",
               encoding=encoding, mode='w-', consolidated=True)

CPU times: user 1min 38s, sys: 6.93 s, total: 1min 45s
Wall time: 1min 51s


<xarray.backends.zarr.ZarrStore at 0x14eb2b1e3d40>