In [1]:
import xarray as xr
import numpy as np
import glob
import sys
import dask
import pop_tools
from functools import partial
import matplotlib.pyplot as plt

nmemsxaer=3

pathout="/glade/scratch/islas/python/singleforcing/DATA_SORT/cesm2_xaer/"

topdir="/glade/campaign/cesm/collections/CESM2-SF/timeseries/"

In [2]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

cluster = PBSCluster(
    cores = 1,
    memory = '20GB',
    processes = 1,
    queue = 'casper',
    local_directory = '$TMPDIR',
    resource_spec = 'select=1:ncpus=1:mem=20GB',
    project='P04010022',
    walltime='02:00:00',
    interface='ib0')

# scale up
cluster.scale(25)

# change your urls to the dask dashboard so that you can see it
dask.config.set({'distributed.dashboard.link':'https://jupyterhub.hpc.ucar.edu/stable/user/{USER}/proxy/{port}/status'})

# Setup your client
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 46478 instead
  http_address["port"], self.http_server.port


In [3]:
#cluster.close()

In [4]:
client

0,1
Client  Scheduler: tcp://10.12.206.54:36424  Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/islas/proxy/46478/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [5]:
def fixtime(ds):
    timebndavg = np.array(ds.time_bound,
        dtype='datetime64[s]').view('i8').mean(axis=1).astype('datetime64[s]')
    ds['time'] = timebndavg
    return ds

In [6]:
def vertintegrate(ds, dz):
    dz = dz.isel(z_t=slice(0,20))
    ds = ds.isel(z_t=slice(0,20))
    dsz = (ds*dz).sum('z_t')/dz.sum('z_t')
    return dsz

In [7]:
def labseaavg(ds, tarea):
    wgts = tarea
    wgts = xr.where( (wgts.TLONG > 300) & (wgts.TLONG < 325), wgts, 0)
    wgts = xr.where( (wgts.TLAT > 50) & (wgts.TLAT < 65), wgts, 0)
    
    ds_w = ds.weighted(wgts.fillna(0))
    dsm = ds_w.mean(("nlon","nlat"))
    return dsm

In [8]:
memstr = [ str(i).zfill(3) for i in np.arange(1,nmemsxaer+1,1)]

rho_all=[]
salt_all=[]
temp_all=[]

filelist_salt = [sorted(glob.glob(topdir+"b.e21.BHISTcmip6.f09_g17.CESM2-SF-xAER."+imem+"/ocn/proc/tseries/month_1/*."+
                 "SALT.*.nc"))+sorted(glob.glob(topdir+"b.e21.BSSP370cmip6.f09_g17.CESM2-SF-xAER."+imem+
                                                "/ocn/proc/tseries/month_1/*.SALT.*.nc")) for imem in memstr ]
filelist_temp =  [sorted(glob.glob(topdir+"b.e21.BHISTcmip6.f09_g17.CESM2-SF-xAER."+imem+"/ocn/proc/tseries/month_1/*."+
                 "TEMP.*.nc"))+sorted(glob.glob(topdir+"b.e21.BSSP370cmip6.f09_g17.CESM2-SF-xAER."+imem+
                                                "/ocn/proc/tseries/month_1/*.TEMP.*.nc")) for imem in memstr ]

for imem in np.arange(0,len(filelist_salt),1):
    print(imem)
        
    dat = xr.open_mfdataset(filelist_salt[imem], concat_dim='time', parallel=True, chunks={"time":20, "z_t":60, "nlat":120, "nlon":120})
    dat = fixtime(dat)
    salt = dat.SALT ; tarea = dat.TAREA ;dz = dat.dz
    saltz = vertintegrate(salt, dz)
    salt_lab = labseaavg(saltz, tarea)
    
    dat = xr.open_mfdataset(filelist_temp[imem], concat_dim='time', parallel=True, chunks={"time":20, "z_t":60, "nlat":120, "nlon":120})
    dat = fixtime(dat)
    temp = dat.TEMP ; tarea = dat.TAREA ;dz = dat.dz
    tempz = vertintegrate(temp, dz)
    temp_lab = labseaavg(tempz, tarea)
    
    ref_depth = xr.DataArray(np.zeros(np.shape(temp_lab)), dims=temp_lab.dims, coords=temp_lab.coords) + 101.5 # because I'm using top 203m
    rho,drhods,drhodt = pop_tools.eos(salt=salt_lab, temp=temp_lab, return_coefs=True,depth=ref_depth)
    
    temp_anom = temp_lab - temp_lab.mean('time')
    salt_anom = salt_lab - salt_lab.mean('time')
    rho_anom = rho - rho.mean('time')
    
    rho_temp = drhodt*temp_anom
    rho_salt = drhods*salt_anom/1000. # not sure why, but this is necessary
    
    rho_anom = rho_anom.rename('RHO').load()
    rho_temp = rho_temp.rename('RHO_TEMP').load()
    rho_salt = rho_salt.rename('RHO_SALT').load()
    
    rho_all.append(rho_anom)
    salt_all.append(rho_salt)
    temp_all.append(rho_temp)
    
rho_all = xr.concat(rho_all, dim='M')
salt_all = xr.concat(salt_all, dim='M')
temp_all = xr.concat(temp_all, dim='M')

0
1
2


In [9]:
rho_all.to_netcdf(pathout+'RHO_components_xAER2.nc')
salt_all.to_netcdf(pathout+'RHO_components_xAER2.nc', mode='a')
temp_all.to_netcdf(pathout+'RHO_components_xAER2.nc', mode='a')

In [10]:
cluster.close()