# Flash drought dataset extraction

In [1]:
import os
import glob as glob
import xarray as xr
import sys
import dask
import tempfile
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster
import warnings
import logging
warnings.filterwarnings('ignore') 
logging.getLogger("distributed").setLevel(logging.ERROR)
logging.getLogger('flox').setLevel(logging.WARNING)

sys.path.append('/g/data/mn51/users/jb6465/code/flash-drought/attribution-python')
from extract import *

## Extract renanalysis

#### Extract BARRA (faster without dask)

In [5]:
extracted_data_save_dir = '/g/data/mn51/users/jb6465/data/flash_drought/reanalysis/BARRA-R2/'
nation_domain = 'barra_domain'

for target_var in ['sfcWind', 'tas', 'uas', 'vas', 'rsds', 'huss', 'ps']:
    print(target_var)
    for year in list(range(1979,2025)):
      barra_daily_extract(target_var, extracted_data_save_dir, nation_domain, year)

sfcWind
tas
uas
vas
rsds
huss
ps


#### Extract ERA5 (faster with dask)

In [3]:
dask.config.set({'array.chunk-size': "256 MiB",'array.slicing.split_large_chunks': True, 'distributed.comm.timeouts.connect': '120s', 'distributed.comm.timeouts.tcp': '120s', 'distributed.comm.retry.count': 10, 'distributed.scheduler.allowed-failures': 20, "distributed.scheduler.worker-saturation": 1.1})
client = Client(n_workers=12, threads_per_worker=1, local_directory = tempfile.mkdtemp(), memory_limit = "63000mb")

In [None]:
extracted_data_save_dir = '/g/data/mn51/users/jb6465/data/flash_drought/reanalysis/ERA5/'
nation_domain = 'barra_domain'

for target_var in ['10u', '10v', '10w', '2t', 'ssrd', 'q', 'sp']:
    print(target_var)
    for year in list(range(1979,2025)):
      era5_daily_extract(target_var, extracted_data_save_dir, nation_domain, year)  

10u


In [17]:
xr.Dataset({'w10': w10_cube})

In [11]:
era5_ws10_hly = (era5_vas_hly.v10**2 + era5_uas_hly.u10**2) ** 0.5
era5_ws10_dly = (era5_ws10_hly.sortby("time")).resample(time='D').mean()
era5_ws10_dly = era5_ws10_dly.chunk({'time':720, 'lat':'auto', 'lon':'auto'}).compute()

Unnamed: 0,Array,Chunk
Bytes,6.78 GiB,253.52 MiB
Shape,"(8784, 283, 366)","(8784, 39, 97)"
Dask graph,32 chunks in 40 graph layers,32 chunks in 40 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 6.78 GiB 253.52 MiB Shape (8784, 283, 366) (8784, 39, 97) Dask graph 32 chunks in 40 graph layers Data type float64 numpy.ndarray",366  283  8784,

Unnamed: 0,Array,Chunk
Bytes,6.78 GiB,253.52 MiB
Shape,"(8784, 283, 366)","(8784, 39, 97)"
Dask graph,32 chunks in 40 graph layers,32 chunks in 40 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [3]:
extracted_data_save_dir = '/g/data/mn51/users/jb6465/data/flash_drought/reanalysis/ERA5/'
nation_domain = 'barra_domain'

for target_var in ['10u', '10v', '2t', 'ssrd', 'q', 'sp']:
    print(target_var)
    for year in list(range(1979,2025)):
      era5_extract(target_var, extracted_data_save_dir, nation_domain, year)  

10u




KeyboardInterrupt: 

In [None]:
client.close()

