# Process 3d output for SOM

Requires zonal and meridional winds at 200hPa, 500hPa and 925hPa, and specific humidity at (925 hPa and 500 hPa).
 1. Interpolate in pressure to get those variables
 2. Write out new dataset with only the output variables there
 3. Remap afterwards

In [1]:
%matplotlib inline
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
from itertools import product
from collections import OrderedDict

In [None]:
# read U, V, Q, from h2 files; ps from h1 file

In [2]:
ds3 = xr.open_dataset('/global/project/projectdirs/m3312/crjones/e3sm/early_science/3hourly_3d_hist/earlyscience.FC5AV1C-H01A.ne120.sp1_64x1_1000m.20190329.cam.h2.0002-05-15-00000.nc', chunks={'time': 1})
ds2 = xr.open_dataset('/global/project/projectdirs/m3312/crjones/e3sm/early_science/hourly_2d_hist/earlyscience.FC5AV1C-H01A.ne120.sp1_64x1_1000m.20190329.cam.h1.0002-05-15-00000.nc')

In [3]:
interp_levels = [200, 500, 925]
interp_q = [925, 500]
u = ds3['U']
v = ds3['V']
q = ds3['Q']

In [64]:
ux = u.isel(ncol=15, time=12)
vx = v.isel(ncol=15, time=12)

In [19]:
px = ps.isel(ncol=15, time=12) * hybm / 100 + p0 * hyam
qx = q.isel(ncol=15, time=12)

In [4]:
ps = ds2['PS'].sel(time=ds3.time)
hyam = ds3['hyam'].load()
hybm = ds3['hybm'].load()
p0 = 1000

In [5]:
# can I use dask.apply_along_axis?
p = ps * hybm / 100 + p0 * hyam
p.shape

(40, 777602, 72)

In [6]:
u = u.load()

In [7]:
u = u.transpose(*p.dims)

## Python version

In [40]:
# alternate approach: apply linear interpolation within a column by calculating effective weights
def plev_linear_weights(p, levs=[250, 500, 925]):
    # make sure order is maintained
    weights = OrderedDict()
    for ll in levs:
        iup = np.searchsorted(p, ll, side='right')
        if iup == 0 or iup >= len(p):
            weights[ll] = [(0, 0), (np.nan, np.nan)]
            continue
        idn = iup - 1
        dp = (ll - p[idn]) / (p[iup] - p[idn])
        weights[ll] = [(idn, iup), (1 - dp, dp)]
    return weights

In [41]:
def apply_weights_to_column(v, w, as_array=True):
    if as_array:
        return np.array([sum(v[i] * wt for i, wt in zip(*val)) for val in w.values()])
    return {key: sum(v[i] * wt for i, wt in zip(*val)) for key, val in w.items()}

In [42]:
def interp_col_to_pres_level(p, levs, da):
    shape = list(p.shape[:-1]) + [len(levs)]
    outx = np.full(shape, fill_value=np.nan)
    for ind in np.ndindex(p.shape[:-1]):
        weights = plev_linear_weights(p[ind], levs)
        outx[ind] = apply_weights_to_column(da[ind], weights)
    return outx

In [44]:
%%time
levs = [200, 500, 925]
ptest = p.isel(ncol=slice(0, 1000))
da = u.isel(ncol=slice(0, 1000))
out = interp_col_to_pres_level(ptest.values, np.array(levs), da.values)

CPU times: user 2.16 s, sys: 45.1 ms, total: 2.21 s
Wall time: 2.15 s


## Numba version

In [8]:
from numba import njit, prange, jit

In [9]:
import numba as nb

In [29]:
# still need to figure out how to deal with types, but this is suuuper promising
@njit
def nplev_linear_weights(p, levs=[250, 500, 925]):
    inds = np.empty((len(levs), 2), dtype=np.int64)
    weights = np.empty((len(levs), 2), dtype=np.float64)
    for n, ll in enumerate(levs):
        iup = np.searchsorted(p, ll, side='right')
        if iup == 0 or iup >= len(p):
            inds[n] = (0, 0)
            weights[n] = (np.nan, np.nan)
            continue
        idn = iup - 1
        dp = (ll - p[idn]) / (p[iup] - p[idn])
        inds[n] = (idn, iup)
        weights[n] = (1 - dp, dp)
    return inds, weights

