In [2]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import glob
import sys
import dask
import pop_tools 
from functools import partial
import matplotlib.pyplot as plt

from CASutils import lensread_utils as lens

pathout="/glade/scratch/islas/python/singleforcing/DATA_SORT/cesm2_le/"
topdir="/glade/campaign/cgd/cesm/CESM2-LE/"

memstr = lens.lens2memnamegen_second50(50)

In [3]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

cluster = PBSCluster(
    cores = 1,
    memory = '20GB',
    processes = 1,
    queue = 'casper',
    local_directory = '$TMPDIR',
    resource_spec = 'select=1:ncpus=1:mem=20GB',
    project='P04010022',
    walltime='06:00:00',
    interface='ib0')

# scale up
cluster.scale(40)

# change your urls to the dask dashboard so that you can see it
dask.config.set({'distributed.dashboard.link':'https://jupyterhub.hpc.ucar.edu/stable/user/{USER}/proxy/{port}/status'})

# Setup your client
client = Client(cluster)

In [5]:
client

0,1
Client  Scheduler: tcp://10.12.206.54:34132  Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/islas/proxy/8787/status,Cluster  Workers: 40  Cores: 40  Memory: 800.00 GB


In [6]:
def fixtime(ds):
    timebndavg = np.array(ds.time_bound,
        dtype='datetime64[s]').view('i8').mean(axis=1).astype('datetime64[s]')
    ds['time'] = timebndavg
    return ds

In [7]:
def vertintegrate(ds, dz):
    ds = ds.isel(z_t=slice(0,20))
    dz = dz.isel(z_t=slice(0,20))
    dsz = (ds*dz).sum('z_t')/dz.sum('z_t')
    return dsz

In [8]:
def labseaavg(ds, tarea):
    wgts = tarea
    wgts = xr.where( (wgts.TLONG > 300) & (wgts.TLONG < 325), wgts, 0)
    wgts = xr.where( (wgts.TLAT > 50) & (wgts.TLAT < 65), wgts, 0)
    
    ds_w = ds.weighted(wgts.fillna(0))
    dsm = ds_w.mean(("nlon","nlat"))
    return dsm

In [9]:
rho_all=[]
salt_all=[]
temp_all=[]

filelist_salt = [sorted(glob.glob(topdir+"/ocn/proc/tseries/month_1/SALT/*"+imem+"*.nc")) for imem in memstr ]
filelist_temp = [sorted(glob.glob(topdir+"/ocn/proc/tseries/month_1/TEMP/*"+imem+"*.nc")) for imem in memstr ]
filelist_rho = [sorted(glob.glob(topdir+"/ocn/proc/tseries/month_1/PD/*"+imem+"*.nc")) for imem in memstr ]

for imem in np.arange(0,len(filelist_salt),1):
    print(imem)
        
    dat = xr.open_mfdataset(filelist_salt[imem], concat_dim='time', parallel=True, chunks={"time":20, "z_t":60, "nlat":120, "nlon":120})
    dat = fixtime(dat)
    dat = dat.sel(time=slice("1920-01","2050-12"))
    salt = dat.SALT ; tarea = dat.TAREA ;dz = dat.dz
    saltz = vertintegrate(salt, dz)
    salt_lab = labseaavg(saltz, tarea)
    
    dat = xr.open_mfdataset(filelist_temp[imem], concat_dim='time', parallel=True, chunks={"time":20, "z_t":60, "nlat":120, "nlon":120})
    dat = fixtime(dat)
    dat = dat.sel(time=slice("1920-01","2050-12"))
    temp = dat.TEMP ; tarea = dat.TAREA ;dz = dat.dz
    tempz = vertintegrate(temp, dz)
    temp_lab = labseaavg(tempz, tarea)
    
    dat = xr.open_mfdataset(filelist_rho[imem], concat_dim='time', parallel=True, chunks={"time":20, "z_t":60, "nlat":120, "nlon":120})
    dat = fixtime(dat)
    dat = dat.sel(time=slice("1920-01","2050-12"))
    rho = dat.PD ; tarea = dat.TAREA ;dz = dat.dz
    rhoz = vertintegrate(rho, dz)
    rho_lab = labseaavg(rhoz, tarea)
    
    ref_depth = xr.DataArray(np.zeros(np.shape(temp_lab)), dims=temp_lab.dims, coords=temp_lab.coords) + 101.5 # because I'm using top 203m
    rho,drhods,drhodt = pop_tools.eos(salt=salt_lab, temp=temp_lab, return_coefs=True,depth=ref_depth)
    
    temp_anom = temp_lab - temp_lab.mean('time')
    salt_anom = salt_lab - salt_lab.mean('time')
    rho_anom = rho - rho.mean('time')
    
    rho_temp = drhodt*temp_anom
    rho_salt = drhods*salt_anom/1000. # not sure why, but this is necessary
    
    rho_anom = rho_anom.rename('RHO').load()
    rho_temp = rho_temp.rename('RHO_TEMP').load()
    rho_salt = rho_salt.rename('RHO_SALT').load()
    
    rho_all.append(rho_anom)
    salt_all.append(rho_salt)
    temp_all.append(rho_temp)
    
rho_all = xr.concat(rho_all, dim='M')
salt_all = xr.concat(salt_all, dim='M')
temp_all = xr.concat(temp_all, dim='M')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


In [None]:
rho_all.to_netcdf(pathout+'RHO_components_LENS2.nc')
salt_all.to_netcdf(pathout+'RHO_components_LENS2.nc', mode='a')
temp_all.to_netcdf(pathout+'RHO_components_LENS2.nc', mode='a')

In [None]:
cluster.close()
