In [18]:
import xarray as xr
import numpy as np
import atmos_physics as atmos_physics
import math
import dask.array as da

In [19]:
def get_train_test_split(training_split, longitudes, times):
    pure_split = int(longitudes*times*training_split)
    return math.floor(float(pure_split) / longitudes) * longitudes

In [20]:
#filepath = "/ocean/projects/ees220005p/gmooers/GM_Data/**00000[1]**.nc4"
filepath = "/ocean/projects/ees220005p/gmooers/GM_Data/**0000012[26]**.nc4"
savepath = "/ocean/projects/ees220005p/gmooers/GM_Data/training_data/"
n_z_input = 49
train_size=0.9

In [21]:
variables = xr.open_mfdataset(filepath)

In [22]:
x = variables.lon  # m
y = variables.lat  # m
z = variables.z  # m
p = variables.p  # hPa
rho = variables.rho  # kg/m^3
terra = variables.TERRA[:,:n_z_input]
SFC_PRES = variables.SFC_REFERENCE_P
SKT = variables.SKT
n_x = x.size
n_y = y.size
n_z = z.size
n_files = terra.shape[0]

In [23]:
cos_lat = np.zeros((n_files, n_y, n_x))
sin_lon = np.zeros((n_files, n_y, n_x))
cos_lat[:, :, :] = xr.ufuncs.cos(xr.ufuncs.radians(y.values[None, :, None]))
sin_lon[:, :, :] = xr.ufuncs.sin(xr.ufuncs.radians(x.values[None, None, :]))

In [24]:
adz = xr.zeros_like(z[:n_z_input]) 
dz = 0.5*(z[0]+z[1]) 
adz[0] = 1.

for k in range(1,n_z_input-1): # range doesn't include stopping number
    adz[k] = 0.5*(z[k+1]-z[k-1])/dz

adz[n_z_input-1] = (z[n_z_input-1]-z[n_z_input-2])/dz
rho_dz = adz*dz*rho

In [25]:
Tin = variables.TABS_SIGMA[:,:n_z_input] #originally just called tabs
Qrad = variables.QRAD_SIGMA[:,:n_z_input] / 86400
qt = (variables.QV_SIGMA[:,:n_z_input] + variables.QC_SIGMA[:,:n_z_input] + variables.QI_SIGMA[:,:n_z_input]) / 1000.0 # originally called qt
qp = variables.QP_SIGMA[:,:n_z_input] / 1000.0
q_auto_out = -1.0*variables.QP_MICRO_SIGMA[:,:n_z_input] / 1000.0
qpflux_z_coarse = variables.RHOQPW_SIGMA[:,:n_z_input] / 1000.0
T_adv_out = variables.T_FLUX_Z_OUT_SUBGRID_SIGMA[:,:n_z_input]     #originally tflux_z
q_adv_out = variables.Q_FLUX_Z_OUT_SUBGRID_SIGMA[:,:n_z_input] / 1000.0 #originally qtflux_z
qpflux_z = variables.QP_FLUX_Z_OUT_SUBGRID_SIGMA[:,:n_z_input] / 1000.0 
w = variables.W[:,:n_z_input]  # m/s
precip = variables.PREC_SIGMA[:,:n_z_input]  # precipitation flux kg/m^2/s
cloud_qt_flux = variables.SED_SIGMA[:,:n_z_input] / 1000.0
cloud_lat_heat_flux = variables.LSED_SIGMA[:,:n_z_input] 
qpflux_diff_coarse_z = variables.RHOQPS_SIGMA[:,:n_z_input] / 1000.0  # SGS qp flux kg/m^2/s Note that I need this variable
#q_auto_out = - dqp

In [26]:
a_pr = 1.0 / (atmos_physics.tprmax - atmos_physics.tprmin)
omp = np.maximum(0.0, np.minimum(1.0, (Tin - atmos_physics.tprmin) * a_pr))
fac = (atmos_physics.L + atmos_physics.Lf * (1.0 - omp)) / atmos_physics.cp

In [27]:
q_sed_fluxc_out = ((atmos_physics.L + atmos_physics.Lf) * cloud_qt_flux + cloud_lat_heat_flux) / atmos_physics.Lf
q_sed_fluxi_out = - (atmos_physics.L * cloud_qt_flux + cloud_lat_heat_flux) / atmos_physics.Lf
q_sed_flux_tot  = cloud_qt_flux

