In [1]:
#translates Steve Yeager's NCL scripts into python w/dask
#Processes and bias corrects CESM-DPLE output and writes out as a netcdf4
#Will work for CAM and POP fields, 2d fields
#Should be able to (eventually) handle annual, seasonal, and monthly means
#this version does not use dask b/c throughput is the main issue, so limitation
#comes from I/O, and adding more processors does not speed things up

#I've marked in ALL CAPS places that need to be set by user
#-Liz Maroon 9/3/2018

#Updated for xarray v0.11.2, 1/23/19
#Updated for 2d or 3d POP/CAM annual means


#import packages
import xarray as xr                   #for netcdf manipulation
import numpy as np                    #for numerics
from collections import OrderedDict   #for setting netcdf attributes
from dask.distributed import Client, LocalCluster   #dask stuff
from dask_jobqueue import PBSCluster
import os                             #these last three packages used to detect username/script location
import pwd
import sys
import glob
import datetime

%matplotlib inline

The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
  s = dedents('\n' + '\n'.join(lines[first:]))


In [3]:
#SET VARIABLE INFO HERE:
VAR='HMXL' #VAR='TS'
MODEL='OCN'# MODEL='ATM' #SET HERE IF PROCESSING CAM OR POP OUTPUT - can write catches for LND/ICE later as needed
WHICHMEAN='ANN'  #For ocean monthly means, use a with-dask script #can also do any mean of consecutive months
NUMDIMS=2  #2d or 3d variable?

#WHERE ARE DPLE FILES CURRENTLY:
DPLE_DIR='/glade/scratch/kristenk/dple/netcdf'

#WHERE AND WHAT DO YOU WANT TO CALL OUTPUT FILES:
DPOUT='/glade/scratch/kristenk/dple_stuff/CESM-DP-LE.'+VAR+'.'+WHICHMEAN.lower()+'mean'#.nc'


Ideas for troubleshooting pesky, unexplained dask issues like killed or dropped workers, dangling streams
1. Try increasing amount of memory in PBSCluster call
2. Increase core/processes ratio in PBSCluster call
3. Increase/decrease chunk size in xr.mf_dataset call

In [4]:
#cluster = LocalCluster(n_workers=2,threads_per_worker=1)  #for geyser with 2 cores
#cluster = LocalCluster(n_workers=36,threads_per_worker=1,diagnostics_port=8861)  #for 1 cheyenne node with 36 cores
#cluster = LocalCluster(n_workers=72,threads_per_worker=1)  #for 2 cheyenne node, total 72 cores

#for geyser cluster
#cluster = SLURMCluster

#for cheyenne cluster
numnodes=4
if NUMDIMS==3: 
    memory='50GB'
else: memory='60GB'

cluster=PBSCluster(cores=36,processes=9,memory=memory,project='P93300670',queue='regular',walltime='01:30:00')
#cluster.scale(1)

client = Client(cluster)

cluster.scale(numnodes*9)

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


In [5]:
#client.restart()
client

0,1
Client  Scheduler: tcp://10.148.10.15:39133  Dashboard: http://10.148.10.15:38185/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [7]:
#Make array for start years
first_syear=1954
last_syear=2015
S=xr.DataArray(np.arange(first_syear+1,last_syear+1.5,1,dtype='int'),dims=['S'],coords={'S':np.arange(first_syear+1,last_syear+1.5,1,dtype='int')},name='S')


In [8]:
#fix for time bounds #changed zcoord to z_t_150m!!!
if MODEL=='OCN': tbname='time_bound'; zcoord='z_t'; latcoord='nlat'; loncoord='nlon'
elif MODEL=='ATM': tbname='time_bnds'; zcoord='plev'; latcoord='lat'; loncoord='lon'


#function for seasonal means - come back to later
def seamean(dataarray,sea):
    monthord='JFMAMJJASOND'
    #pri
    #mon1=sea[0]
    for m2 in monthord:
        if m1==m2: break
    #mon2=sea[1]
    #if 
monthord='JFMAMJJASOND'
for m2 in monthord:
    if 

In [9]:
%%time

oceancoords=['TLAT','TLONG']
da_list=[]

