In [None]:
from netCDF4 import Dataset
import numpy as np
import matplotlib.pyplot as plt
from wrf import interplevel, getvar, destagger
CP = 1005.7
RD = 287.04
P0 = 1000.
TR = 300.
LV = 2.501e6
EPS = 1.

In [None]:
dx = 30000.
dy = 30000.
nx = 500
ny = 240
nz_coamps = 45
nz_wrf = 100
shape_coamps = [nz_coamps, ny, nx]
shape_wrf = [nz_wrf, ny, nx]

wrf_dir = '/p/work1/lloveras/adj_30km/in_files/wrfin_adj_wave'
coamps_dir = '/p/work1/lloveras/adj_30km/pert_0h/'

wave = True

In [None]:
ncfile = Dataset(wrf_dir,'r+')
wrf_u = getvar(ncfile,'U')
wrf_v = getvar(ncfile,'V')
wrf_p = getvar(ncfile,'P')
wrf_th = getvar(ncfile,'T')
wrf_qv = getvar(ncfile,'QVAPOR')
wrf_full_p = getvar(ncfile,'pressure')


coamps_u = np.fromfile(coamps_dir + 'aaauu1_sig_020000_000010_1a0500x0240_2021060100_00000000_fcstfld', dtype='>f4')
coamps_v = np.fromfile(coamps_dir + 'aaavv1_sig_020000_000010_1a0500x0240_2021060100_00000000_fcstfld', dtype='>f4')
coamps_ex = np.fromfile(coamps_dir + 'aaapp1_sig_020000_000010_1a0500x0240_2021060100_00000000_fcstfld', dtype='>f4')
coamps_th = np.fromfile(coamps_dir + 'aaath1_sig_020000_000010_1a0500x0240_2021060100_00000000_fcstfld', dtype='>f4')
coamps_qv = np.fromfile(coamps_dir + 'aaaqv1_sig_020000_000010_1a0500x0240_2021060100_00000000_fcstfld', dtype='>f4')
coamps_full_p = np.fromfile(coamps_dir + 'ttlprs_sig_020000_000010_1a0500x0240_2021060100_00000000_fcstfld', dtype='>f4')


In [None]:
coamps_u = np.flip(np.reshape(coamps_u, shape_coamps),axis=0)
coamps_v = np.flip(np.reshape(coamps_v, shape_coamps),axis=0)
coamps_ex = np.flip(np.reshape(coamps_ex, shape_coamps),axis=0)
coamps_th = np.flip(np.reshape(coamps_th, shape_coamps),axis=0)
coamps_qv = np.flip(np.reshape(coamps_qv, shape_coamps),axis=0)
coamps_full_p = np.flip(np.reshape(coamps_full_p, shape_coamps),axis=0)

coamps_pert_ex = (coamps_full_p/P0)**(RD/CP) + coamps_ex
coamps_pert_p = P0*coamps_pert_ex**(CP/RD)
coamps_p = (coamps_pert_p - coamps_full_p)*100.

In [None]:
umax = np.amax(np.abs(coamps_u))
vmax = np.amax(np.abs(coamps_v))
thmax = np.amax(np.abs(coamps_th))
pmax = np.amax(np.abs(coamps_p))
qvmax = np.amax(np.abs(coamps_qv))
print(umax, vmax, thmax, pmax, qvmax*1000)

In [None]:
u_pert = np.zeros([nz_wrf,ny,nx+1])
v_pert = np.zeros([nz_wrf,ny+1,nx])
p_pert = np.zeros(shape_wrf)
th_pert = np.zeros(shape_wrf)
qv_pert = np.zeros(shape_wrf)
for k in range(nz_wrf):
    u_pert[k,:,:-1] = interplevel(coamps_u, coamps_full_p, wrf_full_p[k,:,:])
    v_pert[k,:-1,:] = interplevel(coamps_v, coamps_full_p, wrf_full_p[k,:,:])
    th_pert[k,:,:] = interplevel(coamps_th, coamps_full_p, wrf_full_p[k,:,:])
    p_pert[k,:,:] = interplevel(coamps_p, coamps_full_p, wrf_full_p[k,:,:])
    qv_pert[k,:,:] = interplevel(coamps_qv, coamps_full_p, wrf_full_p[k,:,:])
    
u_pert[:,:,-1] = u_pert[:,:,-2]
v_pert[:,-1,:] = v_pert[:,-2,:]

