In [1]:
import logging
import numpy as np
import xarray as xr
from distributed import Client

from nwsspc.sharp.calc import constants
from nwsspc.sharp.calc import parcel
from nwsspc.sharp.calc import params
from nwsspc.sharp.calc import thermo
from nwsspc.sharp.calc import layer

client = Client(n_workers=8, silence_logs=logging.ERROR)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 16,Total memory: 62.32 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:38595,Workers: 0
Dashboard: http://127.0.0.1:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:35129,Total threads: 2
Dashboard: http://127.0.0.1:38801/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:35919,
Local directory: /tmp/dask-scratch-space/worker-_3n0wdta,Local directory: /tmp/dask-scratch-space/worker-_3n0wdta

0,1
Comm: tcp://127.0.0.1:46539,Total threads: 2
Dashboard: http://127.0.0.1:37945/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:40263,
Local directory: /tmp/dask-scratch-space/worker-95xxfg0w,Local directory: /tmp/dask-scratch-space/worker-95xxfg0w

0,1
Comm: tcp://127.0.0.1:45359,Total threads: 2
Dashboard: http://127.0.0.1:41137/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:44687,
Local directory: /tmp/dask-scratch-space/worker-cp9gm_5z,Local directory: /tmp/dask-scratch-space/worker-cp9gm_5z

0,1
Comm: tcp://127.0.0.1:33105,Total threads: 2
Dashboard: http://127.0.0.1:37115/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:35999,
Local directory: /tmp/dask-scratch-space/worker-9lis593s,Local directory: /tmp/dask-scratch-space/worker-9lis593s

0,1
Comm: tcp://127.0.0.1:36261,Total threads: 2
Dashboard: http://127.0.0.1:35217/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:43605,
Local directory: /tmp/dask-scratch-space/worker-p1c1203c,Local directory: /tmp/dask-scratch-space/worker-p1c1203c

0,1
Comm: tcp://127.0.0.1:40467,Total threads: 2
Dashboard: http://127.0.0.1:37007/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:46311,
Local directory: /tmp/dask-scratch-space/worker-_nqipvff,Local directory: /tmp/dask-scratch-space/worker-_nqipvff

0,1
Comm: tcp://127.0.0.1:33991,Total threads: 2
Dashboard: http://127.0.0.1:41307/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:44705,
Local directory: /tmp/dask-scratch-space/worker-pprpin6j,Local directory: /tmp/dask-scratch-space/worker-pprpin6j

0,1
Comm: tcp://127.0.0.1:44577,Total threads: 2
Dashboard: http://127.0.0.1:33583/status,Memory: 7.79 GiB
Nanny: tcp://127.0.0.1:37029,
Local directory: /tmp/dask-scratch-space/worker-szo89adw,Local directory: /tmp/dask-scratch-space/worker-szo89adw


In [None]:
ds_hybrid = xr.open_dataset(
    "hrrr-hybrid.json", 
    engine="kerchunk", 
    decode_timedelta=True
)[["pres", "gh", "t", "q"]].chunk(dict(hybrid=-1,)).load().astype("float32")

ds_2m = xr.open_dataset(
    "hrrr-2m.json", 
    engine="kerchunk", 
    decode_timedelta=True
)[["t2m", "sh2"]].chunk().load().astype("float32")

ds_sfc = xr.open_dataset(
    "hrrr-surface.json", 
    engine="kerchunk", 
    decode_timedelta=True
)[["sp"]].chunk().load().astype("float32")

In [3]:
print(ds_hybrid)

<xarray.Dataset> Size: 3GB
Dimensions:     (hybrid: 50, y: 1059, x: 1799)
Coordinates:
  * hybrid      (hybrid) float64 400B 1.0 2.0 3.0 4.0 ... 47.0 48.0 49.0 50.0
    latitude    (y, x) float64 15MB 21.14 21.15 21.15 ... 47.86 47.85 47.84
    longitude   (y, x) float64 15MB 237.3 237.3 237.3 ... 299.0 299.0 299.1
    step        (hybrid) timedelta64[ns] 400B 02:00:00 02:00:00 ... 02:00:00
    time        (hybrid) datetime64[ns] 400B 2025-05-19T21:00:00 ... 2025-05-...
    valid_time  (hybrid) datetime64[ns] 400B 2025-05-19T23:00:00 ... 2025-05-...
Dimensions without coordinates: y, x
Data variables:
    pres        (hybrid, y, x) float64 762MB 1.015e+05 1.015e+05 ... 1.731e+03
    gh          (hybrid, y, x) float64 762MB 10.71 10.71 ... 2.77e+04 2.77e+04
    t           (hybrid, y, x) float64 762MB 292.6 292.6 292.6 ... 226.0 226.0
    q           (hybrid, y, x) float64 762MB 0.01159 0.01159 ... 1.947e-06
Attributes:
    GRIB_edition:            2
    GRIB_centre:             kwbc
  

In [4]:
def compute_everything(pres, hght, tmpk, spfh, sp, t2m, sh2, use_2m=True):
    mixr = thermo.mixratio(spfh)
    mixr_2m = thermo.mixratio(sh2)

    mixr[mixr < constants.TOL] = constants.TOL
    if (mixr_2m < constants.TOL): mixr_2m = constants.TOL

    vtmp = thermo.virtual_temperature(tmpk, mixr)

    if (use_2m):
        dwpk_2m = thermo.temperature_at_mixratio(mixr_2m, sp)
        pcl = parcel.Parcel(sp, t2m, dwpk_2m, parcel.LPL.SFC)
    else: 
        pcl_dwpk = thermo.temperature_at_mixratio(mixr[0], pres[0])
        pcl = parcel.Parcel(pres[0], tmpk[0], pcl_dwpk, parcel.LPL.SFC)

    lifter = parcel.lifter_cm1()
    lifter.ma_type = thermo.adiabat.pseudo_liq
    lifter.converge = 0.1

    pcl_vtmp = pcl.lift_parcel(lifter, pres)
    pcl_buoy = thermo.buoyancy(pcl_vtmp, vtmp)
    cape, cinh = pcl.cape_cinh(pres, hght, pcl_buoy)
    return cape

In [5]:
input_core_dims = [
    ["hybrid"], ["hybrid"],
    ["hybrid"], ["hybrid"],
    [], [], []
]



In [10]:
cape_2m = xr.apply_ufunc(
    compute_everything, 
    ds_hybrid["pres"],
    ds_hybrid["gh"],
    ds_hybrid["t"],
    ds_hybrid["q"],
    ds_sfc["sp"],
    ds_2m["t2m"],
    ds_2m["sh2"],
    input_core_dims=input_core_dims,
    output_core_dims=[[]],
    vectorize=True
    dask="allowed",
    output_dtypes=[np.float32], # Good practice to match output dtype
    kwargs={"use_2m": True},
).compute()

print(cape_2m.max())

TypeError: mixratio(): incompatible function arguments. The following argument types are supported:
    1. mixratio(q: float) -> float
    2. mixratio(pressure: float, temperature: float) -> float
    3. mixratio(spfh_arr: ndarray[dtype=float32, shape=(*), order='C', device='cpu', writable=False]) -> numpy.ndarray[dtype=float32, shape=(*)]
    4. mixratio(pres_arr: ndarray[dtype=float32, shape=(*), order='C', device='cpu', writable=False], tmpk_arr: ndarray[dtype=float32, shape=(*), order='C', device='cpu', writable=False]) -> numpy.ndarray[dtype=float32, shape=(*)]

Invoked with types: ndarray