#for year in S[0:5].values-1:
def makeyear(year):
    print(year)
    loadthesefiles=sorted(glob.glob(f"{DPLE_DIR}/b.e11.BDP.f09_g16.{year}-11.*.nc"))
    if NUMDIMS==3:
        ds=xr.open_mfdataset(loadthesefiles,concat_dim="M", chunks={"time": 122,zcoord:1})
    else: ds=xr.open_mfdataset(loadthesefiles,concat_dim="M", chunks={"time": 24})
    ds['time'].values=ds[tbname][0,:,0].values
    
    ds=ds.assign_coords(M=np.arange(1,len(loadthesefiles)+1,1,dtype='int'))
    
    #fix for ocean coordinate size xarray bug
    if MODEL=='OCN': 
        for cc in oceancoords:
            if len(ds[cc].shape)>2:
                tempc=ds[cc][0,:,:].load()
                tempd=tempc.dims
                ds=ds.drop(cc)
                ds=ds.assign_coords(cc=xr.DataArray(tempc.values,coords=[ds[tt].values for tt in tempd],dims=tempd))
                ds=ds.rename({'cc':cc})
    
    da=ds[VAR]
    
    if 1 in np.shape(ds[VAR]): zdim_sq=True
    else: zdim_sq=False
    
    #removing z_t if present for 2d variable :
    if (MODEL=='OCN') and ('z_t' in da.dims) and zdim_sq:
        da=da.isel(z_t=0)
        da=da.drop('z_t')
#     elif ('z_t_150m' in da.dims):
#         da=da.isel(z_t_150m=0)
#         da=da.drop('z_t_150m')
#commented out lines above!! kk
     
    #get rid of any remaining unwanted dims/coords - want to have left only dims of time, M, lat, lon
    for dd in da.coords:
        if (dd not in oceancoords) and (dd not in da.dims): da=da.drop(dd)

    #time mean
    if WHICHMEAN=='ANN':
        da=da.groupby('time.year').mean('time').isel(year=slice(1,11))  
        #Really should be weighting by months - not doing this for now so can check against Steve's scripts
        da=da.rename({'year':'L'})  
        da['L'].values=np.arange(1,11,1,dtype='int')
                
    #add here later for seasonal means
    
    return da
    
    #da_list.append(da)
da_list=client.map(makeyear,list(S.values-1))
da_list=client.gather(da_list)


#FOR SST 5S
#2 geyser cores: 1 min 19s
#36 cheyenne cores: 14 s #50 s
#72 cheyenne cores: 40 s
#144 cheyenne cores: 30 s

#FOR SST all S
#144 cheyenne cores: 1 min 9 s

#FOR TEMP 5S
#72 cheyenne cores : 1 min 25s

#FOR TEMP all S
#144 cheyenne cores (36/9, 24,1 chunk): 1 min 54 s
#144 cheyenne cores (36/9, 50,1 chunk): 1 min 34 s
#144 cheyenne cores (36/9, 122,1 chunk): 1 min 30s


#FOR TREFHT 5S
#144 cheyenne cores: 18 s

#FOR TREFHT all S with 144 cores, 36 workers: 17-27 s
#144 cheyenne cores: 25 s

CPU times: user 35 s, sys: 3.26 s, total: 38.2 s
Wall time: 1min 4s


In [10]:
#ds

In [11]:
%%time

#get attributes
year=1954
loadthesefiles=sorted(glob.glob(f"{DPLE_DIR}/b.e11.BDP.f09_g16.{year}-11.*.nc"))
ds=xr.open_mfdataset(loadthesefiles,concat_dim="M", chunks={"time": 1})

ncattrs=ds.attrs
varattrs=ds[VAR].attrs
dimattrs={}
for dd in ds.dims:
    dimattrs[dd]=ds[dd].attrs
for cc in ds.coords:
    dimattrs[cc]=ds[cc].attrs


CPU times: user 6.79 s, sys: 886 ms, total: 7.68 s
Wall time: 18.3 s


In [12]:
%%time

#concatenate list into array by dimension S
array=xr.concat(da_list,dim=S)

