In [None]:
import xarray as xr
import numpy as np
import zarr
from datetime import datetime
import metpy
from metpy.constants import water_heat_vaporization, dry_air_gas_constant, earth_gravity
from scipy.interpolate import interpn, NearestNDInterpolator

In [None]:
from matplotlib import pyplot as plt

In [None]:
ena_lat = 39.0916  # deg N
ena_lon = -28.0257  # deg E
lat1, lat2 = ena_lat-10.0, ena_lat+10.0
lon1, lon2 = ena_lon-10.0, ena_lon+10.0
print(lat1, lat2)
print(lon1, lon2)

In [None]:
datestring = '2016-10-22'
# datestring = '2018-11-21'
casedate = np.datetime64(datestring)
datem1 = (casedate - np.timedelta64(1, "D")).astype(object).strftime("%Y%m%d")
date0 = casedate.astype(object).strftime("%Y%m%d")
datep1 = (casedate + np.timedelta64(1, "D")).astype(object).strftime("%Y%m%d")

In [None]:
# extract MERRA2 surface fluxes and large-scale forcings
forc_dir = "forcing"
atm = xr.open_mfdataset(
    [
        f"{forc_dir}/merra2/data/MERRA2_400.inst3_3d_asm_Nv.{datem1}.nc4",
        f"{forc_dir}/merra2/data/MERRA2_400.inst3_3d_asm_Nv.{date0}.nc4",
        f"{forc_dir}/merra2/data/MERRA2_400.inst3_3d_asm_Nv.{datep1}.nc4",
    ]
)
sfc = xr.open_mfdataset(
    [
        f"{forc_dir}/merra2/data/MERRA2_400.tavg1_2d_flx_Nx.{datem1}.nc4",
        f"{forc_dir}/merra2/data/MERRA2_400.tavg1_2d_flx_Nx.{date0}.nc4",
        f"{forc_dir}/merra2/data/MERRA2_400.tavg1_2d_flx_Nx.{datep1}.nc4",
    ]
)
sfc2 = xr.open_mfdataset(
    [
        f"{forc_dir}/merra2/data/MERRA2_400.tavg1_2d_slv_Nx.{datem1}.nc4",
        f"{forc_dir}/merra2/data/MERRA2_400.tavg1_2d_slv_Nx.{date0}.nc4",
        f"{forc_dir}/merra2/data/MERRA2_400.tavg1_2d_slv_Nx.{datep1}.nc4",
    ]
)

In [None]:
time = np.asarray(atm.time)

In [None]:
lon, lat = np.meshgrid(sfc.lon.loc[lon1:lon2].values, sfc.lat.loc[lat1:lat2].values)

In [None]:
t = atm.time[8]
(atm.PHIS.loc[t,lat1:lat2,lon1:lon2]/9.8).plot()
plt.show()


In [None]:
def extract_ml_merra2(d, lat1, lat2, lon1, lon2):

    g = earth_gravity.magnitude
    zs = d.PHIS.loc[:,lat1:lat2,lon1:lon2].values/g
    z = d.H.loc[:,:,lat1:lat2,lon1:lon2].values - zs[:,np.newaxis,:,:]

    return ( 
        zs,
        d.PS.loc[:,lat1:lat2,lon1:lon2].values,
        d.SLP.loc[:,lat1:lat2,lon1:lon2].values,
        z,
        d.PL.loc[:,:,lat1:lat2,lon1:lon2].values,
        d.T.loc[:,:,lat1:lat2,lon1:lon2].values,
        d.QV.loc[:,:,lat1:lat2,lon1:lon2].values,
        d.U.loc[:,:,lat1:lat2,lon1:lon2].values,
        d.V.loc[:,:,lat1:lat2,lon1:lon2].values,
    )

In [None]:
zs, ps, slp, z, p, t, qv, u, v = extract_ml_merra2(atm, lat1, lat2, lon1, lon2)

In [None]:
t2m = sfc2.T2M.loc[:,lat1:lat2, lon1:lon2].interp(time=time, method='linear', kwargs={"fill_value": "extrapolate"})
sfc2.T2M.loc[:,lat1:lat2, lon1:lon2].mean(axis=(1,2)).plot(label='original')
t2m.mean(axis=(1,2)).plot(label='interpolated')
plt.legend()
plt.show()