@njit
def napply_weights_to_column(v, inds, weights):
    return np.array([np.sum(v[inds[k]] * weights[k]) for k in range(len(inds))])

@njit
def ninterp_col_to_pres_level(p, levs, da):
    shape = (*p.shape[:-1], len(levs))
    outx = np.full(shape, fill_value=np.nan)
    for ind in np.ndindex(p.shape[:-1]):
        inds, weights = nplev_linear_weights(p[ind], levs)
        outx[ind] = napply_weights_to_column(da[ind], inds, weights)
    return outx

In [66]:
np.array([[i + j for i in range(3)] for j in range(4)]).shape

(4, 3)

In [113]:
# still need to figure out how to deal with types, but this is suuuper promising
@njit
def nlev_inds_and_weights(p, levs=[250, 500, 925]):
    shape = (*p.shape[:-1], len(levs), 2)
    inds = np.empty(shape, dtype=np.int64)
    weights = np.empty(shape, dtype=np.float64)
    
    id0 = np.empty((len(levs), 2), dtype=np.int64)
    wt0 = np.empty((len(levs), 2), dtype=np.float64)
    for rows in np.ndindex(p.shape[:-1]):
        pcol = p[rows]
        for n, ll in enumerate(levs):
            iup = np.searchsorted(pcol, ll, side='right')
            if iup == 0 or iup >= len(pcol):
                id0[n] = (0, 0)
                wt0[n] = (np.nan, np.nan)
                continue
            idn = iup - 1
            dp = (ll - pcol[idn]) / (pcol[iup] - pcol[idn])
            id0[n] = (idn, iup)
            wt0[n] = (1 - dp, dp)
        inds[rows] = id0
        weights[rows] = wt0
    return inds, weights

In [114]:
inds, weights = nlev_inds_and_weights(p.isel(ncol=slice(0, 20)).values)

LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Cannot lower constant of type 'reflected list(int64)'

File "<ipython-input-113-f0ad241be961>", line 4:
def nlev_inds_and_weights(p, levs=[250, 500, 925]):
    shape = (*p.shape[:-1], len(levs), 2)
    ^

[1] During: lowering "levs = arg(1, name=levs)" at <ipython-input-113-f0ad241be961> (4)
-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.

Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new

If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba

Thanks in advance for your help in improving Numba!



In [101]:
# still need to figure out how to deal with types, but this is suuuper promising
@njit
def nlev_inds_and_weights(p, levs=[250, 500, 925]):
    shape = (*p.shape[:-1], len(levs), 2)
    inds = np.empty(shape, dtype=int)
    weights = np.empty(shape, dtype=float)
    
    id0 = np.empty((len(levs), 2), dtype=int)
    wt0 = np.empty((len(levs), 2), dtype=float)
    for rows in np.ndindex(p.shape[:-1]):
        pcol = p[rows]
        for n, ll in enumerate(levs):
            iup = np.searchsorted(pcol, ll, side='right')
            if iup == 0 or iup >= len(pcol):
                id0[n] = (0, 0)
                wt0[n] = (np.nan, np.nan)
                continue
            idn = iup - 1
            dp = (ll - p[idn]) / (p[iup] - p[idn])
            id0[n] = (idn, iup)
            wt0[n] = (1 - dp, dp)
        inds[rows] = id0
        weights[rows] = wt0
    return inds, weights

@njit
def napply_weights_to_columns(inds, weights, *arr):
    return np.array([[np.sum(v[inds[k]] * weights[k]) for v in arr] for k in range(len(inds))])

