In [None]:
proj_dir='/path/to/main_project_folder/' # edit this line

import numpy as np
import pandas as pd
import xarray as xr
import dask.array as da
import glob as glob
import sys
sys.path.append(proj_dir) 
from project_utils import parameters as param
from project_utils import load_region
from project_utils import prepare_inputs
import importlib
importlib.reload(param)
importlib.reload(load_region)
importlib.reload(prepare_inputs)

def prep_ncep_files(ds):
        ds = ds.drop(['gw', 'lat_bnds', 'lon_bnds', 'area', 'time_bnds', 'level_bnds'], errors='ignore')
        return ds

def prep_era5_files(ds):
        ds['lat'] = xr.open_dataset(z_era5_files[5]).lat
        ds = ds.swap_dims({'latitude':'lat', 'longitude':'lon'}).drop(['gw', 'lat_bnds', 'lon_bnds', 'area'])
        return ds

region_list = ['northcentral_north_america', 
               'southcentral_north_america', 
               'southeastern_north_america', 
               'southwestern_europe', 
               'western_europe', 
               'central_europe', 
               'eastern_europe', 
               'northeastern_europe', 
               'northeastern_asia', 
               'southeastern_asia', 
               'northsouthern_south_america', 
               'southsouthern_south_america', 
               'southwestern_africa', 
               'southeastern_africa', 
               'southwestern_australia', 
               'southeastern_australia', 
               'west_texas',
               'east_texas',
              ]

hgt_ncep_files = sorted(glob.glob("../input_data_NCEP/NCEP-R2/hgt/*.nc"))
soilw_ncep_files = sorted(glob.glob("../input_data_NCEP/NCEP-R2/soilw/*.nc"))
tmax_ncep_files = sorted(glob.glob("../input_data_NCEP/NCEP-R2/tmax/*.nc"))
    