2025-11-20 14:41:57,567 - distributed.semaphore - ERROR - Release failed for id=effc857861224e2ba81fedc23a7513c2, lease_id=925e7c81dd984b94bf9006d320982dd7, name=/g/data/mn51/users/jb6465/data/flash_drought/reanalysis/ERA5/10u/barra_domain_ERA5_1987_10u.nc. Cluster network might be unstable?
Traceback (most recent call last):
  File "/g/data/xp65/public/apps/med_conda/envs/analysis3-25.07/lib/python3.11/site-packages/distributed/semaphore.py", line 486, in _release
    await retry_operation(
  File "/g/data/xp65/public/apps/med_conda/envs/analysis3-25.07/lib/python3.11/site-packages/distributed/utils_comm.py", line 416, in retry_operation
    return await retry(
           ^^^^^^^^^^^^
  File "/g/data/xp65/public/apps/med_conda/envs/analysis3-25.07/lib/python3.11/site-packages/distributed/utils_comm.py", line 385, in retry
    return await coro()
           ^^^^^^^^^^^^
  File "/g/data/xp65/public/apps/med_conda/envs/analysis3-25.07/lib/python3.11/site-packages/distributed/core.py", li

#### Extract MERRA2 (faster with dask)

In [2]:
dask.config.set({'array.chunk-size': "256 MiB",'array.slicing.split_large_chunks': True, 'distributed.comm.timeouts.connect': '120s', 'distributed.comm.timeouts.tcp': '120s', 'distributed.comm.retry.count': 10, 'distributed.scheduler.allowed-failures': 20, "distributed.scheduler.worker-saturation": 1.1})
client = Client(n_workers=12, threads_per_worker=1, local_directory = tempfile.mkdtemp(), memory_limit = "63000mb")

In [None]:
extracted_data_save_dir = '/g/data/mn51/users/jb6465/data/flash_drought/reanalysis/MERRA2/'
extract_domain = 'barra_domain'

for target_var in [['SWGDN'],['T2M', 'QV2M', 'PS'],['U2M', 'V2M']]:
    print(target_var)
    for year in range(1980, 2023):
        merra2_daily_extract(target_var, extracted_data_save_dir, extract_domain, year)  

['SWGDN']
['T2M', 'QV2M', 'PS']


In [3]:
def merra2_daily_extract(target_var, extracted_data_save_dir, nation_domain, year):
    """
    Inputs:
    - target_var - string of MERRA2 target variable keys
    - extracted_data_save_dir - string of directory to save extracted data in
    - nation_domain - string to specify target nation domain
    - year - string of target extraction year
    Returns:
    Hourly MERRA2 variable files downloaded, aggregated to daily and saved in specified dir
    """
    import os
    import glob
    import xarray as xr
    import gc
    
    #preprocess functions to save memory and time
    def preprocess_U2M_V2M(ds):
        logging.getLogger('flox').setLevel(logging.WARNING)
        return ds[['V2M', 'U2M']].sel(lat=slice(domain_dict[nation_domain]['lat_min'], domain_dict[nation_domain]['lat_max']), lon=slice(domain_dict[nation_domain]['lon_min'], domain_dict[nation_domain]['lon_max']))

    def preprocess_T2M_QV2M_PS(ds):
        logging.getLogger('flox').setLevel(logging.WARNING)
        return ds[['T2M', 'QV2M', 'PS']].resample(time='1D').mean().sel(lat=slice(domain_dict[nation_domain]['lat_min'], domain_dict[nation_domain]['lat_max']), lon=slice(domain_dict[nation_domain]['lon_min'], domain_dict[nation_domain]['lon_max']))

    def preprocess_SWdn(ds): 
        logging.getLogger('flox').setLevel(logging.WARNING)
        return ds['SWGDN'].resample(time='1D').mean().sel(lat=slice(domain_dict[nation_domain]['lat_min'], domain_dict[nation_domain]['lat_max']), lon=slice(domain_dict[nation_domain]['lon_min'], domain_dict[nation_domain]['lon_max']))
    
    if target_var == ['U2M', 'V2M']:
        if not os.path.isfile(f"{extracted_data_save_dir}/W2M/{nation_domain}_MERRA2_{year}_W2M_day.nc"):
            files = sorted(glob.glob(f"{merra2_M2T1NXSLV_dir}/{year}/*/*.nc4"))
            U2M_V2M_cube = xr.open_mfdataset(files, combine='nested', concat_dim='time', parallel=True, preprocess=preprocess_U2M_V2M, engine='netcdf4').chunk({'time':-1, 'lat':'auto', 'lon':'auto'})
            U2M_V2M_cube.to_netcdf(f'{extracted_data_save_dir}/U2M_V2M/{nation_domain}_{year}_MERRA2_hly_U2M_V2M.nc', encoding={'U2M': {'zlib': True, 'complevel': 5, 'dtype':'float32'}, 'V2M': {'zlib': True, 'complevel': 5, 'dtype':'float32'}})
            del U2M_V2M_cube; gc.collect()
    
            wind = xr.open_mfdataset([f'{extracted_data_save_dir}/U2M_V2M/{file}' for file in os.listdir(f'{extracted_data_save_dir}/U2M_V2M/') if str(year) in file], combine='nested', concat_dim='time', parallel=True, engine='netcdf4').chunk({'time': -1, 'lat': 'auto', 'lon': 'auto'})
            W2M = ((((wind['U2M']**2)+(wind['V2M']**2))**0.5).rename('W2M')).resample(time='1D').mean()
            W2M.to_netcdf(f'{extracted_data_save_dir}/W2M/{nation_domain}_MERRA2_{year}_W2M_day.nc', encoding={'W2M': {'zlib': True, 'complevel': 5, 'dtype':'float32'}})
            del W2M; gc.collect()

    if target_var == ['T2M', 'QV2M', 'PS']:
        if not os.path.isfile(f"{extracted_data_save_dir}/T2M_QV2M/{nation_domain}_MERRA2_{year}_T2M_QV2M_day.nc"):
            files = sorted(glob.glob(f"{merra2_M2T1NXSLV_dir}/{year}/*/*.nc4"))
            T2M_QV2M_cube = xr.open_mfdataset(files, combine='nested', concat_dim='time', parallel=True, preprocess=preprocess_T2M_QV2M_PS, engine='netcdf4').chunk({'time':-1, 'lat':'auto', 'lon':'auto'})
            T2M_QV2M_cube.to_netcdf(f'{extracted_data_save_dir}/T2M_QV2M_PS/{nation_domain}_MERRA2_{year}_T2M_QV2M_PS_day.nc', encoding={'T2M': {'zlib': True, 'complevel': 5, 'dtype':'float32'}, 'QV2M': {'zlib': True, 'complevel': 5, 'dtype':'float32'}, 'PS': {'zlib': True, 'complevel': 5, 'dtype':'float32'}})
            del T2M_QV2M_cube; gc.collect()

    if target_var == ['SWGDN']:
        if not os.path.isfile(f"{extracted_data_save_dir}/SWGDN/{nation_domain}_MERRA2_{year}_SWGDN_day.nc"):
            files = sorted(glob.glob(f"{merra2_M2T1NXRAD_dir}/{year}/*/*.nc4"))
            SWGDN_cube = xr.open_mfdataset(files,combine='nested', concat_dim='time',parallel=True, preprocess=preprocess_SWdn, engine='netcdf4').chunk({'time':-1, 'lat':'auto', 'lon':'auto'})
            SWGDN_cube.to_netcdf(f'{extracted_data_save_dir}/SWGDN/{nation_domain}_MERRA2_{year}_SWGDN_day.nc', encoding={'SWGDN': {'zlib': True, 'complevel': 5, 'dtype':'float32'}})
            del SWGDN_cube; gc.collect()
    
    return

## Extract projections