In [1]:
import sys, os
import matplotlib.pyplot as plt
import xarray as xr
import datetime
import numpy as np
import multiprocessing

from dateutil.relativedelta import relativedelta
from multiprocessing import Pool

In [2]:
xr.__version__

'0.16.1'

In [3]:
#import dask
#dask.config.set(scheduler="single-threaded")

# Part 1. Setup

In [4]:
workspace = "/glade/work/lvank/for_Melchior/"

# Historical
casename = "b.e21.BHIST.f09_g17.CMIP6-historical.011"

# SSP126
#casename = "b.e21.BSSP126cmip6.f09_g17.CMIP6-SSP1-2.6.102"

In [5]:
os.chdir(os.path.join(workspace, casename))

In [6]:
!echo $OMP_NUM_THREADS

1


In [8]:
#NPROC = 10
NPROC = 1

# Part 2. Functions to read 6-hourly and daily data

In [9]:
def read_ATM(var_dict):
    """
    6-hourly ATM variables 
    
    no corrections needed on these ones
    """
    ds = xr.open_mfdataset('hus/*.nc', combine='by_coords')
    var_dict['hus'] = ds.hus
    
    ds = xr.open_mfdataset('ta/*.nc', combine='by_coords')
    var_dict['ta'] = ds.ta

    ds = xr.open_mfdataset('ua/*.nc', combine='by_coords')
    var_dict['ua'] = ds.ua
    
    ds = xr.open_mfdataset('va/*.nc', combine='by_coords')
    var_dict['va'] = ds.va
    
    ds = xr.open_mfdataset('ps/*.nc', combine='by_coords')
    var_dict['ps'] = ds.ps
    
    # NOTE: TAS has daily freq
    # will need to be dealt with later!! 
    ds = xr.open_mfdataset('tas/*.nc', combine='by_coords')
    var_dict['tas'] = ds.tas
   
    #x = [len(var_dict[varname]) == 125561 for varname in ['hus', 'ta', 'ua', 'va', 'ps']]
    #assert all(x), x

    #x = [len(var_dict[varname]) == 31391 for varname in ['tas']]
    #assert all(x), x

In [10]:
def read_oceanvar(var_dict, varname, lat, lon):
    """
    Read ocean variable
    
    Ocean variables like TOS have been manually interpolated to the ATM grid using CDO. 
    The LAT LON have been messed up as a result, need to align with the ATM variables before merging. 
    
    Also: because it is POP/CICE output, the first of January at the beginning of the simulation is missing. 
    This will be fixed (interpolated) later when we process that particular month. Doing it here would require
    a large overhead (xr.concat is very slow on big datasets...)
    """
    ds = xr.open_mfdataset(f'{varname}/fv1_grid/*.nc', combine='by_coords')
    
    var = ds[varname]
    
    # WORKAROUND: since were are combining two seperate files, there is a single time step at 2065-01-01
    # which should contain 4 time steps. Drop this single time step, we'll re-fill this day later in the script. 
    """
    dropme = siconc.sel(time='2065-01-01').time.item()
    print('DROPPING ', dropme)
    siconc = siconc.drop([dropme], dim='time')
    """
    
    del var['lat']
    del var['lon']
    
    var = var.rename({'y':'lat', 'x':'lon'})
    
    var['lat'] = lat
    var['lon'] = lon
    

    #assert(len(var) == 31390)
    var_dict[varname] = var

In [11]:
def get_var_month(var_dict, mon_str, varname):
    """
    Returns Xarray DataArray with single variable for single month
    """   
    if (varname in ['tas', 'siconc', 'tos']): 
        # Daily variables, interpolate to 6-hourly
        var = get_var_month_from_daily(var_dict, mon_str, varname)
    else:
        var = var_dict[varname].sel(time=slice(mon_str,mon_str))

    l1 = mon_str == '2015-01' and varname == 'tos'
    l2 = mon_str == '2015-01' and varname == 'siconc'
    l3 = mon_str == '2065-01' and varname == 'siconc'

    if (l1 or l2 or l3):
        # variable from POP misses first day
        # Copy day 2 into day 1
        tmp = var[[0,1,2,3]].copy()
        x = [y.replace(day=1) for y in tmp.time.data]
        x_da = xr.DataArray(x, coords=[x,], dims='time')
        tmp['time'] = x_da
        #print(tmp)
        var = xr.concat([tmp, var], dim='time')
        
    return var

