In [3]:
import os
import datetime as dt
import numpy as np
import xarray as xr
from glob import glob
import pandas as pd
from multiprocessing import Pool
import sys

In [4]:
source = '/glade/derecho/scratch/klesinger/FD_RZSM_deep_learning/Data/reforecast/GEFSv12/CONUS/soilw_bgrnd/soilw_bgrnd_EMC_2007-04-04.nc'

In [5]:
f= xr.open_dataset(source)

In [7]:
f.RZSM.values

array([[[[[1.        , 1.        , 1.        , ..., 1.        ,
           1.        , 1.        ],
          [1.        , 1.        , 1.        , ..., 1.        ,
           1.        , 1.        ],
          [1.        , 1.        , 1.        , ..., 1.        ,
           1.        , 1.        ],
          ...,
          [0.43800002, 0.43720004, 0.4236429 , ..., 0.40442857,
           0.4028143 , 0.40088573],
          [0.43700004, 0.43400005, 0.43534288, ..., 0.40695718,
           0.40325716, 0.402     ],
          [0.43800002, 0.4313    , 0.4257429 , ..., 0.40385717,
           0.38892862, 0.38334292]],

         [[1.        , 1.        , 1.        , ..., 1.        ,
           1.        , 1.        ],
          [1.        , 1.        , 1.        , ..., 1.        ,
           1.        , 1.        ],
          [1.        , 1.        , 1.        , ..., 1.        ,
           1.        , 1.        ],
          ...,
          [0.43800002, 0.43745714, 0.4247286 , ..., 0.4082857 ,
    

In [5]:
global start_date, end_date, dates
#dates
#GEFS long-term (multi-ensemble) forecasts are only initialized on Wednesdays
start_date = dt.date(2000, 1, 1)

#Actual end date (only data through that period for this website https://noaa-gefs-retrospective.s3.amazonaws.com/index.html#GEFSv12/reforecast/)
end_date = dt.date(2019, 12, 31)

dates = [start_date + dt.timedelta(days=d) for d in range(0, end_date.toordinal() - start_date.toordinal() + 1)]

#from date time, Wednesday is a 2. (Monday is a 0) https://docs.python.org/3/library/datetime.html#datetime.datetime.weekday

dates = [i for i in dates if i.weekday() ==2]

In [6]:
def return_average_of_ensembles(var,var_name,open_d10,open_d35,template_GEFS_initial,ensemble_number):

    #take the average of the different lead dates, but not for tmax and tmin
    if (var == 'tmax_2m') or (var =='tmin_2m'):
     #for tmax and tmin
     #Only find the maximum and minimum of each day. Don't take the averages
     start_ = 0
     steps = {}
     
     #they ordered the days terribly. There are 6 timesteps for each day for d10 files. Except the last timestep which actually belongs to the first timestep of d35
     for i in range(len_leads):
         if i ==0:
             steps[i] = open_d10[f'{var_name}'][start_:start_+7,:,:].max(dim=step_name).values
             start_+=8 #needed to begin the next index to keep up with proper dates
         elif i<10:
             steps[i] = open_d10[f'{var_name}'][start_:start_+7,:,:].max(dim=step_name).values #eight total possible values until last time step
             start_+= 8 #Need to add one because we don't want to re-index the same day
         elif i == 10:
             # break
             try:
                 #Need to take from the first file (time 00:00:00), and combine with d35 files
                 #find the max of 4 different dates
                 #make a new array
                 new_array = np.empty(shape=(4,open_d10[f'{var_name}'].shape[1],open_d10[f'{var_name}'].shape[2]))
                 new_array[0,:,:] = open_d10[f'{var_name}'][-1,:,:]
                 new_array[1,:,:] = open_d35[f'{var_name}'][0,:,:]
                 new_array[2,:,:] = open_d35[f'{var_name}'][1,:,:] 
                 new_array[3,:,:] = open_d35[f'{var_name}'][2,:,:] 
                     
                 s1 = np.nanmax(new_array,axis=0)
                 steps[i] = s1
                 start_ = 3 #start count over, 4th file is the new date in d35 files
             except IndexError:
                 pass
                 #Some ensembles have broken members
         elif i <=34:
             steps[i] = open_d35[f'{var_name}'][start_:start_+4,:,:].max(dim=step_name).values
             start_+=4
     #Add to file
     for step,lead_day in enumerate(steps.keys()):
         template_GEFS_initial[:,ensemble_number,step,:,:] = steps[lead_day]
         
    else: 
        start_ = 0
        steps = {}
        for i in range(len_leads):
            # break
            if i ==0:
                steps[i] = open_d10[f'{var_name}'][start_:start_+7,:,:].mean(dim=step_name).values
                start_+=8 #needed to begin the next index to keep up with proper dates
            elif i<10:
                steps[i] = open_d10[f'{var_name}'][start_:start_+7,:,:].mean(dim=step_name).values #eight total possible values until last time step
                start_+= 8 #Need to add one because we don't want to re-index the same day
            elif i == 10:
                #Need to take from the first file (time 00:00:00), and combine with d35 files
                s1 = (open_d10[f'{var_name}'][-1,:,:] + open_d35[f'{var_name}'][0,:,:] + \
                    open_d35[f'{var_name}'][1,:,:] + open_d35[f'{var_name}'][2,:,:]) /4
                steps[i] = s1
                start_ = 3 #start count over, 4th file is the new date in d35 files
            elif i <=34:
                steps[i] = open_d35[f'{var_name}'][start_:start_+4,:,:].mean(dim=step_name).values
                start_+=4
        #Add to file
        for step,lead_day in enumerate(steps.keys()):
            template_GEFS_initial[:,ensemble_number,step,:,:] = steps[lead_day]         
            
      
    return(template_GEFS_initial[:,:,:,:,:])

In [7]:
def return_xarray_file(var, template_GEFS_initial, julian_list, _date, open_d10):
    if 'dlwrf' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                dlwrf = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Downwelling longwave radiation GEFSv12. Daily average already computed. All ensembles and Ls in one file'),
        )  
        
    if 'pwat_eatm' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                dlwrf = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Precipitable water GEFSv12. Daily average already computed. All ensembles and Ls in one file'),
        )  
        
    elif 'dswrf' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                dswrf = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Downwelling shortwave radiation GEFSv12. Daily average already computed. All ensembles and Ls in one file'),
        )  
    elif 'soil' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                RZSM = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
    
            ),
            attrs = dict(
                Description = 'Volumetric soil moisture content 0-100cm: 0.0-0.1, 0.1-0.4, 0.4-1.0 and 1.-2. m depth \
    (fraction between wilting and saturation) GEFSv12. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'tmp' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                tmean = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Average temperature GEFSv12. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'ulwrf' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                ulwrf = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Longwave upwelling radiation. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'uswrf' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                uswrf = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Shortwave upwelling radiation. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'spfh' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                spfh = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Specific humidity. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'tmax' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                tasmax = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Maximum Temperature. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'tmin' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                tasmin = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Minimum Temperature. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'uflx' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                uas = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'U component of wind, N/m2 momentum flux. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'vflx' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                vas = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'V component of wind, N/m2 momentum flux. Daily average already computed. All ensembles and Ls in one file')
        )
    elif var == 'hgt_pres':
        GEFS_out = xr.Dataset(
            data_vars = dict(
                z = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Geopotential height. Daily average already computed. All ensembles and Ls in one file')
        )
    
        # GEFS_out = GEFS_out.sel()
    elif 'apcp_sfc' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                pr = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Precipitation. Daily average already computed. All ensembles and Ls in one file')
        )
        GEFS_out2 = xr.Dataset(
            data_vars = dict(
                pr = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = np.arange(35),
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Precipitation. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'ugrd' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                uas = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'U component of wind, m/2. Daily average already computed. All ensembles and Ls in one file')
        )
    elif 'vgrd' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                vas = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'V component of wind, m/2. Daily average already computed. All ensembles and Ls in one file')
        )
        
    elif 'lhtfl' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                latentHeat = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Surface Latent heat, W/m^2. Daily average already computed. All ensembles and Ls in one file')
        )
        
    elif 'shtfl' in var:
        GEFS_out = xr.Dataset(
            data_vars = dict(
                sensibleHeat = (['S','M','L','Y','X'], template_GEFS_initial[:,:,:,:,:]),
            ),
            coords = dict(
              
                X = open_d10.longitude.values,
                Y = open_d10.latitude.values,
                L = julian_list,
                M = range(template_GEFS_initial.shape[1]),
                S = np.atleast_1d(pd.to_datetime(_date)),
            ),
            attrs = dict(
                Description = 'Surface Sensible heat, W/m^2. Daily average already computed. All ensembles and Ls in one file')
        )

    return(GEFS_out)