In [28]:
dfac_dz = np.zeros((n_files, n_z_input, n_y, n_x))
for k in range(n_z_input - 1):
    kb = max(0, k - 1)
    dfac_dz[:, k, :, :] = (fac[:, k + 1, :, :] - fac[:, k, :, :]) / rho_dz[k, :] * rho[:, k]

In [29]:
Tout = dfac_dz * (qpflux_z_coarse + qpflux_diff_coarse_z - precip) / rho

In [30]:
split_index = get_train_test_split(train_size, n_x, n_files)

In [31]:
my_dict_train = {}
my_dict_test = {}
my_weight_dict = {}

In [32]:
new_chunks = {'time':100, 'z': 49, 'lat': 100, 'lon': 100}
Tin =  da.from_array(Tin, chunks=new_chunks)
Tin = Tin.reshape(n_z_input, n_y, n_files * n_x)
my_dict_train["Tin"] = (("z","lat","sample"), Tin[...,:split_index])
my_dict_test["Tin"] = (("z","lat","sample"), Tin[...,split_index:])

qin = da.from_array(qt, chunks=new_chunks) 
qin = qin.reshape(n_z_input, n_y, n_files * n_x) 
my_dict_train["qin"] = (("z","lat","sample"), qin[...,:split_index])
my_dict_test["qin"] = (("z","lat","sample"), qin[...,split_index:])
    

Tout = da.from_array(Tout, chunks=new_chunks)  
Tout = Tout.reshape(n_z_input, n_y, n_files * n_x)
my_dict_train["Tout"] = (("z","lat","sample"), Tout[...,:split_index])
my_dict_test["Tout"] = (("z","lat","sample"), Tout[...,split_index:])

T_adv_out = da.from_array(T_adv_out, chunks=new_chunks)  
T_adv_out = T_adv_out.reshape(n_z_input, n_y, n_files * n_x) 
my_dict_train["T_adv_out"] = (("z","lat","sample"), T_adv_out[...,:split_index])
my_dict_test["T_adv_out"] = (("z","lat","sample"), T_adv_out[...,split_index:])
        
q_adv_out = da.from_array(q_adv_out, chunks=new_chunks)  
q_adv_out = q_adv_out.reshape(n_z_input, n_y, n_files * n_x) 
my_dict_train["q_adv_out"] = (("z","lat","sample"), q_adv_out[...,:split_index])
my_dict_test["q_adv_out"] = (("z","lat","sample"), q_adv_out[...,split_index:])
        
q_auto_out = da.from_array(q_auto_out, chunks=new_chunks) 
q_auto_out = q_auto_out.reshape(n_z_input, n_y, n_files * n_x) 
my_dict_train["q_auto_out"] = (("z","lat","sample"), q_auto_out[...,:split_index])
my_dict_test["q_auto_out"] = (("z","lat","sample"), q_auto_out[...,split_index:])

q_sed_flux_tot = da.from_array(q_sed_flux_tot, chunks=new_chunks)
q_sed_flux_tot = q_sed_flux_tot.reshape(n_z_input, n_y, n_files * n_x) 
my_dict_train["q_sed_flux_tot"] = (("z","lat","sample"), q_sed_flux_tot[...,:split_index])
my_dict_test["q_sed_flux_tot"] = (("z","lat","sample"), q_sed_flux_tot[...,split_index:])
    

terra = da.from_array(terra, chunks=new_chunks) 
terra = terra.reshape(n_z_input, n_y, n_files * n_x)
my_dict_train["terra"] = (("z","lat","sample"), terra[...,:split_index])
my_dict_test["terra"] = (("z","lat","sample"), terra[...,split_index:])
    
small_chunks = {'time':100, 'lat': 100, 'lon': 100}
sfc_pres = da.from_array(SFC_PRES, chunks=small_chunks)
sfc_pres = sfc_pres.reshape(n_y, n_files * n_x) 
my_dict_train["sfc_pres"] = (("lat","sample"), sfc_pres[...,:split_index])
my_dict_test["sfc_pres"] = (("lat","sample"), sfc_pres[...,split_index:])
    
skt = da.from_array(SKT, chunks=small_chunks) 
skt = skt.reshape(n_y, n_files * n_x)
my_dict_train["skt"] = (("lat","sample"), skt[...,:split_index])
my_dict_test["skt"] = (("lat","sample"), skt[...,split_index:])
    
