In [1]:
from xarray.core.dataarray import DataArray
import metpy.constants as mpconsts
import numpy as np
from metpy import calc as mpcalc
from metpy.interpolate import log_interpolate_1d as log_interp_1d
from metpy.units import units
import os
from pandas.core.series import Series
from xarray.core.dataarray import DataArray
from tqdm import tqdm


import os
import glob
import numpy as np 
import xarray as xr
import pandas as pd
import datetime
from datetime import date, timedelta
import dask
import metpy as mp
from metpy.units import units
import metpy.calc as mpc
import Ngl

# Plotting utils 
import matplotlib
import matplotlib.pyplot as plt 
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.util

In [None]:
# Grabbed from Brian M. to use time midpoints, not end periods
def cesm_correct_time(ds):
    """Given a Dataset, check for time_bnds,
       and use avg(time_bnds) to replace the time coordinate.
       Purpose is to center the timestamp on the averaging inverval.   
       NOTE: ds should have been loaded using `decode_times=False`
    """
    assert 'time_bnds' in ds
    assert 'time' in ds
    correct_time_values = ds['time_bnds'].mean(dim='nbnd')
    # copy any metadata:
    correct_time_values.attrs = ds['time'].attrs
    ds = ds.assign_coords({"time": correct_time_values})
    ds = xr.decode_cf(ds)  # decode to datetime objects
    return ds

## Function to regrid data from ADF
def regrid_data(fromthis, tothis, method=1):
    """Regrid data using various different methods"""

    #Import necessary modules:
    import xarray as xr

    if method == 1:
        # kludgy: spatial regridding only, seems like can't automatically deal with time
        if 'time' in fromthis.coords:
            result = [fromthis.isel(time=t).interp_like(tothis) for t,time in enumerate(fromthis['time'])]
            result = xr.concat(result, 'time')
            return result
        else:
            return fromthis.interp_like(tothis)
    elif method == 2:
        newlat = tothis['lat']
        newlon = tothis['lon']
        coords = dict(fromthis.coords)
        coords['lat'] = newlat
        coords['lon'] = newlon
        return fromthis.interp(coords)
    elif method == 3:
        newlat = tothis['lat']
        newlon = tothis['lon']
        ds_out = xr.Dataset({'lat': newlat, 'lon': newlon})
        regridder = xe.Regridder(fromthis, ds_out, 'bilinear')
        return regridder(fromthis)
    elif method==4:
        # geocat
        newlat = tothis['lat']
        newlon = tothis['lon']
        result = geocat.comp.linint2(fromthis, newlon, newlat, False)
        result.name = fromthis.name
        
        return result

# - - - - - - - - - - - - - - - 
# Pre-process data while reading in 
# - - - - - - - - - - - - - - - 

def preprocess_atm(ds):
    ds         = cesm_correct_time(ds)
    dsSel      = ds[['SHFLX','LHFLX','PRECT']]
    # Also select spatial range
    dsSel = dsSel.sel(lat=slice(10,50), lon=slice(190,310))
    return dsSel

def preprocess_atm_profiles(ds):
    ds         = cesm_correct_time(ds)
    dsSel      = ds[['T','Q','PS','hyam','hyai','hybm','hybi','P0']]
    # Also select spatial range
    dsSel = dsSel.sel(lat=slice(10,50), lon=slice(190,310), lev=slice(500,1000))
    return dsSel
    
def interpolateToPressure(DS, varName, pressGoals):
    p0mb = DS.P0.values/100        # mb

    # Pull out hya/hyb profiles 
    hyam = np.squeeze(DS.hyam.values)[:]
    hybm = np.squeeze(DS.hybm.values)[:]
    hyai = np.squeeze(DS.hyai.values)[:]
    hybi = np.squeeze(DS.hybi.values)[:]

    # Surface pressure with time dimension
    PS   = DS.PS.values              # Pa 

    # Converting variables: 
    varInterp = Ngl.vinth2p(DS[varName].values,hyam,hybm,pressGoals,PS,1,p0mb,1,True)
    
    return varInterp



In [10]:
import dask
dask.config.set({'logging.distributed': 'error'})

from dask_jobqueue import PBSCluster

# For Casper
cluster = PBSCluster(
    queue="casper",
    walltime="02:00:00",
    account="P93300042",
    memory="4GB",
    resource_spec="select=1:ncpus=2:mem=4GB",
    cores=1,
    processes=1,
)

from dask.distributed import Client

# Connect client to the remote dask workers
client = Client(cluster)
print(client)

