In [None]:
import xarray as xr
import numpy as np
import zarr
from datetime import datetime
import metpy
from metpy.constants import water_heat_vaporization, dry_air_gas_constant, earth_gravity
from scipy.interpolate import interpn, NearestNDInterpolator

In [None]:
from matplotlib import pyplot as plt
import colorcet as cc

In [None]:
# datestring = '2016-10-22'
datestring = '2018-11-21'
casedate = np.datetime64(datestring)
datem1 = (casedate - np.timedelta64(1, "D")).astype(object).strftime("%Y%m%d")
date0 = casedate.astype(object).strftime("%Y%m%d")
datep1 = (casedate + np.timedelta64(1, "D")).astype(object).strftime("%Y%m%d")

In [None]:
atm = xr.open_dataset(f'forcing/era5/data/ERA5-{datem1}-{datep1}-ml.nc')
sfc = xr.open_dataset(f'forcing/era5/data/ERA5-{datem1}-{datep1}-sfc.nc')

In [None]:
if all(sfc.time[i] < sfc.time[i+1] for i in range(len(sfc.time) - 1)):
    print('sfc data all sorted!')
if all(atm.time[i] < atm.time[i+1] for i in range(len(atm.time) - 1)):
    print('atm data all sorted!')

In [None]:
time = np.asarray(atm.time)

In [None]:
print(time)

In [None]:
lon, lat = np.meshgrid(sfc.longitude.values, sfc.latitude.values[::-1])

In [None]:
def extract_ml_era5(d):

    """
    The calculation of pressure and geopotential height on model levels follows
    https://confluence.ecmwf.int/display/CKB/ERA5%3A+compute+pressure+and+geopotential+on+model+levels%2C+geopotential+height+and+geometric+height
    """

    rd = dry_air_gas_constant
    g = earth_gravity
    ab = np.genfromtxt(
        "era5_table.csv", delimiter=",", skip_header=1, missing_values="-"
    )
    a = ab[:, 1]
    b = ab[:, 2]

    t = d.t
    qv = d.q
    """
    Note here that the second index (1) here means getting the first level data,
    not the actual array index. Also, only first level data is valid
    """
    lnsp = d.lnsp.loc[:, 1, :, :]
    """
    The surface geopotential looks so noisy 
    because of the spectral decomposition/representation used in IFS
    """
    sgp = d.z.loc[:, 1, :, :]
    zs = sgp.metpy.quantify()/earth_gravity

    rv = metpy.calc.mixing_ratio_from_specific_humidity(qv)

    nt, nz, ny, nx = t.shape

    pi = np.zeros((nt, nz + 1, ny, nx))
    ps = np.exp(lnsp)
    pi[:] = ps.values[:, np.newaxis, :, :]
    pi = (
        a[np.newaxis, :, np.newaxis, np.newaxis]
        + pi * b[np.newaxis, :, np.newaxis, np.newaxis]
    )
    p = (pi[:, 1:, :, :] + pi[:, :-1, :, :]) * 0.5
    pi[:, 0, :, :] = 0.1
    dpi = pi[:, 1:, :, :] - pi[:, :-1, :, :]
    dlnpi = np.log(pi[:, 1:, :, :] / pi[:, :-1, :, :])

    """
    I have not got time to derive alpha, this is just what is given in the ERA5 documentation
    """
    alpha = 1.0 - dlnpi * pi[:, :-1, :, :] / dpi
    alpha[:, 0, :, :] = np.log(2.0)

    tm = t.metpy.quantify() * (1.0 + 0.609133 * rv.metpy.quantify())
    dphi = rd.magnitude * tm.values * dlnpi
    phi = np.zeros((nt, nz + 1, ny, nx))
    phi[:, :-1, :, :] = np.flip(np.cumsum(dphi[:, ::-1, :, :], axis=1), 1)
    phi[:] = phi[:] + sgp.values[:, np.newaxis, :, :]
    ph = phi[:, 1:, :, :] + rd.magnitude * tm.values * alpha

    ph = t.copy(deep=True, data=ph)
    ph.attrs["units"] = "m**2/s**2"
    del ph.attrs["long_name"]
    del ph.attrs["standard_name"]
    ph.metpy.quantify()

    p = t.copy(deep=True, data=p)
    p.attrs["units"] = "Pa"
    del p.attrs["long_name"]
    del p.attrs["standard_name"]
    p.metpy.quantify()

    z = ph.metpy.quantify() / g
    z = t.copy(deep=True, data=z)
    z.attrs["units"] = "m"
    del z.attrs["long_name"]
    del z.attrs["standard_name"]

    return (
        z.values - zs.values[:,np.newaxis,:,:],
        p.values,
        d.t.values,
        d.q.values,
        d.u.values,
        d.v.values,
    )