#FOR TEMP all S, 144 cheyenne cores (36/9, chunks 20,1): 31 s
#FOR TEMP all S, 144 cheyenne cores (36/9, chunks 122,1): 30s

CPU times: user 325 ms, sys: 7.7 ms, total: 333 ms
Wall time: 339 ms


In [13]:
if NUMDIMS==2:
    array=array.transpose('S','L','M',latcoord,loncoord)
elif NUMDIMS==3:
    array=array.transpose('S','L','M',zcoord,latcoord,loncoord)

#fast step

In [14]:
%%time

#preparing to turn the DataArray back into a DataSet (so it can be written out as a netcdf)
array.name=VAR
array.attrs=varattrs
dimattrs['S']=OrderedDict([('long_name','start year')])
dimattrs['L']=OrderedDict([('long_name','lead year')])
dimattrs['M']=OrderedDict([('long_name','ensemble member')])

array.attrs=varattrs
for cc in array.coords:
    array[cc].attrs=dimattrs[cc]

#turning DataArray into DataSet and adding ncattrs

newds=array.to_dataset()

newds.attrs=ncattrs
newds.attrs['script']=os.path.basename(sys.argv[0])
now=datetime.datetime.now()
newds.attrs['history']='created by '+pwd.getpwuid(os.getuid()).pw_name+' on '+str(now)        

#fast step

CPU times: user 545 µs, sys: 108 µs, total: 653 µs
Wall time: 757 µs


In [15]:
%%time

climy0=1964
climy1=2014

vtime=newds['S']+0.5+newds['L']-1
vtime.values[~((vtime.values>climy0) & (vtime.values<(climy1+1)))]=np.nan
vtime.values[~np.isnan(vtime.values)]=1

CPU times: user 0 ns, sys: 2.89 ms, total: 2.89 ms
Wall time: 30.2 ms


In [16]:
def writeallout(ds,fstartname):
    ds.to_netcdf(fstartname+'.nc',engine='netcdf4')
    ensmean=ds[VAR].mean('M')
    drift=(ensmean*vtime).mean('S')
    biascorr=ds[VAR]-drift
    
    ds_drift=drift.to_dataset(name='drift')
    ds_drift.attrs['climatology']=str(climy0)+"-"+str(climy1)+", computed separately for each lead time"
    ds_drift.to_netcdf(fstartname+'.drift.nc',engine='netcdf4')
    
    ds_anom=biascorr.to_dataset(name='anom')
    ds_anom.attrs['climatology']=str(climy0)+"-"+str(climy1)+", computed separately for each lead time"
    ds_anom.to_netcdf(fstartname+'.anom.nc',engine='netcdf4')    

In [17]:
%%time

if NUMDIMS==2: 
    newds.load()
    writeallout(newds,DPOUT)
elif NUMDIMS==3:
    numlevs=len(newds[zcoord])
    #for zz in range(numlevs)[0:2]:
    for zz in range(numlevs)[3:16]:
        print(zz)
        onelevelds=newds.isel({zcoord:zz})
        onelevelds.load()
        strlev=str(zz+1).zfill(2)
        fnamebeg=DPOUT+'.LEV'+strlev
        writeallout(onelevelds,fnamebeg)
        del onelevelds
        
#cheyenne 144 cores for 5S SST: 59s
#cheyenne 144 cores for all S SST: 3min 23s

#cheyenne 144 cores for 5S TREFHT: 50s
#cheyenne 144 cores for 12 S TREFHT: 2 min 27s
#cheyenne 144 cores for all S TREFHT: 2min 51s

#cheyenne 144 cores for 2 layers of TEMP all S (36/9, 122,1 chunks): 1h 7min 17s

  return np.nanmean(a, axis=axis, dtype=dtype)


CPU times: user 2min 39s, sys: 1min 8s, total: 3min 47s
Wall time: 3min 36s


In [18]:
client

0,1
Client  Scheduler: tcp://10.148.10.15:39133  Dashboard: http://10.148.10.15:38185/status,Cluster  Workers: 36  Cores: 144  Memory: 240.12 GB


In [19]:
client.close()
cluster.close()