In [None]:
%load_ext autoreload 
%autoreload 2

In [None]:
import json
import xarray as xr
import dask
import numpy as np
from dask.distributed import Client, LocalCluster
import argparse

from bc_module_v2 import bc_module
import modules

In [None]:
#f = open('bcsd_parameter.json')
#parameter = json.load(f)
    
# Set month and year of the current forecast
#month = parameter["issue_date"]["month"]
#year = parameter["issue_date"]["year"]
    
# Convert the domain names in the parameter JSON to an array:
#domain_names = [domain_names['name'] for domain_names in parameter["domains"]]

year        = '2022'
month       = '04'
domain_name = 'Khuzestan'

file_params = {
    'version': '2.2',
    'glbldir': '/pd/data/regclim_data/gridded_data/seasonal_predictions/seas5/',
    'regroot': '/pd/data/regclim_data/gridded_data/processed/'
}

bc_params = {
    'dry_thresh': 0.01,
    'low_extrapol': "delta_additive",
    'up_extrapol': "delta_additive",
    'extremes': "weibull",
    'intermittency': True,
    'nquants': 2500,
    'window': 15
}

In [None]:
# Set all filenames for in- and output files
obs_dict, mdl_dict, pred_dict, month, bc_out_lns = modules.set_filenames(month, year, domain_name, file_params["regroot"], file_params["version"])

# Read the dimensions for the output file (current prediction)
coords = modules.get_coords_from_files(list(pred_dict.values())[0])

# Set all the metadata for the output file
global_attributes, variable_attributes = modules.set_metadata(coords, bc_params)

In [None]:
# Create an empty NetCDF in which we write the BCSD output
ds = modules.create_4d_netcdf(bc_out_lns, global_attributes, variable_attributes, coords)
    
# Load the NetCDF to get a handle for the output
ds_out = xr.open_dataset(bc_out_lns)

In [None]:
# Get some ressourcers
client, cluster = modules.getCluster('haswell', 2, 40)
 
# Do the memory magic...
client.amm.start() 
    
# Write some info about the cluster
print(cluster.dashboard_link)

In [None]:
client.close()
cluster.close()

In [None]:
    # Loop over each variable
    for variable in variable_attributes:
     

        ###### Old IO-Module #####
        # load data as dask objects
        # Obs (1981 - 2016 on daily basis)
        ds_obs = xr.open_mfdataset(obs_dict[variable], chunks={'time': 13149, 'lat': 50, 'lon': 50}, parallel=True, engine='h5netcdf')
        ds_obs = ds_obs[variable].persist()
        
        # Mdl (historical, 1981 - 2016 for one month and 215 days)  215, 36, 25, 1, 1 ;
        # Preprocess historical mdl-data, create a new time coord, which contain year and day at once and not separate
        ds_mdl = modules.preprocess_mdl_hist(mdl_dict[variable], month) # chunks={'time': 215, 'year': 36, 'ens': 25, 'lat': 1, 'lon': 1})
        ds_mdl = ds_mdl[variable].persist()
        
        # Pred (current year for one month and 215 days)
        ds_pred = xr.open_mfdataset(pred_dict[variable], chunks={'time': 215, 'ens': 51, 'lat': 50, 'lon': 50}, parallel=True, engine='h5netcdf')
        ds_pred = ds_pred[variable].persist()
        
        # Change data type of latidude and longitude, otherwise apply_u_func does not work
        ds_pred = ds_pred.assign_coords(lon=ds_pred.lon.values.astype(np.float32), lat=ds_pred.lat.values.astype(np.float32))

        # Calculate day of the year from time variable
        dayofyear_obs = ds_obs['time.dayofyear']
        dayofyear_mdl = ds_mdl['time.dayofyear']
        
        da_temp = xr.DataArray(
            None, 
            dims = ['time', 'lat', 'lon', 'ens'], 
            coords = {
                'time': ('time', coords['time'], {'standard_name': 'time', 'long_name': 'time'}),
                'ens': ('ens', coords['ens'], {'standard_name': 'realization', 'long_name': 'ensemble_member'}),
                'lat': ('lat', coords['lat'], {'standard_name': 'latitude', 'long_name': 'latitude', 'units': 'degrees_east'}),
                'lon': ('lon', coords['lon'], {'standard_name': 'longitude', 'long_name': 'longitude', 'units': 'degrees_north'})
            }
        )

        if variable == "tp":
            precip = True
        else:
            precip = False
            
        for timestep in range(0, len(ds_pred.time)):
            
            print(f'Correcting timestep {timestep}...')
    
        
            day = dayofyear_mdl[timestep]
    
            day_range = (np.arange(day - bc_params['window'], day + bc_params['window'] + 1) + 365) % 365 + 1
    
            intersection_day_obs = np.in1d(dayofyear_obs, day_range)
            intersection_day_mdl = np.in1d(dayofyear_mdl, day_range)
    
            ds_obs_sub = ds_obs.loc[dict(time=intersection_day_obs)]
    
            ds_mdl_sub = ds_mdl.loc[dict(time=intersection_day_mdl)]
        
            ds_mdl_sub = ds_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
            ds_mdl_sub = ds_mdl_sub.drop('time')
    
    
            ds_pred_sub = ds_pred.isel(time=timestep)
    
    
            pred_corr_act = xr.apply_ufunc(
                bc_module, 
                ds_pred_sub, 
                ds_obs_sub, 
                ds_mdl_sub, 
                kwargs={'bc_params': bc_params, 'precip': precip},
                input_core_dims=[["ens"], ["time"], ['ens_time']], 
                output_core_dims=[["ens"]], 
                vectorize=True, 
                dask="parallelized", 
                output_dtypes=[np.float64]) 
        
            da_temp.loc[dict(time=ds_pred.time.values[timestep])] = pred_corr_act
        
    
        ds_out[variable] = da_temp.transpose('time', 'ens', 'lat', 'lon')
        
    # We still need to write the NetCDF...