In [None]:
def extract_sfc_era5(d):

    """
    Supplying surface pressure to PINACLES for now.
    But sea-level pressure is smoother and may be better for our purpose.
    Especially given that pressure on the vertical levels are only used for the domain mean profiles
    in setting up reference state and in radiation.
    We always want to supply heights in terms of heights above the surface
    """
    ps = d.sp
    slp = d.msl
    t2m = d.t2m
    d2m = d.d2m
    zs = d.z.metpy.quantify() / earth_gravity
    qv2m = metpy.calc.specific_humidity_from_dewpoint(
        ps.metpy.quantify(), d2m.metpy.quantify()
    )
    u10m = d.u10
    v10m = d.v10
    """
    Supplying SST (instead of skin temperature) from ERA5 to PINACLES.
    Need to interpolate to fill missing values (NaNs).
    """
    lon, lat = d.longitude.values, d.latitude.values
    lon2d, lat2d = np.meshgrid(lon, lat)
    sst_in = d.sst.values
    missing_y, missing_x = np.where(sst_in[0, :, :] != sst_in[0, :, :])
    missing_pts = list(zip(lat[missing_y], lon[missing_x]))
    print(missing_pts)
    sst_out = np.empty_like(sst_in)
    nt, ny, nx = sst_in.shape
    for t in range(nt):
        sst = sst_in[t, :, :]
        sst_out[t, :, :] = sst
        """
        I couldn't be sure from reading scipy documentation what the right order is
        for the list of coordinates supplied to the interpolator. So I used lat/lon to make sure 
        the interpolation is done correctly (because nx and ny happen to be equal). 
        Maybe doc on the interpolator class has more info.
        """
        interp = NearestNDInterpolator(
            list(zip(lat2d[~np.isnan(sst)].flatten(), lon2d[~np.isnan(sst)].flatten())),
            sst[~np.isnan(sst)].flatten(),
        )
        sst_out[t][missing_y, missing_x] = interp(missing_pts)

    return (
        zs.values,
        ps.values,
        slp.values,
        t2m.values,
        qv2m.values,
        u10m.values,
        v10m.values,
        sst_out,
    )

In [None]:
z, p, t, qv, u, v = extract_ml_era5(atm)

In [None]:
zs, ps, slp, t2m, qv2m, u10m, v10m, sst = extract_sfc_era5(sfc)

In [None]:
sfc

In [None]:
sfc.sp[1,:,:].plot(levels=np.arange(100000, 102251, 250))
plt.show()

In [None]:
sfc.msl[1,:,:].plot(levels=np.arange(100000, 102251, 250))
plt.show()

In [None]:
sfc.skt[1,:,:].plot()
plt.show()

In [None]:
(sfc.sst[1,:,:] - sfc.skt[1,:,:]).plot(vmin=-0.5, vmax=0.5, extend='neither', cmap=cc.cm.coolwarm)
plt.show()

In [None]:
sst_out = sfc.sst.copy(deep=True, data=sst)
sst_out[1,:,:].plot()
plt.show()

In [None]:
(sst_out[1,:,:] -sfc.sst[1,:,:]).plot()
plt.show()

In [None]:
(sfc.z[1,:,:].metpy.quantify()/earth_gravity).plot(levels=np.arange(-10, 11, 2))
plt.show()

In [None]:
z.shape

In [None]:
out = zarr.open(f'pinacles_zarr/era5_{datem1}-{datep1}.zarr', mode='w')

out['time'] = time

out['SST'] = sst[:,::-1,:]
out['longitude_SST'] = lon
out['latitude_SST'] = lat

out['PSFC'] = ps[:,::-1,:]
out['longitude_PSFC'] = lon
out['latitude_PSFC'] = lat

out['T2m'] = t2m[:,::-1,:]
out['QV2m'] = qv2m[:,::-1,:]
out['u10m'] = u10m[:,::-1,:]
out['v10m'] = v10m[:,::-1,:]

out['Height'] = z[:,::-1,::-1,:]
out['longitude_Height'] = lon
out['latitude_Height'] = lat

out['T'] = t[:,::-1,::-1,:]
out['longitude_T'] = lon
out['latitude_T'] = lat

out['QV'] = qv[:,::-1,::-1,:]
out['longitude_QV'] = lon
out['latitude_QV'] = lat

out['P'] = p[:,::-1,::-1,:]
out['longitude_P'] = lon
out['latitude_P'] = lat

out['U'] = u[:,::-1,::-1,:]
out['longitude_U'] = lon
out['latitude_U'] = lat

out['V'] = v[:,::-1,::-1,:]
out['longitude_V'] = lon
out['latitude_V'] = lat