In [12]:
def get_var_month_from_daily(var_dict, mon_str, varname):
    """
    Special case: TAS has daily freq, and I had trouble resampling with CDO
    so upsample directly here in Python
    
    UPDATE: do this for all daily variables. Easier. 
    """
    
    foo = [int(x) for x in mon_str.split('-')] # get Y and M as integer
    mydate = datetime.datetime(*foo, 1)
    #print(mydate)
    date_next = mydate + relativedelta(months=1) # 1 month later
    #print(date_next)
    mon_str2 = f'{date_next:%Y-%m-%d}' # first day of next month, as a string
    
    return var_dict[varname].sel(time=slice(mon_str, mon_str2)).resample(time='6H').interpolate('linear').sel(time=mon_str)

In [13]:
def get_DataSet_month(var_dict, mon_str):
    """
    Returns Xarray DataSet with all variables for single month
    
    This is the main function that calls the other functions above. 
    """
    var_list = [get_var_month(var_dict, mon_str, x) for x in ['tas', 'ta', 'ua', 'va', 'ps', 'siconc', 'hus', 'tos'] ]
    
    # Check that all DataArrays have the same length
    nt = [len(x) for x in var_list]
    assert all(element == nt[0] for element in nt), nt
    
    ds_out = xr.merge(var_list)
    return ds_out

In [14]:
def set_global_attributes(ds):
    ds.attrs['description'] = "6-hourly CESM output for forcing RACMO2 RCM"
    ds.attrs['author'] = "L.vankampenhout@uu.nl"
    ds.attrs['creation_date'] = f'{datetime.datetime.now():%Y-%m-%d}'
    ds.attrs['source_script'] = "make_monthly.ipynb, archived at https://github.com/lvankampenhout/ssp126_scripts"

## some serial testing

In [15]:
# var_dict = {}

# read_ATM(var_dict)
# lat, lon = var_dict['hus'].lat, var_dict['hus'].lon
# read_oceanvar(var_dict, 'siconc', lat, lon)
# read_oceanvar(var_dict, 'tos', lat, lon)

# ds = get_DataSet_month(var_dict, '1950-01')

# Part 3. Multi-core processing

For better throughput, we create a pool of workers that can process the years simultaneously. 

On Cheyenne, these commands can be used to create a pool of 10 DAV workers: 

```
execdav -m 100G -t 8:00:00 -n 1 --cpus-per-task=10

export OMP_NUM_THREADS=10

start_jupyterLab.sh
```

Note that I've struggled quite a lot to get this working. It only works when the input datasets are created new each time a worker process is started (function do_work). Defining a global definition of the input data (dictionary `var_dict`) didn't work nicely together with `multiprocessing.Pool`, see my post here: https://github.com/pydata/xarray/issues/3781

In [16]:
import warnings
warnings.filterwarnings("ignore", message="has multiple fill values")

In [17]:
def do_work(year):
    """
    Worker function to process the data in parallel. 
    
    UPDATE: now with private var_dict because of HDF5 errors (Dask + multiprocessing don't like each other)
    """
    print('processing ', year)
    
    var_dict = {}
    read_ATM(var_dict)
    lat, lon = var_dict['hus'].lat, var_dict['hus'].lon
    read_oceanvar(var_dict, 'siconc', lat, lon)
    read_oceanvar(var_dict, 'tos', lat, lon)

#     for varname in var_dict.keys():
#         print(varname, var_dict[varname].shape)  
    
    for month in range(1,13):
        mon_str = f'{year}-{month:02d}'
        ds = get_DataSet_month(var_dict, mon_str)
        set_global_attributes(ds)
#         print(mon_str)
        ds.to_netcdf(f'monthly/{mon_str}.nc')

In [18]:
print('running', NPROC, 'cores')

running 1 cores


In [19]:
#years = range(1950,2015)
years = [2011, 2013]
with Pool(processes=NPROC) as pool:
    pool.map(do_work, years)
    #pool.map(do_cpu_work, foo)

processing  2011


  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(


processing  2013


  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(


In [None]:
# q = queue.Queue()


# def worker():
#     while True:
#         item = q.get()
#         print(f'Working on {item}')
        
#         ds = get_DataSet_month(mon_str)
#         set_global_attributes(ds)
#         ds.to_netcdf(f'monthly/{mon_str}.nc')
        
#         print(f'Finished {item}')
#         q.task_done()

# # turn-on the worker thread
# threading.Thread(target=worker, daemon=True).start()

# # send thirty task requests to the worker
# for year in range(2015,2016):
#     for month in range(1,13):
#         q.put(f'{year}-{month:02d}')
        
# print('All task requests sent\n', end='')

# # block until all tasks are done
# q.join()
# print('All work completed')