In [None]:
def extract_sfc_merra2(d, atm_time, zs, lat1, lat2, lon1, lon2):

    t2m = d.T2M.loc[:, lat1:lat2, lon1:lon2].interp(
        time=atm_time, method="linear", kwargs={"fill_value": "extrapolate"}
    )
    qv2m = d.QV2M.loc[:, lat1:lat2, lon1:lon2].interp(
        time=atm_time, method="linear", kwargs={"fill_value": "extrapolate"}
    )
    u10m = d.U10M.loc[:, lat1:lat2, lon1:lon2].interp(
        time=atm_time, method="linear", kwargs={"fill_value": "extrapolate"}
    )
    v10m = d.V10M.loc[:, lat1:lat2, lon1:lon2].interp(
        time=atm_time, method="linear", kwargs={"fill_value": "extrapolate"}
    )
    ts = d.TS.loc[:, lat1:lat2, lon1:lon2].interp(
        time=atm_time, method="linear", kwargs={"fill_value": "extrapolate"}
    )
    ts1 = ts.copy(deep=True)

    """
    Supplying skin temp. from MERRA2 to PINACLES.
    For pts with non-zero surface height, skin temp. interpolated
    using NearestNDInterpolator
    """
    lon, lat = t2m.lon.values, t2m.lat.values
    lon2d, lat2d = np.meshgrid(lon, lat)
    sst_in = ts.values
    missing_y, missing_x = np.where(zs[0, :, :] > 0.0)
    missing_pts = list(zip(lat[missing_y], lon[missing_x]))
    print(missing_pts)
    print(zs[0][missing_y, missing_x])
    sst_out = np.empty_like(sst_in)
    nt, ny, nx = sst_in.shape
    for t in range(nt):
        sst = sst_in[t, :, :]
        sst_out[t, :, :] = sst
        """
        I couldn't be sure from reading scipy documentation what the right order is
        for the list of coordinates supplied to the interpolator. So I used lat/lon to make sure 
        the interpolation is done correctly (because nx and ny happen to be equal). 
        Maybe doc on the interpolator class has more info.
        """
        interp = NearestNDInterpolator(
            list(
                zip(
                    lat2d[np.where(zs[0, :, :] <= 0.0)].flatten(),
                    lon2d[np.where(zs[0, :, :] <= 0.0)].flatten(),
                )
            ),
            sst[np.where(zs[0, :, :] <= 0.0)].flatten(),
        )
        sst_out[t][missing_y, missing_x] = interp(missing_pts)
    ts2 = ts1.copy(deep=True, data=sst_out)

    return (
        t2m.values,
        qv2m.values,
        u10m.values,
        v10m.values,
        sst_out,
        ts1,
        ts2,
    )

In [None]:
t2m, qv2m, u10m, v10m, sst, ts1, ts2 = extract_sfc_merra2(sfc2, time, zs, lat1, lat2, lon1, lon2)

In [None]:
t = sfc.time[23]
(sfc2.TS.loc[t, lat1:lat2, lon1:lon2] - sfc.TSH.loc[t, lat1:lat2, lon1:lon2]).plot()

In [None]:
(ts2[10,:,:] - ts1[10,:,:]).plot()
plt.show()

In [None]:
ts1[1,:,:].plot()
plt.show()

In [None]:
ts2[1,:,:].plot()
plt.show()

In [None]:
out = zarr.open(f'pinacles_zarr/merra2_{datem1}-{datep1}.zarr', mode='w')

out['time'] = time

out['SST'] = sst[:,:,:]
out['longitude_SST'] = lon
out['latitude_SST'] = lat

out['PSFC'] = ps[:,:,:]
out['longitude_PSFC'] = lon
out['latitude_PSFC'] = lat

out['T2m'] = t2m[:,:,:]
out['QV2m'] = qv2m[:,:,:]
out['u10m'] = u10m[:,:,:]
out['v10m'] = v10m[:,:,:]

out['Height'] = z[:,::-1,:,:]
out['longitude_Height'] = lon
out['latitude_Height'] = lat

out['T'] = t[:,::-1,:,:]
out['longitude_T'] = lon
out['latitude_T'] = lat

out['QV'] = qv[:,::-1,:,:]
out['longitude_QV'] = lon
out['latitude_QV'] = lat

out['P'] = p[:,::-1,:,:]
out['longitude_P'] = lon
out['latitude_P'] = lat

out['U'] = u[:,::-1,:,:]
out['longitude_U'] = lon
out['latitude_U'] = lat

out['V'] = v[:,::-1,:,:]
out['longitude_V'] = lon
out['latitude_V'] = lat