In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from glob import glob
import os
from datetime import datetime, timedelta
import pandas as pd
from multiprocessing import Pool
import xarray as xr
import numpy as np
import dask
from functools import partial

#Must do this to ensure we can find the library
os.system(f'export ESMFMKFILE=/glade/work/klesinger/conda-envs/tf212gpu/lib/esmf.mk')

import xesmf as xe


In [2]:
def xarray_varname(file: xr.DataArray) -> str:
    #grabs the first xarray variable, just easier with a function. But only works if you have 1 varaible in the xarray file
    return(list(file.keys())[0])

In [1]:
global dir_base,save_dir_base
dir_base = f'/glade/derecho/scratch/klesinger/FD_RZSM_deep_learning/Data/raw_downloads/ECMWF'
save_dir_base = f'/glade/derecho/scratch/klesinger/FD_RZSM_deep_learning/Data/reforecast/ECMWF'

# Notes

## We need to open each file using xarray and cfgrib

## Then we need to save the file so that CDO operators can regrid --- then regrid the saved file

## Then we need to convert RZSM from kg/m3 to m3/m3 by dividing by 1000

In [3]:
def return_var_info(var,region_name):
    
    #First create the daily data that we want
    file_source = f'{dir_base}/{var}'
    
    global control_,perturbed_
    control_ = sorted(glob(f'{file_source}/*control*nc'))
    perturbed_ = sorted(glob(f'{file_source}/*perturbed*nc'))
    
    date_list = [i.split('/')[-1].split('_')[2] for i in control_]
    
    global save_dir
    save_dir = f'{save_dir_base}/{region_name}/{var}"
    
    os.system(f'mkdir -p {save_dir}')
    
    return(control_,perturbed_,date_list, save_dir)

def return_control_info(var):
    
    #First create the daily data that we want
    file_source = f'{dir_base}/{var}'
    
    global control_
    control_ = sorted(glob(f'{file_source}/*control*nc'))

    return(control_)


In [None]:
# return_control_info(var)

In [4]:
# Get our GEFSv12 initialized dates

'''We need to wait until GEFSv12 data has been processed to get the correct coordinates '''
def return_init_info(region_name):

    global gefs_init
    
    gefs_init = sorted(glob(f'/glade/derecho/scratch/klesinger/FD_RZSM_deep_learning/Data/reforecast/GEFSv12/{region_name}/soilw_bgrnd'))
    
    if region_name == 'CONUS':
        gefs_init = sorted(glob('/glade/work/klesinger/FD_RZSM_deep_learning/Data/GEFSv12_reforecast/soilw_bgrnd/*'))
    elif region_name == 'australia':
        gefs_init = sorted(glob('/glade/work/klesinger/FD_RZSM_deep_learning/Data_australia/GEFSv12_reforecast/soilw_bgrnd/*'))
    elif region_name == 'china':
        gefs_init = sorted(glob('/glade/work/klesinger/FD_RZSM_deep_learning/Data_china/GEFSv12_reforecast/soilw_bgrnd/*'))
    
    gefs_dates = [i.split('/')[-1].split('_')[-1].split('.')[0] for i in gefs_init]
    
    #first "date" is a folder
    gefs_dates = gefs_dates[1:]
    
    global template, lon, lat
    template = xr.open_dataset(gefs_init[10]).load()
    lon, lat = template.X.values, template.Y.values

    return(template, lon, lat, gefs_dates)
# gefs_dates.reverse()


In [None]:
#Now we have the proper init dates to re-create a new dataset for ECMWF (we are going to select the lead dates and make a new file exactly the same shape as GEFSv12


# Need to make masks for bilinear interpolation or conservative to work

In [5]:


def resave_regrid_merge_multiprocessing_added_mask_RZSM(file):
    
    for region_name in ['CONUS', 'australia', 'china']:
        
        template, lon, lat, gefs_dates = return_init_info(region_name)
        
        # file = control_[0]
        len_of_lead = 45 #this is specific to ECMWF, i downloaded 45 days
        
        #Test 
        # file = control_[-30]
        #file comes from the control_ file list
        control_,perturbed_,date_list, save_dir = return_var_info(var,region_name)
        
        # break
        save_name = f"{save_dir}/{file.split('/')[-1]}"
        save_name = file.split('/')[-1].split('_control')
        save_name = f'{save_dir}/{save_name[0]}{save_name[1]}'
        
        if os.path.exists(save_name):
            pass
        else:
            
            fill_this_file = template.copy(deep = True)
            #First open the control file (divide by 1000 to convert to m3/m3)
            ctl = xr.open_dataset(file,engine='cfgrib').astype(np.float32) / 1000
    
            
    
            ctl["mask"] = xr.where(~np.isnan(ctl["sm100"].isel(step=0)), 1, 0)
            
             # #Check ctl file 
            ax = plt.axes(projection=ccrs.PlateCarree())
            ctl.isel(step=0)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
            ax.coastlines()
    
            ctl["mask"].plot(cmap="binary_r")
        
        
            #make a date and find perturbed file
            find_date = file.split('/')[-1].split('_')[2]
            perturbed_date =[i for i in perturbed_ if find_date in i]
            
            if len(perturbed_date) == 1:
    
                process_perturbed = xr.open_dataset(perturbed_date[0],engine='cfgrib')/ 1000
    
                # #Check perturbed file 
                # ax = plt.axes(projection=ccrs.PlateCarree())
                # process_perturbed.isel(step=0).isel(number=1)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                # process_perturbed.isel(step=0)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                # ax.coastlines()
                
                #keep all longitude values that are less than 90 (it's a weird way of doing it but our current mask for other data GEFSv12 is like this)
                if region_name == 'CONUS':
                    new_lon_values = [i if i < 90 else i-360 for i in ctl.longitude.values]
                    high_res_grid = xr.open_dataset('masks/NH_grid.nc') #I manually created this. Didn't know how to structure it until I ran the code once then came back and added it.
                    # high_res_grid = xr.open_dataset('masks/RZSM_weighted_mean_0_100cm.nc4')
                    high_res_grid.close()
                    # high_res_grid
                elif region_name == 'australia':
                    new_lon_values = ctl.longitude.values
                    high_res_grid = xr.open_dataset('masks/australia_mask.nc4')
                elif region_name == 'china':
                    new_lon_values = ctl.longitude.values
                    high_res_grid = xr.open_dataset('masks/china_mask.nc4')
                
                ctl = ctl.assign_coords({'longitude':new_lon_values})
                process_perturbed = process_perturbed.assign_coords({'longitude':new_lon_values})
    
                if region_name == 'CONUS':
                    ds_out = xr.Dataset(
                        {
                            "latitude": (["latitude"], np.arange(template.Y.values[0], template.Y.values[-1]-0.5, -0.5), {"units": "degrees_north"}),
                            "longitude": (["longitude"], np.arange(template.X.values[0], template.X.values[-1]+0.5, 0.5), {"units": "degrees_east"}),
                        }
                    )
                elif region_name == 'australia':
                    ds_out = xr.Dataset(
                        {
                            "latitude": (["latitude"], np.arange(template.Y.values[0], template.Y.values[-1]-0.5, -0.5), {"units": "degrees_north"}),
                            "longitude": (["longitude"], np.arange(template.X.values[0], template.X.values[-1]+0.5, 0.5), {"units": "degrees_east"}),
                        }
                    )
                elif region_name == 'china':
                    ds_out = xr.Dataset(
                        {
                            "latitude": (["latitude"], np.arange(template.Y.values[0], template.Y.values[-1]-0.5, -0.5), {"units": "degrees_north"}),
                            "longitude": (["longitude"], np.arange(template.X.values[0], template.X.values[-1]+0.5, 0.5), {"units": "degrees_east"}),
                        }
                    )               
                ds_out
    
                name_ = list(high_res_grid.keys())[0]
    
                if region_name == 'CONUS':
                    high_res_grid['mask'] = xr.where(high_res_grid[name_] == 1,0,1)
                    high_res_grid = high_res_grid.rename({'X':'longitude','Y':'latitude'})
                elif region_name == 'australia':
                    high_res_grid['mask'] = xr.where(np.isnan(high_res_grid[name_]),0,1)
                    high_res_grid = high_res_grid.isel(time=0)
                elif region_name == 'china':
                    high_res_grid['mask'] = xr.where(np.isnan(high_res_grid[name_]),0,1)
                    high_res_grid = high_res_grid.isel(time=0)
                
                high_res_grid["mask"].plot(cmap="binary_r")
                
                
                #Cannot do bilinear, it produces no values along the coasts of the data for RZSM. Conservative works better
                # regridder = xe.Regridder(ctl, high_res_grid, "patch")
                # regridder = xe.Regridder(ctl, high_res_grid, "bilinear")
                # conservative is the only way that we can get smooth edges around coasts
                regridder = xe.Regridder(ctl, high_res_grid, "conservative")
                
                ctl_out = regridder(ctl, keep_attrs=True)
                perturbed_out = regridder(process_perturbed, keep_attrs=True)
    
                # #Check perturbed file 
                ax = plt.axes(projection=ccrs.PlateCarree())
                perturbed_out.isel(step=0).isel(number=1)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                # perturbed_out.isel(step=0)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                ax.coastlines()
    
                #Now replace the template data with control and perturbed
                fill_this_file = fill_this_file.assign_coords({'S':np.atleast_1d(pd.to_datetime(find_date))}).reindex(L= np.arange(len(ctl_out.step.values)), fill_value=np.nan) 
                
                
                fill_this_file.RZSM[:,0,0:len(ctl_out.step.values),:,:] = ctl_out[var].values 
                fill_this_file.RZSM[:,1:,0:len(perturbed_out.step.values),:,:] = perturbed_out[var].values
    
                #Now add a mask for the values that are a "1" which are the ocean/water bodies
                fill_this_file = xr.where(np.isnan(fill_this_file), 1, fill_this_file)
    
                #Check final file 
                ax = plt.axes(projection=ccrs.PlateCarree())
                fill_this_file.isel(L=0).isel(M=1).isel(S=0)['RZSM'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                ax.coastlines()
                #Save file
                fill_this_file.to_netcdf(save_name)
    
                #Now regrid using cdo
            
            else:
                print(f'Could not find a perturbed file for {find_date}')
                pass
    
        
            
        

#Then save as a netcdf



In [None]:
var = 'soilw_bgrnd'
file_list = return_control_info(var)
# file=file_list[0] # for testing

if __name__ == '__main__':
    p=Pool(20)
    p.map(resave_regrid_merge_multiprocessing_added_mask_RZSM, file_list)


# NOw we need to make a new function for the other variables (tmax, tmin, specific humidity)

In [None]:
global var_full
var_full = 'temp_pwat_dewpoint'
other_var_file_list = return_control_info(var_full)

In [None]:
test_ = xr.open_dataset(other_var_file_list[0],engine='cfgrib')
test_

global vars_downloaded
vars_downloaded = ['tcw','t2m', 'd2m'] #total water content, 2m temperature, dewpoint temp 2m

In [None]:
file = other_var_file_list[0]


In [None]:


def resave_regrid_merge_multiprocessing_other_vars(file):
# for file in other_var_file_list:
    for region_name in ['CONUS', 'australia', 'china']:
        for var in vars_downloaded:
            # break
            template, lon, lat, gefs_dates = return_init_info(region_name)
            
            # file = control_[0]
            len_of_lead = 45 #this is specific to ECMWF, i downloaded 45 days
            
            #Test 
            # file = control_[-30]
            #file comes from the control_ file list
            control_,perturbed_,date_list, save_dir = return_var_info(var_full,region_name)
            
            # break
            save_name = f"{save_dir}/{file.split('/')[-1]}"
            save_name = file.split('/')[-1].split('_control')
            date = f"{save_name[0].split('_')[-1]}"
            save_name = f'{save_dir}/{var}_{date}.nc'
            
            if os.path.exists(save_name):
                pass
            else:
                
                fill_this_file = template.copy(deep = True)
                #First open the control file 
                ctl = xr.open_dataset(file,engine='cfgrib').astype(np.float32) 

                # find_date = file.split('/')[-1].split('_')[2]
                perturbed_date =[i for i in perturbed_ if date in i]
                
                if len(perturbed_date) == 1:
        
                    process_perturbed = xr.open_dataset(perturbed_date[0],engine='cfgrib')
        
                    # #Check perturbed file 
                    # ax = plt.axes(projection=ccrs.PlateCarree())
                    # process_perturbed.isel(step=0).isel(number=1)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                    # process_perturbed.isel(step=0)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                    # ax.coastlines()
                    
                    #keep all longitude values that are less than 90 (it's a weird way of doing it but our current mask for other data GEFSv12 is like this)
                    if region_name == 'CONUS':
                        new_lon_values =  [i if i < 90 else i-360 for i in ctl.longitude.values]
                        high_res_grid = xr.open_dataset('masks/NH_grid.nc') #I manually created this. Didn't know how to structure it until I ran the code once then came back and added it.
                        # high_res_grid = xr.open_dataset('masks/RZSM_weighted_mean_0_100cm.nc4')
                        high_res_grid.close()
                        high_res_grid = high_res_grid.rename({'X':'longitude','Y':'latitude'})
                        # high_res_grid
                    elif region_name == 'australia':
                        new_lon_values = ctl.longitude.values
                        high_res_grid = xr.open_dataset('masks/australia_mask.nc4')
                    elif region_name == 'china':
                        new_lon_values = ctl.longitude.values
                        high_res_grid = xr.open_dataset('masks/china_mask.nc4')
                    
                    ctl = ctl.assign_coords({'longitude':new_lon_values})
                    process_perturbed = process_perturbed.assign_coords({'longitude':new_lon_values})
        
                    if region_name == 'CONUS':
                        ds_out = xr.Dataset(
                            {
                                "latitude": (["latitude"], np.arange(template.Y.values[0], template.Y.values[-1]-0.5, -0.5), {"units": "degrees_north"}),
                                "longitude": (["longitude"], np.arange(template.X.values[0], template.X.values[-1]+0.5, 0.5), {"units": "degrees_east"}),
                            }
                        )
                    elif region_name == 'australia':
                        ds_out = xr.Dataset(
                            {
                                "latitude": (["latitude"], np.arange(template.Y.values[0], template.Y.values[-1]-0.5, -0.5), {"units": "degrees_north"}),
                                "longitude": (["longitude"], np.arange(template.X.values[0], template.X.values[-1]+0.5, 0.5), {"units": "degrees_east"}),
                            }
                        )
                    elif region_name == 'china':
                        ds_out = xr.Dataset(
                            {
                                "latitude": (["latitude"], np.arange(template.Y.values[0], template.Y.values[-1]-0.5, -0.5), {"units": "degrees_north"}),
                                "longitude": (["longitude"], np.arange(template.X.values[0], template.X.values[-1]+0.5, 0.5), {"units": "degrees_east"}),
                            }
                        )               
                    ds_out
        

                    #Cannot do bilinear, it produces no values along the coasts of the data for RZSM. Conservative works better
                    # regridder = xe.Regridder(ctl, high_res_grid, "patch")
                    regridder = xe.Regridder(ctl, high_res_grid, "bilinear")
                    # conservative is the only way that we can get smooth edges around coasts
                    # regridder = xe.Regridder(ctl, high_res_grid, "conservative")
                    
                    ctl_out = regridder(ctl, keep_attrs=True)
                    perturbed_out = regridder(process_perturbed, keep_attrs=True)
        
                    # #Check perturbed file 
                    # ax = plt.axes(projection=ccrs.PlateCarree())
                    # perturbed_out.isel(step=0).isel(number=1)[var].plot.pcolormesh(ax=ax, vmin=0, vmax=perturbed_out[var].max().values)
                    # perturbed_out.isel(step=0)['sm100'].plot.pcolormesh(ax=ax, vmin=0, vmax=1)
                    # ax.coastlines()
        
                    #Now replace the template data with control and perturbed
                    fill_this_file = fill_this_file.assign_coords({'S':np.atleast_1d(pd.to_datetime(date))}).reindex(L= np.arange(len(ctl_out.step.values)), fill_value=np.nan) 
                    
                    
                    fill_this_file[xarray_varname(fill_this_file)][:,0,0:len(ctl_out.step.values),:,:] = ctl_out[xarray_varname(ctl_out)].values
                    name_old = xarray_varname(fill_this_file)
                    name_new = var
                    fill_this_file =fill_this_file.rename({name_old:name_new})
                    fill_this_file[xarray_varname(fill_this_file)][:,1:,0:len(perturbed_out.step.values),:,:] = perturbed_out[xarray_varname(perturbed_out)].values
        
                    #Now add a mask for the values that are a "1" which are the ocean/water bodies
                    fill_this_file = xr.where(np.isnan(fill_this_file), 1, fill_this_file)
        
                    #Check final file 
                    # ax = plt.axes(projection=ccrs.PlateCarree())
                    # fill_this_file.isel(L=0).isel(M=1).isel(S=0)[var].plot.pcolormesh(ax=ax, vmin=0, vmax=perturbed_out[var].max().values)
                    # ax.coastlines()
                    #Save file
                    fill_this_file.to_netcdf(save_name)
        
                    #Now regrid using cdo
                
                else:
                    print(f'Could not find a perturbed file for {find_date}')
                    pass
        
            
                
        

#Then save as a netcdf



In [None]:
other_var_file_list.reverse()

In [None]:
if __name__ == '__main__':
    for region in ['CONUS','australia','china']:
        p=Pool(20)
        p.map(resave_regrid_merge_multiprocessing_other_vars, other_var_file_list)