for dset in ['ERA5', 'NCEP']:
    for jj in range(len(region_list)):
        region_str = region_list[jj]
        print(jj)
        print(region_str)

        hem, region_input_lat_bbox, region_input_lon_bbox, region_box_x, region_box_y, region_lat, region_lon, region_lon_EW, region_t62_lats, region_t62_lons = load_region.load_region_constants(region_str)        

        era5_path = "../input_data_ERA5/"+region_str+"/daily/"
        z_era5_files = sorted(glob.glob(era5_path+"z/*t62.nc"))
        soilw_era5_files = sorted(glob.glob(era5_path+"swvl1/*t62.nc"))
        tmax_era5_files = sorted(glob.glob(era5_path+"tmax/*t62.nc"))
        
        ######################### 500-hPa geopotential height: #############################
        # read daily data, subtract domain average 500-hPa trend, calculate daily standardized anomalies
        
        if dset == 'ERA5':
            hgt_files = z_era5_files
            soilw_files = soilw_era5_files
            tmax_files = tmax_era5_files
            prep_files = prep_era5_files
        elif dset == 'NCEP':
            hgt_files = hgt_ncep_files
            soilw_files = soilw_ncep_files
            tmax_files = tmax_ncep_files
            prep_files = prep_ncep_files
        else:
            hgt_files = None
            prep_files = None
            print("DATASET NOT FOUND!")

        if isinstance(region_input_lon_bbox, slice): 
            hgt_ds = xr.open_mfdataset(hgt_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                lat = region_input_lat_bbox, 
                lon = region_input_lon_bbox,
                time = param.time_period)
            if dset == 'NCEP':
                hgt_ds = hgt_ds.sel(level = param.hgt_level)
            
        else:
            hgt_ds_left = xr.open_mfdataset(hgt_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                lat = region_input_lat_bbox, 
                lon = region_input_lon_bbox[0], 
                time = param.time_period)
            hgt_ds_right = xr.open_mfdataset(hgt_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                lat = region_input_lat_bbox, 
                lon = region_input_lon_bbox[1],
                time = param.time_period)
            if dset == 'NCEP':
                hgt_ds_left = hgt_ds_left.sel(level = param.hgt_level)
                hgt_ds_right = hgt_ds_right.sel(level = param.hgt_level)
            
            hgt_ds = xr.concat([hgt_ds_left, hgt_ds_right], dim="lon")

        if dset == 'ERA5':
            hgt_ds = hgt_ds.copy() / 9.80665 # convert geopotential to GPH
            hgt_ds = hgt_ds.rename({'z':'hgt'})
            
        print(hgt_ds)
        lats = hgt_ds['lat']
        lons = hgt_ds['lon']

        print('removing hgt trend')
        ## calculate daily area-weighted domain average 500-hPa GPH
        area_weights = xr.broadcast(np.cos(np.deg2rad(lats)), hgt_ds, exclude = ['lat', 'time'])[0]
        hgt_domain_mean = hgt_ds['hgt'].weighted(area_weights.lat).mean(dim = ['lat', 'lon'])
        ## calculate annual domain average 500-hPa GPH to remove seasonal variability 
        hgt_domain_mean = hgt_domain_mean.groupby('time.year').mean(dim = "time").to_dataframe(name = "hgt")
        ## calculate linear trend in 500-hPa GPH
        hgt_trend = np.polyfit(hgt_domain_mean['hgt'].index.get_level_values('year'), hgt_domain_mean['hgt'], 1)
        print("Slope of hgt_trend: ", hgt_trend[0], "m per year")
        ## calculate detrended hgt
        hgt_ds['change'] = (hgt_ds.time.dt.year - 1979)*hgt_trend[0]
        hgt_ds['hgt_detrend'] = hgt_ds['hgt'] - hgt_ds['change']
        hgt_ds = hgt_ds.drop_vars('change')
        ## calculate calendar-day standardized anomalies 
        hgt_ds['mean_detrend'] = hgt_ds['hgt_detrend'].groupby('time.dayofyear').mean(dim = 'time')
        hgt_ds['sd_detrend'] = hgt_ds['hgt_detrend'].groupby('time.dayofyear').std(dim = 'time')
        hgt_ds['hgt_anom_no_trend'] = (hgt_ds['hgt_detrend'].groupby('time.dayofyear') - hgt_ds['mean_detrend']).groupby('time.dayofyear')/hgt_ds['sd_detrend']
        ## save as netcdf files
        hgt_ds['hgt_detrend'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/hgt.nc")
        print("../processed_data_"+dset+"/"+region_str+"/hgt.nc")
        hgt_ds['hgt_anom_no_trend'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/hgt_calday_anomalies.nc")
        print("../processed_data_"+dset+"/"+region_str+"/hgt_calday_anomalies.nc")



        ######################### Soil Moisture: #############################

        # read daily data, calculate daily standardized anomalies
        if isinstance(region_input_lon_bbox, slice):
            soilw_ds = xr.open_mfdataset(soilw_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                lat = region_input_lat_bbox, 
                lon = region_input_lon_bbox,  
                time = param.time_period).astype(np.float64)
            
            if dset == 'NCEP':
                soilw_ds = soilw_ds.squeeze('level')
            
        else: # read east and west hemispheres separately
            soilw_ds_left = xr.open_mfdataset(soilw_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                lat = region_input_lat_bbox, 
                lon = region_input_lon_bbox[0],  
                time = param.time_period).astype(np.float64)
            soilw_ds_right = xr.open_mfdataset(soilw_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                lat = region_input_lat_bbox, 
                lon = region_input_lon_bbox[1],  
                time = param.time_period).astype(np.float64)
            
             if dset == 'NCEP':
                soilw_ds_left = soilw_ds_left.squeeze('level')
                soilw_ds_right = soilw_ds_right.squeeze('level')
            
            soilw_ds = xr.concat([soilw_ds_left, soilw_ds_right], dim="lon")

        if dset == 'ERA5':
            soilw_ds = soilw_ds.rename({'swvl1':'soilw'})
            
        lats = soilw_ds['lat']
        lons = soilw_ds['lon']

        ## compute area-weighted mean over the entire input map
        area_weights = xr.broadcast(np.cos(np.deg2rad(lats)), soilw_ds, exclude = ['lat', 'time'])[0]
        soilw_domain_mean = soilw_ds['soilw'].weighted(area_weights.lat).mean(dim = ['lat', 'lon'])
        soilw_domain_mean = soilw_domain_mean.groupby('time.year').mean(dim = "time").to_dataframe(name = "soilw")
        ## calculate linear trend in soilw
        soilw_trend = np.polyfit(soilw_domain_mean['soilw'].index.get_level_values('year'), soilw_domain_mean['soilw'], 1)
        print("Slope of soilw_trend: ", soilw_trend[0], "per year (units: Volumetric soil moisture content fraction )")
        print("40-year trend results in total change approx: ", np.round(40*soilw_trend[0] / soilw_ds['soilw'].std(dim = 'time').mean(dim = ['lat', 'lon']).values,3)*100  ," % of standard deviation")
        ## calculate detrended soilw
        soilw_ds['change'] = (soilw_ds.time.dt.year - 1979)*soilw_trend[0]
        soilw_ds['soilw_detrend'] = soilw_ds['soilw'] - soilw_ds['change']
        soilw_ds = soilw_ds.drop_vars('change')
        ## calculate calendar-day standardized anomalies         
        soilw_ds['cal_day_mean'] = soilw_ds['soilw_detrend'].groupby('time.dayofyear').mean(dim = 'time')
        soilw_ds['cal_day_sd'] = soilw_ds['soilw_detrend'].groupby('time.dayofyear').std(dim = 'time')
        soilw_ds['soilw_daily_anom'] = (soilw_ds['soilw_detrend'].groupby('time.dayofyear') - soilw_ds['cal_day_mean']).groupby('time.dayofyear')/soilw_ds['cal_day_sd']
        ## save as netcdf files
        soilw_ds['soilw_daily_anom'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/soilw_calday_anomalies.nc")
        print("../processed_data_"+dset+"/"+region_str+"/soilw_calday_anomalies.nc")
        soilw_ds['cal_day_sd'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/soilw_calday_stdev.nc")
        print("../processed_data_"+dset+"/"+region_str+"/soilw_calday_stdev.nc")
        soilw_ds['cal_day_mean'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/soilw_calday_mean.nc")
        print("../processed_data_"+dset+"/"+region_str+"/soilw_calday_mean.nc")
        soilw_ds['soilw_detrend'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/soilw.nc")
        print("../processed_data_"+dset+"/"+region_str+"/soilw.nc")


        ######################### TMAX: #############################

        # read daily data, calculate daily standardized anomalies
        if isinstance(region_input_lon_bbox, slice):
            tmax_ds = xr.open_mfdataset(tmax_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                                        lat = region_input_lat_bbox, 
                                        lon = region_input_lon_bbox, 
                                        time = param.time_period).drop(['level'], errors='ignore').astype(np.float64)
        else: # read east and west hemispheres separately
            tmax_ds_left = xr.open_mfdataset(tmax_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                                                lat = region_input_lat_bbox, 
                                                lon = region_input_lon_bbox[0], 
                                                time = param.time_period).drop(['level'], errors='ignore').astype(np.float64)
            tmax_ds_right = xr.open_mfdataset(tmax_files, combine = 'nested', concat_dim='time', preprocess=prep_files).sel(
                                                lat = region_input_lat_bbox, 
                                                lon = region_input_lon_bbox[1], 
                                                time = param.time_period).drop(['level'], errors='ignore').astype(np.float64)
            tmax_ds = xr.concat([tmax_ds_left, tmax_ds_right], dim="lon")

        lats = tmax_ds['lat']
        lons = tmax_ds['lon']
        ## compute area-weighted mean over the entire input map
        area_weights = xr.broadcast(np.cos(np.deg2rad(lats)), tmax_ds, exclude = ['lat', 'time'])[0]
        tmax_domain_mean = tmax_ds['tmax'].weighted(area_weights.lat).mean(dim = ['lat', 'lon'])
        tmax_domain_mean = tmax_domain_mean.groupby('time.year').mean(dim = "time").to_dataframe(name = "tmax")
        ## calculate linear trend in tmax
        tmax_trend = np.polyfit(tmax_domain_mean['tmax'].index.get_level_values('year'), tmax_domain_mean['tmax'], 1)
        print("Slope of tmax_trend: ", tmax_trend[0], "per year (units: Kelvin )")
        ## calculate detrended tmax
        tmax_ds['change'] = (tmax_ds.time.dt.year - 1979)*tmax_trend[0]
        tmax_ds['tmax_detrend'] = tmax_ds['tmax'] - tmax_ds['change']
        tmax_ds = tmax_ds.drop_vars('change')
        ## save as netcdf files    
        tmax_ds['tmax_detrend'].to_netcdf("../processed_data_"+dset+"/"+region_str+"/tmax.nc")
        print("../processed_data_"+dset+"/"+region_str+"/tmax.nc")