@njit
def ninterp_arrays_to_pres_level(p, levs, *arr):
    narr = len(arr)
    shape = (*p.shape[:-1], len(levs), narr)
    outx = np.empty(shape)
    for ind in np.ndindex(p.shape[:-1]):
        inds, weights = nplev_linear_weights(p[ind], levs)
        outx[ind] = napply_weights_to_columns(inds, weights, *arr[ind])
        # for outx, ar in zip(out, arr):
        #     outx[ind] = napply_weights_to_column(ar[ind], inds, weights)
    return outx

In [88]:
# original version that worked
# still need to figure out how to deal with types, but this is suuuper promising
@njit((nb.float32[:],nb.float32[:]))
def nplev_linear_weights0(p, levs=[250, 500, 925]):
    # inds = np.empty((len(levs), 2), dtype=np.int)
    # weights = np.empty((len(levs), 2), dtype=np.float32)
    inds = []
    weights = []
    for n, ll in enumerate(levs):
        iup = np.searchsorted(p, ll, side='right')
        if iup == 0 or iup >= len(p):
            inds.append((0, 0))
            weights.append((np.nan, np.nan))
            continue
        idn = iup - 1
        dp = (ll - p[idn]) / (p[iup] - p[idn])
        inds.append((idn, iup))
        weights.append((1 - dp, dp))
    return inds, weights

@njit
def napply_weights_to_column0(v, inds, weights, as_array=True):
    out = np.zeros(len(inds))
    for k in prange(len(inds)):
        ind = inds[k]
        weight = weights[k]
        for i, w in zip(ind, weight):
            out[k] += v[i] * w
    return out

@njit
def ninterp_col_to_pres_level0(p, levs, da):
    shape = (*p.shape[:-1], len(levs))
    # shape = list(p.shape[:-1]) + [len(levs)]
    outx = np.full(shape, fill_value=np.nan)
    for ind in np.ndindex(p.shape[:-1]):
        inds, weights = nplev_linear_weights(p[ind], levs)
        outx[ind] = napply_weights_to_column(da[ind], inds, weights)
    return outx

In [100]:
np.split(tmp[(0, 1)], tmp[(0,1)].shape[-1], axis=-1)

[array([[10],
        [12],
        [14],
        [16],
        [18]]), array([[11],
        [13],
        [15],
        [17],
        [19]])]

In [24]:
%%time
out = interp_col_to_pres_level(ptest, levs, da)

CPU times: user 1min 39s, sys: 728 ms, total: 1min 40s
Wall time: 1min 38s


In [46]:
%%time
levs = [200, 500, 925]
ptest = p.isel(ncol=slice(0, 1000))
da = u.isel(ncol=slice(0, 1000))
out3 = ninterp_col_to_pres_level(ptest.values, np.array(levs), da.values)

CPU times: user 51.6 ms, sys: 0 ns, total: 51.6 ms
Wall time: 51.3 ms


In [47]:
%%time
levs = [200, 500, 925]
ptest = p
da = u
out4 = ninterp_col_to_pres_level(ptest.values, np.array(levs), da.values)

CPU times: user 37.5 s, sys: 569 ms, total: 38 s
Wall time: 37.8 s


In [102]:
%%time
levs = [200, 500, 925]
ptest = p
da = u
out3 = ninterp_arrays_to_pres_level(ptest.values, np.array(levs), da.values, da.values)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (tuple(array(float32, 3d, A) x 2), tuple(int64 x 2))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <ipython-input-101-00be0301bcd4> (29)

File "<ipython-input-101-00be0301bcd4>", line 29:
def ninterp_arrays_to_pres_level(p, levs, *arr):
    <source elided>
        inds, weights = nplev_linear_weights(p[ind], levs)
        outx[ind] = napply_weights_to_columns(inds, weights, *arr[ind])
        ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile

If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new