cluster.scale(8)

client.wait_for_workers(8)


Perhaps you already have a cluster running?
Hosting the HTTP server on port 33395 instead


<Client: 'tcp://128.117.208.85:38043' processes=0 threads=0, memory=0 B>


## Get CESM data

In [5]:
caseNames = [
            'f.e21.F2000climo.f09_f09_mg17.S2S_LandAtmCoupling_output.002',
            'f.e21.F2000climo.f09_f09_mg17.ReduceDSL_0p8to0p5.S2S_LandAtmCoupling_output.002',
            # 'i.e21.I2000Clm50Sp.f09_f09_mg17.S2S_LandAtmCoupling_output.002',
            # 'f.e21.FHIST.f09_f09_mg17.S2S_LandAtmCoupling_output.002',
           ]

case_IDs = [
            'F2000climo_ctrl',
            'F2000climo_dsl0p5',
            # 'I2000Clm50Sp',
            # 'FHIST_ctrl', 
            ]

dataDir = '/glade/campaign/cgd/tss/people/mdfowler/LandAtmCoupling_longRuns/'


In [6]:
%%time 

iCase=0

# for iCase in range(len(case_IDs)):
print('*** Starting on case %s ***' % (case_IDs[iCase]))

## Select files with daily means (h1 files in this case)
listFiles_atm_lev = np.sort(glob.glob(dataDir+caseNames[iCase]+'/atm/hist/*cam.h3.????-*'))


DS_case0 = xr.open_mfdataset(listFiles_atm_lev,  preprocess=preprocess_atm_profiles, concat_dim='time', 
                            combine='nested', decode_times=False, 
                            data_vars='minimal', parallel=True)

print('atm 3D files loaded')


*** Starting on case F2000climo_ctrl ***
atm 3D files loaded
CPU times: user 28.9 s, sys: 5.96 s, total: 34.8 s
Wall time: 5min 1s


In [7]:
## Get landfrac
h0_files = np.sort(glob.glob(dataDir+caseNames[iCase]+'/atm/hist/*cam.h0.????-*'))
landfrac = xr.open_dataset(h0_files[1]).LANDFRAC
landfrac = landfrac.isel(time=0).sel(lat=slice(10,50), lon=slice(190,310))

landMask = np.ones([len(landfrac.lat.values), len(landfrac.lon.values)])
landMask[np.squeeze(landfrac.values)<=0.45] = np.nan


In [8]:
# Drop first year and only keep JJA 
iTimes = np.where( (DS_case0['time.year'].values>=(DS_case0['time.year'].values[0]+1))  & 
                   (DS_case0['time.month'].values>=6) & (DS_case0['time.month'].values<=8) 
                  )[0]

DS_case0 = DS_case0.isel(time=iTimes)

In [14]:
selHr = 13

iHours = np.where( (DS_case0['time.hour'].values==13) )[0]
len(iHours)

DS_case0 = DS_case0.isel(time=iHours)


In [15]:
P_interp = np.append(np.arange(550, 900, 50), np.arange(900,1000, 25))
# P_interp = np.append(P_interp, 990)

T_interp_temp = interpolateToPressure(DS_case0, 'T', P_interp)

T_interp = xr.DataArray(T_interp_temp, 
    coords={
            'time': DS_case0.time.values,
            'lev':  P_interp,
            'lat':  DS_case0.lat.values, 
            'lon':  DS_case0.lon.values}, 
    dims=["time", "lev", "lat", "lon"])

T_invertLev = T_interp.isel(lev=slice(None, None, -1))


In [12]:
P_interp = np.append(np.arange(550, 900, 50), np.arange(900,1000, 25))
# P_interp = np.append(P_interp, 990)

Q_interp_temp = interpolateToPressure(DS_case0, 'Q', P_interp)

Q_interp = xr.DataArray(Q_interp_temp, 
    coords={
            'time': DS_case0.time.values,
            'lev':  P_interp,
            'lat':  DS_case0.lat.values, 
            'lon':  DS_case0.lon.values}, 
    dims=["time", "lev", "lat", "lon"])

Q_invertLev = Q_interp.isel(lev=slice(None, None, -1))


In [16]:
outdir = '/glade/derecho/scratch/mdfowler/S2S_processed/'

ctp_10utc = xr.open_dataset(outdir+caseNames[0]+'_CTP_JJA_10utc.nc')
ctp_13utc = xr.open_dataset(outdir+caseNames[0]+'_CTP_JJA_13utc.nc')