In [8]:
def merge_ensemble_members(_date):
    #testing
    # _date = pd.to_datetime('2004-03-03')
    # region_name='CONUS'
    # var = 'hgt_pres'
    
    
    global len_leads, step_name
    len_leads = 35
    
    #After manually inspecting the files after using NCL operators, the "step" has been replaced by "forecast_time0"
    step_name = 'forecast_time0'
            
    
    for region_name in region_names:
        for var in var_names:
            
            if region_name == 'CONUS':
                save_dir = '/glade/work/klesinger/FD_RZSM_deep_learning/Data/GEFSv12_reforecast'
                source_dir = f'/glade/derecho/scratch/klesinger/GEFSv12_raw/{var}/{var}_processed'
                lat_ = 200 #you must manually know the size of the input file
                lon_ = 380
            elif region_name == 'australia':
                save_dir = '/glade/work/klesinger/FD_RZSM_deep_learning/Data_australia/GEFSv12_reforecast'
                source_dir = f'/glade/derecho/scratch/klesinger/GEFSv12_raw/{var}/{var}_processed_australia'
                lat_ = 48 #you must manually know the size of the input file
                lon_ = 96
            elif region_name == 'china':
                save_dir = '/glade/work/klesinger/FD_RZSM_deep_learning/Data_china/GEFSv12_reforecast'
                source_dir = f'/glade/derecho/scratch/klesinger/GEFSv12_raw/{var}/{var}_processed_china'
                lat_ = 48 #you must manually know the size of the input file
                lon_ = 96        
            
            
            os.system(f'mkdir -p {save_dir}/{var}')
            os.chdir(f'{source_dir}')
            
            #For soil layer depth
            global soil_layer_depth, weighted_RZSM
            soil_layer_depth=3 #0-100cm. Can do 0-2m if number =4
            
            weighted_RZSM=True #weighted sum of the individual layers (we only have it set to 3 layers, you must modify the code if soil_layer_depth != 3)

            os.system(f'mkdir -p {save_dir}')
        
            def name(file):
                return(list(file.keys())[0])
        
            len_steps_d10 = 80 #I checked the good files and this is how many steps it should have
            len_steps_d35 = 100 #I checked the good files and this is how many steps it should have
            

            '''Need to create a new xarray template with 11 ensemble members, 35 lead days, lat/lon same as GMAO 
            (because file is already regridded to be in the same format as all other observations and GMAO '''
        
            #This template was created from the code above
            template_GEFS_initial = np.empty(shape=(1,11,len_leads,lat_,lon_))
            lead_splices = ['d10','d35']
        
            all_possible_ensemble_members = ['c00', 'p01', 'p02', 'p03', 'p04', 'p05', 'p06', 'p07', 'p08', 'p09', 'p10']
            #testing with one at a time
            # vars_to_process = ['soilw_bgrnd']
        
            #Get the dates of the files
        
            # _date=dates[0]
    
            #The date of the file is when it was intialized
            out_date_create = pd.to_datetime(_date) 
            out_date = f'{out_date_create.year}-{out_date_create.month:02}-{out_date_create.day:02}'
                
            final_out_name = f'{var}_EMC_{out_date}.nc'
    
            if os.path.exists(f"{save_dir}/{var}/{final_out_name}"):
                pass
            else:
                            
                print(f'Working on variable {var} to merge ensemble members for date {_date}.')
    
                '''Now combine the files, because there are some missing ensemble members (not sure why)
                we need to account for files with 1.) ALL members, 2.) some members
        
                It gets more complicated when the missing ensemble members are different between each
                d10 (first 10 days) and d35 (last 25 days), but I have figured it out'''
                
                template_GEFS_initial[:,:,:,:,:] = np.nan
                # lead_day=0 #keeps up with which index is correct in template_GEFS_initial (resets with each date)
    
                all_files_d10 = sorted(glob(f'{lead_splices[0]}_{var}_{_date.year}{_date.month:02}{_date.day:02}*.n*'))
                all_files_d35 = sorted(glob(f'{lead_splices[1]}_{var}_{_date.year}{_date.month:02}{_date.day:02}*.n*'))
    
                good_files_d10 = []
                good_files_d35 = []

                #Save name template for later
                
                
                #Some files don't have all the steps, so just remove the bad files from processing (they are going to be replaced later by a different realization)
                for i in all_files_d10:
                    op = xr.open_dataset(i)
                    op.close()
                    if len(op[step_name].values) == len_steps_d10: 
                        good_files_d10.append(i)
                        
                for i in all_files_d35:
                    op = xr.open_dataset(i)
                    op.close()
                    if len(op[step_name].values) == len_steps_d35: 
                        good_files_d35.append(i)            

                tem = good_files_d10[0].split('_')
                
                #some files have doubles (rsync error or from HPC when converting)
                file_len = [len(i) for i in good_files_d10]
                
                mode = max(set(file_len), key=file_len.count)
                #Replace files
                good_files_d10 = [i for i in good_files_d10 if len(i) == mode]
                good_files_d35 = [i for i in good_files_d35 if len(i) == mode]
                  
                    
                #TODO:If all realizations are present
                if (len(good_files_d10) == 11) and (len(good_files_d35) == 11):
                    #If all possible files are there, then this is the easy code processing to add data to single file
                    for ensemble_number,files in enumerate(zip(good_files_d10,good_files_d35)):
                         # break
                        if var != 'soilw_bgrnd' and var != 'hgt_pres':
                            open_d10 = xr.open_dataset(files[0])
                            open_d35 = xr.open_dataset(files[1])
                            try:
                                var_name = [i for i in list(open_d10.keys()) if 'step' not in i][0]
                            except IndexError:
                                pass
                                
                        elif var == 'hgt_pres':
                        
                            open_d10 = xr.open_dataset(files[0])
                            open_d35 = xr.open_dataset(files[1])
    
                            #If you have all the values in a file, you can subset here. 
                            # open_d10 = xr.open_dataset(files[0]).sel(isobaricInhPa=hgt_pressure_level)
                            # open_d35 = xr.open_dataset(files[1]).sel(isobaricInhPa=hgt_pressure_level)
    
                            var_name = [i for i in list(open_d10.keys()) if 'step' not in i][0]
                            
        
                                
                        elif var == 'soilw_bgrnd':
                            #Take the sum of the columns
                            open_d10= xr.open_dataset(files[0])
                            open_d35 = xr.open_dataset(files[1])
                            var_name = [i for i in list(open_d10.keys()) if 'step' not in i][0]
                            #TODO: Take the summation of the first 3 soil layers (0-100cm)
                            if weighted_RZSM == False:
                                open_d10 = open_d10[f'{var_name}'][:,0:soil_layer_depth,:,:].sum(dim=['SOILW_P1_2L106_GLL0']).to_dataset()
                                open_d35 = open_d35[f'{var_name}'][:,0:soil_layer_depth,:,:].sum(dim=['SOILW_P1_2L106_GLL0']).to_dataset()
                            else:
                                # break
                                #take weighted mean by layer
                                open_d10 = (np.multiply(open_d10[f'{var_name}'][:,0,:,:],0.1) + np.multiply(open_d10[f'{var_name}'][:,1,:,:],0.3) + np.multiply(open_d10[f'{var_name}'][:,1,:,:],0.6)).to_dataset()
                                open_d35 = (np.multiply(open_d35[f'{var_name}'][:,0,:,:],0.1) + np.multiply(open_d35[f'{var_name}'][:,1,:,:],0.3) + np.multiply(open_d35[f'{var_name}'][:,1,:,:],0.6)).to_dataset()
    
                        '''EMC has some broken files where there is only 1 time spot when it should have 35'''
                        try:
                            open_d10[f'{step_name}'].values
                            open_d35[f'{step_name}'].values
                            
                            if len(open_d10[f'{step_name}'].values) == 0 or len(open_d35[f'{step_name}'].values) == 0:
                                pass_=True
                            else:
                                pass_=False
                                
                        except AttributeError:
                            #no steps in file, just a single file (equals a bad file)
                            pass_=True
                        ##########################################################
                        if pass_==True:
                            pass
                        elif len(open_d10[f'{step_name}'].values)==1 or len(open_d35[f'{step_name}'].values)==1:
                            pass
                        else:
                        #First get the dates of the files
                            '''Take average of first 7 timesteps if d10 file. I have verified
                            this is correct when looking at HPC'''
                            template_GEFS_initial[:,:,:,:,:] = return_average_of_ensembles(var=var,var_name=var_name,open_d10=open_d10,open_d35=open_d35,template_GEFS_initial=template_GEFS_initial[:,:,:,:,:],ensemble_number=ensemble_number)    
                            
                            
                #If all ensembles are missing, do nothing
                elif (len(good_files_d10) == 0 )and (len(good_files_d35) == 0):
                    # print(_date)
                    pass
                
                #If there are a differnet number of ensembles between leads
                elif (len(good_files_d10) != 11) or (len(good_files_d35) != 11):
    
                    #Some ensembles are missing, split to get the name of ensemble members
                    avail_ensemble_members_d10 = [i.split('_')[-1].split('.')[0] for i in good_files_d10]
                    avail_ensemble_members_d35 = [i.split('_')[-1].split('.')[0] for i in good_files_d35]

                    missing_d10_ = sorted(list(set(all_possible_ensemble_members).difference(set(avail_ensemble_members_d10))))
                    missing_d35_ = sorted(list(set(all_possible_ensemble_members).difference(set(avail_ensemble_members_d35))))
                    
                    #if missing only the exact same realizations                
                    if (len(missing_d35_) ==  len(missing_d10_)) and (missing_d35_==missing_d10_):
                        #Find a way to append the missing ensemble files with np.nan
                        
                        for idx,ensemble in enumerate(all_possible_ensemble_members):
                            
                            if ensemble not in avail_ensemble_members_d10:
                                pass
                                #just keep the data as np.nan
                            else:
                                idx_num = avail_ensemble_members_d10.index(ensemble)
                                open_d10=xr.open_dataset(good_files_d10[idx_num])
                                open_d35=xr.open_dataset(good_files_d35[idx_num])
                                
                                if var == 'hgt_pres':
                                    open_d10 = xr.open_dataset(files[0])
                                    open_d35 = xr.open_dataset(files[1])
            
                                    #If you have all the values in a file, you can subset here. 
                                    # open_d10 = xr.open_dataset(files[0]).sel(isobaricInhPa=hgt_pressure_level)
                                    # open_d35 = xr.open_dataset(files[1]).sel(isobaricInhPa=hgt_pressure_level)
                                elif var == 'soilw_bgrnd':
                                    if weighted_RZSM == False:
                                        open_d10 = open_d10[f'{var_name}'][:,0:soil_layer_depth,:,:].sum(dim=['SOILW_P1_2L106_GLL0']).to_dataset()
                                        open_d35 = open_d35[f'{var_name}'][:,0:soil_layer_depth,:,:].sum(dim=['SOILW_P1_2L106_GLL0']).to_dataset()
                                    else:
                                        open_d10 = (np.multiply(open_d10[f'{var_name}'][:,0,:,:],0.1) + np.multiply(open_d10[f'{var_name}'][:,1,:,:],0.3) + np.multiply(open_d10[f'{var_name}'][:,1,:,:],0.6)).to_dataset()
                                        open_d35 = (np.multiply(open_d35[f'{var_name}'][:,0,:,:],0.1) + np.multiply(open_d35[f'{var_name}'][:,1,:,:],0.3) + np.multiply(open_d35[f'{var_name}'][:,1,:,:],0.6)).to_dataset()
    
                                else:
                                    open_d10=xr.open_dataset(good_files_d10[idx_num])
                                    open_d35=xr.open_dataset(good_files_d35[idx_num]) 
                                
                                var_name = [i for i in list(open_d10.keys()) if 'step' not in i][0]
                                
    
                                '''Take average of first 7 timesteps if d10 file. I have verified
                                this is correct when looking at HPC'''
                                template_GEFS_initial = return_average_of_ensembles(var=var,var_name=var_name,open_d10=open_d10,open_d35=open_d35,template_GEFS_initial=template_GEFS_initial,ensemble_number=idx)    
    
                    # missing different realizations 
                    else:
                         #Because there are missing ensemble members that are supposed to be aligned,
                         #we must delete those ensemble members
            
                         # #remove missing members
                         # out_10 = [i for i in all_possible_ensemble_members if i not in missing_d10_]
                         # out_35 = [i for i in all_possible_ensemble_members if i not in missing_d35_]
                        
                         #replace empty file with the control file. Deep learning doesn't like np.nan values
                         #there are so few of these missing files that it should be fine
                         for i in good_files_d10:
                             #for each set of files
                             for m in all_possible_ensemble_members:
                                 name_out = f'd10_{var}_{tem[3]}_{m}.nc'
                                 
                                 if name_out in good_files_d10:
                                     pass
                                 #for each set of members
                                 #we need to see if the file exists, if not create a blank one
                                 else:
                                     # break
                                     temp_10=xr.open_dataset(good_files_d10[0]) #make a temporary file as the blank file
                                     # temp_10[name(temp_10)][:,:,:] = np.nan
                                     temp_10.to_netcdf(f'd10_{i[4:24]}{m}.nc')
                                 
                         #replace empty file with the control file. Deep learning doesn't like np.nan values
                         #there are so few of these missing files that it should be fine
                         for i in good_files_d35:
                             # print(i)
                             #for each set of files
                             for m in all_possible_ensemble_members:
                                 # print(m)
                                 name_out = f'd35_{var}_{tem[3]}_{m}.nc'
                                 if name_out in good_files_d35:
                                     pass
                                 #for each set of members
                                 #we need to see if the file exists, if not create a blank one
                                 else:
                                     temp_10=xr.open_dataset(all_files_d35[0]) #make a temporary file as the blank file
                                     # temp_10[name(temp_10)][:,:,:] = np.nan
                                     temp_10.to_netcdf(f'd35_{i[4:24]}{m}.nc')
                         #'''Take average of first 7 timesteps if d10 file. I have verified
                         #this is correct when looking at HPC'''
                         
                         for idx,ensemble in enumerate(all_possible_ensemble_members):
                             name_out_d10 = f'd10_{var}_{tem[3]}_{ensemble}.nc'
                             name_out_d35 = f'd35_{var}_{tem[3]}_{ensemble}.nc'
                             
                             # idx_num = avail_ensemble_members_d10.index(ensemble)
                             open_d10=xr.open_dataset(name_out_d10)
                             open_d35=xr.open_dataset(name_out_d35)
                             var_name = [i for i in list(open_d10.keys()) if 'step' not in i][0]
                             
                             template_GEFS_initial[:,:,:,:,:]  = return_average_of_ensembles(var=var,var_name=var_name,open_d10=open_d10,open_d35=open_d35,template_GEFS_initial=template_GEFS_initial,ensemble_number=idx)    
               
    
                def julian_date(_date,template_GEFS_initial):
                    #Return julian date for anomaly calculation
                    a_date_in= template_GEFS_initial.shape[2]
                    #get the start date
                    a_start_date = pd.to_datetime(_date) 
        
                    a_date_out=[]
                    for a_i in range(a_date_in):
                        a_date_out.append((a_start_date + np.timedelta64(a_i,'D')).timetuple().tm_yday)
            
                    return(a_date_out)
    
                #Can specifically add the julian date if you want.
                # julian_list = julian_date(_date,template_GEFS_initial)
                #Instead of replacing the below lines, lets just make it a 35 day lead
    
                #This is just the number of leads
                julian_list=np.arange(len_leads)
    
                GEFS_out = return_xarray_file(var, template_GEFS_initial, julian_list, _date, open_d10)
                GEFS_out = GEFS_out.astype(np.float32)
                
                GEFS_out.to_netcdf(path = f"{save_dir}/{var}/{final_out_name}")
                GEFS_out.close()

    return(0)
            

In [None]:
if __name__ == '__main__':
    p = Pool(10)
    p.map(merge_ensemble_members,dates)

Working on variable soilw_bgrnd to merge ensemble members for date 2003-02-12.
Working on variable soilw_bgrnd to merge ensemble members for date 2004-02-25.Working on variable soilw_bgrnd to merge ensemble members for date 2002-08-07.Working on variable soilw_bgrnd to merge ensemble members for date 2000-01-05.Working on variable soilw_bgrnd to merge ensemble members for date 2001-01-17.Working on variable soilw_bgrnd to merge ensemble members for date 2004-09-01.Working on variable soilw_bgrnd to merge ensemble members for date 2001-07-25.
Working on variable soilw_bgrnd to merge ensemble members for date 2003-08-20.
Working on variable soilw_bgrnd to merge ensemble members for date 2002-01-30.Working on variable soilw_bgrnd to merge ensemble members for date 2000-07-12.






Working on variable hgt_pres to merge ensemble members for date 2001-01-17.Working on variable hgt_pres to merge ensemble members for date 2004-02-25.
Working on variable hgt_pres to merge ensemble members for 