In [None]:
def interp_to_levs_for_array(p, levs, da):
    # step 1: preprocess so that lev dimension is last
    # assert da.dims[-1] == 'lev' and p.dims[-1] == 'lev'
    print(da.shape)
    out_shape = list(da.shape[:-1]) + [len(levs)]
    print(out_shape)
    out = np.full(out_shape, fill_value=np.nan)
    for ind in np.ndindex(p.shape[:-1]):
        print(ind)
        weights = plev_linear_weights(p[ind], levs)
        print(weights)
        out[ind] = apply_weights_to_column(da[ind], weights)
    return xr.DataArray(out, dims=

In [99]:
def interp_to_levs_for_array(p, levs, da):
    # step 1: preprocess so that lev dimension is last
    # assert da.dims[-1] == 'lev' and p.dims[-1] == 'lev'
    print(da.shape)
    out_shape = list(da.shape[:-1]) + [len(levs)]
    print(out_shape)
    out = np.full(out_shape, fill_value=np.nan)
    for ind in np.ndindex(p.shape[:-1]):
        print(ind)
        weights = plev_linear_weights(p[ind], levs)
        print(weights)
        out[ind] = apply_weights_to_column(da[ind], weights)
    return xr.DataArray(out, dims=

In [None]:
from multiprocessing.pool import ThreadPool
pool = ThreadPool()

out = np.empty((40, 777602))

def f(i):
    out[i] = function(ar[i], *args)

pool.map(f, range(len(ar)))

In [114]:
out = np.apply_along_axis(plev_linear_weights, 0, p.isel(time=slice(0, 40), ncol=slice(0, 10000)))
out.shape

(40, 10000)

In [124]:
ptest = plev_linear_weights(p.isel(time=5, ncol=75))

In [125]:
ptest == out[5, 75]

True

In [133]:
ptest

OrderedDict([(250, [(36, 37), (<xarray.DataArray ()>
                array(0.440817)
                Coordinates:
                    lev      float64 236.6
                    time     object 0002-05-15 15:00:00, <xarray.DataArray ()>
                array(0.559183)
                Coordinates:
                    lev      float64 236.6
                    time     object 0002-05-15 15:00:00)]),
             (500, [(44, 45), (<xarray.DataArray ()>
                array(0.064707)
                Coordinates:
                    lev      float64 454.4
                    time     object 0002-05-15 15:00:00, <xarray.DataArray ()>
                array(0.935293)
                Coordinates:
                    lev      float64 454.4
                    time     object 0002-05-15 15:00:00)]),
             (925, [(60, 61), (<xarray.DataArray ()>
                array(0.004737)
                Coordinates:
                    lev      float64 889.2
                    time     object 0002-05

In [134]:
p.dims.index('lev')

0

In [None]:
def interp_for_da(pres, levs, da, weights=None):
    """Intented for subarrays / multidimensional arrays"""
    # get weights, apply along appropriate axis for pressure
    lev_axis = pres.dims.index('lev')  # input: pres = dataArray
    weights = np.apply_along_axis(plev_linear_weights, lev_axis, p.values)  # array of [(i1, i2), (w1, w2)]

    dims = [d for d in da.dims]
    assert all(da.values.shape == pres.values.shape)
    lev_axis = dims.index('lev')
    out_shape = 
    

In [61]:
def interp_da_to_levs(pres, levs, *das, weights=None):
    if weights is None:
        weights = plev_linear_weights(pres, levs)
    out_array = np.empty((len(das), len(levs)))
    for i, da in enumerate(das):
        out_array[i, :] = apply_weights(da, weights, as_array=True)
    return out_array

In [100]:
def interp_full_arrays(ps, p0, hyam, hybm, interp_levs, *das):
    da0 = das[0]
    # loop over non-pressure dimensions:
    # dims = OrderedDict([(da.name, [v for v in da.dims if 'lev' not in v]) for da in das])
    dims = [[v for v in da.dims if 'lev' not in v] for da in das]
    assert len(set(tuple(d) for d in dims)) == 1  # make sure they all share same non-pressure dims
    
    dims = dims[0]
    limits = [len(da0[d]) for d in dims]
    
    # verify all dims values are the same
    ni = len(interp_levs)
    out_das = [np.full([*limits, ni], fill_value=np.nan) for _ in range(len(das))]
    
    # loop over dimensions here ...
    for sel_dims in product(*[range(lim) for lim in limits]):
        sel = {dim: val for dim, val in zip(dims, sel_dims)}  # loop over each non-lev dimension
        
    
    # for sel_dims in product(*[range(lim) for lim in limits])
    # assume that each da passed in has the same shape (limitation will be met this use)
    # limits = len(das[0][d]) for d in dims   # length of each dimension in dims

    """p = p0 * hyam + ps * hybm / 100  # pressure levels in hPa (assumed to be of same dims as dat)
    dims = [v for v in dat.dims if 'lev' not in v]  # loop over non-pressure dimensions
    
    # out_data will have same order as dimensions, except lev in last dimension
    out_data = np.full([*limits, len(levs)], fill_value=np.nan) 

    if log_interp:
        p = np.log(p)  # interpolate in log-space
        levs = np.log(levs)
        # debug
        print(levs)
    
    for sel_dims in product(*[range(lim) for lim in limits]):
        sel_dict = {dim: val for dim, val in zip(dims, sel_dims)}
        interp_fun = interp1d(p.isel(sel_dict), dat.isel(sel_dict), **interp_kwargs)
        if extrap_check:
            for n, lev in enumerate(levs):
                ind = tuple([*sel_dims] + [n])
                try:
                    out_data[ind] = interp_fun(lev)
                except ValueError:
                    # Some interp1d methods raise ValueError if asked to extrapolate
                    pass
        else:
            # use vectorized input
            ind = tuple([*sel_dims] + [...])
            out_data[ind] = interp_fun(levs)
    return out_data
    """
    return dims, limits, out_das

In [78]:
len(v['time'])

40

In [65]:
interp_da_to_levs(pres, levs, qx, ux, vx)

array([[ 6.64616311e-05,  7.52676921e-04,  6.70447916e-03],
       [ 3.56072969e+01,  1.69881860e+01,  5.86682081e+00],
       [ 2.34070718e+01,  1.15193150e+01, -2.05407336e+00]])

In [58]:
levs = [250, 500, 925]
pres = px.values
weights = plev_linear_weights(pres, levs)
out_array = np.empty((1, len(levs)))
apply_weights(qx.values, weights, as_array=False)

{250: 6.646163110153282e-05,
 500: 0.000752676920976892,
 925: 0.006704479160862865}

In [50]:
interp_da_to_levs(px.values, [250, 500, 925], q.values)

IndexError: index 44 is out of bounds for axis 0 with size 40

In [9]:
def interp_to_plevs(ps, hyam, hybm, dat, levs, p0=1000, p=None,
                    log_interp=True, extrap_check=False, **interp_kwargs):
    """ Interpolate dataArray dat to pressure levels levs
    
    inputs: ps is surface pressure dataArray
            hyam, hybm: hybrid coeffients
            dat: dataArray of variable to interpolate
            levs: pressure levels to interpolate dat onto
            interp_kwargs: keyword arguments to pass to scipy.interpolate.interp1d
    
    Returns: out_data: numpy array of dat interpolated to levs.
    Note: out_data dimensions have same order as dat, except that the new levs 
          dimension is last
    """
    if p is None:
        p = p0 * hyam + ps * hybm / 100  # pressure levels in hPa (assumed to be of same dims as dat)
    dims = [v for v in dat.dims if 'lev' not in v]  # loop over non-pressure dimensions
    limits = [len(dat[d]) for d in dims]   # length of each dimension in dims
    
    # out_data will have same order as dimensions, except lev in last dimension
    out_data = np.full([*limits, len(levs)], fill_value=np.nan) 

    if log_interp:
        p = np.log(p)  # interpolate in log-space
        levs = np.log(levs)
        # debug
        print(levs)
    
    for sel_dims in product(*[range(lim) for lim in limits]):
        sel_dict = {dim: val for dim, val in zip(dims, sel_dims)}
        interp_fun = interp1d(p.isel(sel_dict), dat.isel(sel_dict), **interp_kwargs)
        if extrap_check:
            for n, lev in enumerate(levs):
                ind = tuple([*sel_dims] + [n])
                try:
                    out_data[ind] = interp_fun(lev)
                except ValueError:
                    # Some interp1d methods raise ValueError if asked to extrapolate
                    pass
        else:
            # use vectorized input
            ind = tuple([*sel_dims] + [...])
            out_data[ind] = interp_fun(levs)
    return out_data

In [24]:
interp_levels

[925, 500, 200]

In [39]:
wlog = plev_linear_weights(np.log(px.values), np.log([200, 500, 925]))
wnolog = plev_linear_weights(px.values, [200, 500, 925])

In [40]:
apply_weights(px.values, wlog)

{5.298317366548036: 200.05955816599482,
 6.214608098422191: 500.0931658284276,
 6.829793737512425: 925.0002379983664}

In [41]:
apply_weights(px.values, wnolog)

{200: 200.0, 500: 500.0, 925: 924.9999999999999}

In [33]:
for k in wlog.values():
    print(k)

[(33, 34), (<xarray.DataArray ()>
array(0.096229)
Coordinates:
    time     object 0002-05-16 12:00:00
    lev      float64 185.3, <xarray.DataArray ()>
array(0.903771)
Coordinates:
    time     object 0002-05-16 12:00:00
    lev      float64 185.3)]
[(44, 45), (<xarray.DataArray ()>
array(0.061285)
Coordinates:
    time     object 0002-05-16 12:00:00
    lev      float64 454.4, <xarray.DataArray ()>
array(0.938715)
Coordinates:
    time     object 0002-05-16 12:00:00
    lev      float64 454.4)]
[(61, 62), (<xarray.DataArray ()>
array(0.997147)
Coordinates:
    time     object 0002-05-16 12:00:00
    lev      float64 901.7, <xarray.DataArray ()>
array(0.002853)
Coordinates:
    time     object 0002-05-16 12:00:00
    lev      float64 901.7)]


In [7]:
p0 = 1000
p = p0 * hyam + ps * hybm / 100

MemoryError: 

In [14]:
x = p.isel(ncol=15, time=12)
x

<xarray.DataArray (lev: 72)>
array([1.238254e-01, 1.828292e-01, 2.699489e-01, 3.985817e-01, 5.885091e-01,
       8.689386e-01, 1.282995e+00, 1.894352e+00, 2.797027e+00, 4.129833e+00,
       5.968449e+00, 8.377404e+00, 1.147379e+01, 1.533394e+01, 1.999634e+01,
       2.544470e+01, 3.159325e+01, 3.836628e+01, 4.567120e+01, 5.330956e+01,
       6.101518e+01, 6.847639e+01, 7.535534e+01, 8.194628e+01, 8.891054e+01,
       9.646667e+01, 1.046650e+02, 1.135600e+02, 1.232110e+02, 1.336822e+02,
       1.450433e+02, 1.573699e+02, 1.707441e+02, 1.854316e+02, 2.016171e+02,
       2.192509e+02, 2.383834e+02, 2.591418e+02, 2.816645e+02, 3.061012e+02,
       3.326147e+02, 3.613815e+02, 3.925931e+02, 4.264572e+02, 4.631992e+02,
       5.025018e+02, 5.429381e+02, 5.831828e+02, 6.224298e+02, 6.602907e+02,
       6.966051e+02, 7.308843e+02, 7.626485e+02, 7.914373e+02, 8.168215e+02,
       8.384133e+02, 8.561961e+02, 8.712327e+02, 8.851844e+02, 8.988048e+02,
       9.120721e+02, 9.249646e+02, 9.374613e+02

In [50]:
weights = plev_linear_weights(plev, levs)

In [51]:
weights

OrderedDict([(5.298317366548036,
              [(9, 10), (0.3644762817614362, 0.6355237182385638)]),
             (6.214608098422191,
              [(10, 11), (0.8978151397199566, 0.10218486028004341)]),
             (6.829793737512425,
              [(10, 11), (0.6424406573943159, 0.3575593426056841)]),
             (7.090076835776092,
              [(10, 11), (0.534392521590144, 0.46560747840985595)])])

In [40]:
levs = [200, 500, 925, 1200]
plev = x.values
weights = {}
for ll in levs:
    iup = np.searchsorted(plev, ll, side='right')
    # iup = np.argmax(plev > ll)
    if iup == 0 or iup >= len(plev):
        continue
    idn = iup - 1
    dp = (ll - plev[idn]) / (plev[iup] - plev[idn])
    weights[ll] = [(idn, iup), (1 - dp, dp)]
weights

{200: [(33, 34), (0.09990860086471232, 0.9000913991352877)],
 500: [(44, 45), (0.06365568676381494, 0.936344313236185)],
 925: [(61, 62), (0.9971663025538356, 0.002833697446164301)]}

In [None]:
q.isel(ncol=15, time=12).values

In [46]:
apply_weights(q.isel(ncol=15, time=12).values, weights)

{200: 1.5652536476177286e-05,
 500: 0.000752676920976892,
 925: 0.006704479160862865}

In [44]:
apply_weights(plev, weights)

{200: 200.0, 500: 500.0, 925: 924.9999999999999}

In [36]:
np.argmax(plev > 500)

45

In [32]:
plev[34]

201.617063040152

In [26]:
np.argmax(x.values > 500)

45

In [28]:
x.values[44]

463.1992365015949

In [12]:
q2 = interp_to_plevs(ps, hyam, hybm, q, interp_q, p=p, extrap_check=True)

KeyboardInterrupt: 

In [22]:
ps = ds2['PS']
ds2['P0']

<xarray.DataArray 'P0' ()>
array(100000.)
Attributes:
    long_name:  reference pressure
    units:      Pa

In [23]:
ps

<xarray.DataArray 'PS' (time: 120, ncol: 777602)>
[93312240 values with dtype=float32]
Coordinates:
  * time     (time) object 0002-05-15 00:00:00 ... 0002-05-19 23:00:00
Dimensions without coordinates: ncol
Attributes:
    units:         Pa
    long_name:     Surface pressure
    cell_methods:  time: mean

In [26]:
ds2['hyam'].attrs

OrderedDict([('long_name', 'hybrid A coefficient at layer midpoints')])

In [37]:
np.all(ds2['PS'].sel(time=ds3.time) == ds2['PS'].reindex(time=ds3.time))

<xarray.DataArray 'PS' ()>
array(True)

In [36]:
ds2['PS'].reindex(time=ds3.time)

<xarray.DataArray 'PS' (time: 40, ncol: 777602)>
[31104080 values with dtype=float32]
Coordinates:
  * time     (time) object 0002-05-15 00:00:00 ... 0002-05-19 21:00:00
Dimensions without coordinates: ncol
Attributes:
    units:         Pa
    long_name:     Surface pressure
    cell_methods:  time: mean

## Multiprocessing approach (never worked that well)

In [12]:
from multiprocessing.pool import ThreadPool
pool = ThreadPool()

In [23]:
out = np.empty((40, 777602, 3))

def interp_col_to_pres_level(p, levs, da):
    shape = list(p.shape[:-1]) + [len(levs)]
    outx = np.full(shape, fill_value=np.nan)
    for ind in np.ndindex(p.shape[:-1]):
        weights = plev_linear_weights(p[ind], levs)
        outx[ind] = apply_weights_to_column(da[ind], weights)
    return outx

levs = [200, 500, 925]
ptest = p.isel(ncol=slice(0, 100))
da = u.isel(ncol=slice(0, 100))

def do_it(p, levs, da):
    shape = list(p.shape[:-1]) + [len(levs)]
    out = np.empty(shape)
    for ind in np.ndindex(p.shape[0]):
        out[ind] = interp_col_to_pres_level(p[ind], levs, da[ind])
    return out