## Around longitude = 110, switch from 13 UTC to 10 UTC. Could also try 105, which is ~longitude of Denver (going with Denver) 
iLon = np.where( (T_invertLev.lon.values>=254) & (T_invertLev.lon.values<=256))[0] 

ctp_combined = ctp_13utc.__xarray_dataarray_variable__.values
ctp_combined[:,:,iLon[0]::] = ctp_10utc.__xarray_dataarray_variable__.values[:,:,iLon[0]::]

ctp = xr.DataArray(ctp_combined, 
    coords={
            'time': T_invertLev.time.values,
            'lat':  T_invertLev.lat.values, 
            'lon':  T_invertLev.lon.values}, 
    dims=["time", "lat", "lon"])


In [17]:
ctp_noZeros = ctp.where(ctp.values>=0)


In [18]:
## Also get humidity index piece 

dewpoint_950 = mpc.dewpoint_from_specific_humidity(Q_invertLev.sel(lev=950, method='nearest').lev.values * units.hPa, 
                                                   Q_invertLev.sel(lev=950, method='nearest').values * units(DS_case0.Q.units))

dewpoint_850 = mpc.dewpoint_from_specific_humidity(Q_invertLev.sel(lev=850, method='nearest').lev.values * units.hPa, 
                                                   Q_invertLev.sel(lev=850, method='nearest').values * units(DS_case0.Q.units))

tmpBot =  T_invertLev.sel(lev=950, method='nearest').values -273.15
tmpTop =  T_invertLev.sel(lev=850, method='nearest').values -273.15

HIlow = (tmpBot-dewpoint_950.m) + (tmpTop-dewpoint_850.m)


  val = np.log(vapor_pressure / mpconsts.nounit.sat_pressure_0c)


ValueError: operands could not be broadcast together with shapes (2300,42,97) (18400,42,97) 

In [None]:
HI_low = xr.DataArray(HIlow, 
    coords={
            'time': T_invertLev.time.values,
            'lat':  T_invertLev.lat.values, 
            'lon':  T_invertLev.lon.values}, 
    dims=["time", "lat", "lon"])


## Get ERA5 data

In [None]:
def preprocess_era5(DS):
    ## Select only hours 12-15 UTC
    # iHours = np.where( (DS['time.hour']>=12) & (DS['time.hour']<=14))[0]
    ## Select only hours 9-11 UTC
    iHours = np.where( (DS['time.hour']>=9) & (DS['time.hour']<=11))[0]
    iLevs  = np.where(DS.level.values>=500)[0]
    DS_sel = DS.isel(time=iHours, level=iLevs).resample(time='1D').mean()

    
    ## Regrid
    DS_sel  = DS_sel.rename({'longitude': 'lon','latitude': 'lat'})
    camGrid = DS_cam['PS'].isel(time=0).load().squeeze()
    regridERA = regrid_data(DS_sel, camGrid, method=1)
        
    DS_sel2 = regridERA.sel(lat=slice(15,55), lon=slice(225, 300))

        
    return DS_sel2

In [None]:
dataDir = '/glade/campaign/collections/rda/data/d633000/e5.oper.an.pl/'

years = 1995+np.arange(11)
print(years)

months = ['06','07','08']

count = 0 
for iYr in range(len(years)):
    for iMon in range(len(months)):
        listFiles_q_temp = np.sort(glob.glob(dataDir+str(years[iYr])+months[iMon]+'/*_q*.nc'))
        listFiles_t_temp = np.sort(glob.glob(dataDir+str(years[iYr])+months[iMon]+'/*_t*.nc'))

        if count==0:
            listFiles_q = listFiles_q_temp
            listFiles_t = listFiles_t_temp
        else: 
            listFiles_q = np.append(listFiles_q, listFiles_q_temp)
            listFiles_t = np.append(listFiles_t, listFiles_t_temp)

        count = count+1
            


In [None]:
%%time
q_era5 = xr.open_mfdataset(np.sort(listFiles_q), preprocess=preprocess_era5, 
                           # combine='by_coords',
                           combine='nested', concat_dim='time', 
                           decode_times=True, data_vars='minimal', parallel=True, 
                           # chunks={'time': 150},
                          )


In [None]:
%%time
t_era5 = xr.open_mfdataset(np.sort(listFiles_t), preprocess=preprocess_era5, 
                                combine='by_coords',
                               #concat_dim='time', combine='nested', 
                                 decode_times=True, data_vars='minimal', parallel=True,
                                 chunks={'time': 150},
                          )