u_pert = np.nan_to_num(u_pert)
v_pert = np.nan_to_num(v_pert)
th_pert = np.nan_to_num(th_pert)
p_pert = np.nan_to_num(p_pert)
qv_pert = np.nan_to_num(qv_pert)


In [None]:
wrf_full_th = np.array(wrf_th) + 300
wrf_full_p = np.array(wrf_full_p)
wrf_fullpert_p = wrf_full_p + p_pert
wrf_fullpert_th = wrf_full_th + th_pert

wrf_full_tk = wrf_full_th*(wrf_full_p/P0)**(RD/CP)
wrf_fullpert_tk = wrf_fullpert_th*(wrf_fullpert_p/P0)**(RD/CP)
tk_pert = wrf_fullpert_tk - wrf_full_tk

dte1_adj = np.sum(np.sum(np.sum(np.squeeze(u_pert[:,:,:]**2.0),1),1))
dte2_adj = np.sum(np.sum(np.sum(np.squeeze(v_pert[:,:,:]**2.0),1),1))
dte3_adj = np.sum(np.sum(np.sum(np.squeeze(CP/TR*(tk_pert[:,:,:]**2.0)),1),1))
dte_adj = 0.5*(dte1_adj + dte2_adj + dte3_adj)


In [None]:
l = 210.
lm = l*1000.

u_pert_wave = np.zeros([nz_wrf,ny,nx+1])
v_pert_wave = np.zeros([nz_wrf,ny+1,nx])
p_pert_wave = np.zeros(shape_wrf)
th_pert_wave = np.zeros(shape_wrf)
qv_pert_wave = np.zeros(shape_wrf)

for i in range(nx):
    for j in range(ny):
        u_pert_wave[:,j,i] = u_pert[:,j,i]*np.sin(2*np.pi*i*dx/lm)*np.sin(2*np.pi*j*dx/lm)
        v_pert_wave[:,j,i] = v_pert[:,j,i]*np.sin(2*np.pi*i*dx/lm)*np.sin(2*np.pi*j*dx/lm)
        p_pert_wave[:,j,i] = p_pert[:,j,i]*np.sin(2*np.pi*i*dx/lm)*np.sin(2*np.pi*j*dx/lm)
        th_pert_wave[:,j,i] = th_pert[:,j,i]*np.sin(2*np.pi*i*dx/lm)*np.sin(2*np.pi*j*dx/lm)
        qv_pert_wave[:,j,i] = qv_pert[:,j,i]*np.sin(2*np.pi*i*dx/lm)*np.sin(2*np.pi*j*dx/lm)
        

In [None]:
wrf_wavepert_p = wrf_full_p + p_pert_wave
wrf_wavepert_th = wrf_full_th + th_pert_wave

wrf_wavepert_tk = wrf_wavepert_th*(wrf_wavepert_p/P0)**(RD/CP)
tk_pert_wave = wrf_wavepert_tk - wrf_full_tk

dte1_wave = np.sum(np.sum(np.sum(np.squeeze(u_pert_wave[:,:,:]**2.0),1),1))
dte2_wave = np.sum(np.sum(np.sum(np.squeeze(v_pert_wave[:,:,:]**2.0),1),1))
dte3_wave = np.sum(np.sum(np.sum(np.squeeze(CP/TR*(tk_pert_wave[:,:,:]**2.0)),1),1))
dte_wave = 0.5*(dte1_wave + dte2_wave + dte3_wave)


In [None]:
fac = dte_adj/dte_wave
s = np.sqrt(fac)
u_pert_wave = s*u_pert_wave
v_pert_wave = s*v_pert_wave
p_pert_wave = s*p_pert_wave
th_pert_wave = s*th_pert_wave
qv_pert_wave = s*qv_pert_wave

In [None]:
if wave:
    u_pert = u_pert_wave
    v_pert = v_pert_wave
    p_pert = p_pert_wave
    th_pert = th_pert_wave
    qv_pert = qv_pert_wave

ncfile.variables['U'][0,:,:,:] = wrf_u + u_pert
ncfile.variables['V'][0,:,:,:] = wrf_v + v_pert
ncfile.variables['T'][0,:,:,:] = wrf_th + th_pert
ncfile.variables['P'][0,:,:,:] = wrf_p + p_pert
ncfile.variables['QVAPOR'][0,:,:,:] = wrf_qv + qv_pert


In [None]:
ncfile.close()