cos_lat = np.expand_dims(cos_lat, axis=0)
cos_lat = np.moveaxis(cos_lat, 2, 3)
cos_lat = np.reshape(cos_lat, (1, n_y, -1)).squeeze()
my_dict_train["cos_lat"] = (("lat","sample"), cos_lat[...,:split_index])
my_dict_test["cos_lat"] = (("lat","sample"), cos_lat[...,split_index:])

sin_lon = np.expand_dims(sin_lon, axis=0)
sin_lon = np.moveaxis(sin_lon, 2, 3)
sin_lon = np.reshape(sin_lon, (1, n_y, -1)).squeeze()
my_dict_train["sin_lon"] = (("lat","sample"), sin_lon[...,:split_index])
my_dict_test["sin_lon"] = (("lat","sample"), sin_lon[...,split_index:])


In [44]:
def calculate_renormalization_factors(Tout, tflux_z, qtflux_z, qmic, qsed, rho_dz):
    '''Renormalize outputs assuming flux form renormalization for T_rad+rest, Tadv, qadv, qmic, qsed.'''

    # Calculate the differences along the vertical axis (assumed to be the first dimension here)
    zTout = -(da.diff(tflux_z, axis=0) / rho_dz[:, None, None])
    zqout = -(da.diff(qtflux_z, axis=0) / rho_dz[:, None, None])
    zqsed = -(da.diff(qsed, axis=0) / rho_dz[:, None, None])

    # Handle the top boundary condition where the flux is defined to be zero at the top half-level
    zTout = da.concatenate([zTout, -tflux_z[-1, :, :][None, :, :] / rho_dz[-1]], axis=0)
    zqout = da.concatenate([zqout, -qtflux_z[-1, :, :][None, :, :] / rho_dz[-1]], axis=0)
    zqsed = da.concatenate([zqsed, -qsed[-1, :, :][None, :, :] / rho_dz[-1]], axis=0)

    # Rescale humidity tendencies
    L_cp_ratio = atmos_physics.L / atmos_physics.cp
    zqsed = zqsed * L_cp_ratio
    qmic = qmic * L_cp_ratio
    zqout = zqout * L_cp_ratio

    # Compute standard deviations
    std1 = da.std(Tout, axis=(1, 2))
    std2 = da.std(zTout, axis=(1, 2))
    std3 = da.std(zqout, axis=(1, 2))
    std4 = da.std(qmic, axis=(1, 2))
    std5 = da.std(zqsed, axis=(1, 2))

    # Normalize the standard deviations
    std_min = da.min([std1, std2, std3, std4, std5])
    std_factors = da.array([std1, std2, std3, std4, std5]) / std_min

    return std_factors.compute()

In [47]:
T_adv_out

dask.array<reshape, shape=(49, 426, 1536), dtype=float32, chunksize=(49, 426, 1536)>

In [48]:
rho_dz[:,0]

<xarray.DataArray (z: 49)>
dask.array<shape=(49,), dtype=float64, chunksize=(49,)>
Coordinates:
  * z        (z) float64 20.0 61.2 104.9 151.2 ... 1.247e+04 1.297e+04 1.347e+04
    time     float64 21.42

In [45]:
norm_list = calculate_renormalization_factors(Tout[...,:split_index],
                                                          T_adv_out[...,:split_index],
                                                          q_adv_out[...,:split_index],
                                                          q_auto_out[...,:split_index],
                                                          q_sed_flux_tot[...,split_index:],
                                                          rho_dz[:,0]) 

IndexError: too many indices

In [None]:
my_weight_dict['norms'] = (("norm"), norm_list)

In [None]:
ds_train = xr.Dataset(
    my_dict_train,
    coords={
        "z": z[:n_z_input].values,
        "lat": y.values,
        "sample": np.arange(0,n_files*len(x.values), 1)[:split_index],
    },
)

In [None]:
ds_test = xr.Dataset(
    my_dict_test,
    coords={
        "z": z[:n_z_input].values,
        "lat": y.values,
        "sample": np.arange(0,n_files*len(x.values), 1)[split_index:],
    },
)

In [None]:
ds_train = xr.Dataset(
    my_weight_dict,
    coords={
        "norm":np.arange(1,6,1),
    },
)

In [None]:
ds_train.to_netcdf(savepath + "_train.nc")
ds_test.to_netcdf(savepath + "_test.nc")