In [1]:
import numpy as np
import xarray as xr
import xesmf as xe
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [2]:
def getFields():
    tfiles='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.sea.surface.temperature.0.75.x.0.75.1979-2018.nc'
    vfiles='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.v.*.nc'

    print("READING Z")
    
    zfiles1='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.geoheight.19*.nc'
    zfiles2='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.geoheight.2000s.nc'
    zfiles3='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.geoheight.2010s.nc'
    zfiles4='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.geoheight.2019s.nc'

    # Z500 and Z850
    ds_z1=xr.open_mfdataset(zfiles1,combine='by_coords').sel(level=slice(100,1000))
    ds_z2=xr.open_dataset(zfiles2).sel(level=slice(100,1000))
    ds_z3=xr.open_dataset(zfiles3).sel(level=slice(100,1000))
    ds_z4=xr.open_dataset(zfiles4)
    ds_z=xr.combine_by_coords([ds_z1,ds_z2,ds_z3,ds_z4])
    ds_z500=ds_z.sel(level=500).rename({'z':'z500'})
    ds_z850=ds_z.sel(level=850).rename({'z':'z850'})
    print(ds_z500)
    
    print("READING U")
    # U200 and U850
    
    ufiles1='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.u.19*.nc'
    ufiles2='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.u.2000s.nc'
    ufiles3='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.u.2010s.nc'
    ufiles4='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.u.2019s.nc'

    ds_u1=xr.open_mfdataset(ufiles1,combine='by_coords').sel(level=slice(100,1000))
    ds_u2=xr.open_dataset(ufiles2).sel(level=slice(100,1000))
    ds_u3=xr.open_dataset(ufiles3).sel(level=slice(100,1000))
    ds_u4=xr.open_dataset(ufiles4)
    ds_u=xr.combine_by_coords([ds_u1,ds_u2,ds_u3,ds_u4],combine_attrs='override')

    ds_u200=ds_u.sel(level=200).rename({'u':'u200'})
    ds_u850=ds_u.sel(level=850).rename({'u':'u850'})

    print("READING V")
    # V200 and V850
    vfiles1='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.v.19*.nc'
    vfiles2='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.v.2000s.nc'
    vfiles3='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.v.2010s.nc'
    vfiles4='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.v.2019s.nc'

    ds_v1=xr.open_mfdataset(vfiles1,combine='by_coords').sel(level=slice(100,1000))
    ds_v2=xr.open_dataset(vfiles2).sel(level=slice(100,1000))
    ds_v3=xr.open_dataset(vfiles3).sel(level=slice(100,1000))
    ds_v4=xr.open_dataset(vfiles4)
    ds_v=xr.combine_by_coords([ds_v1,ds_v2,ds_v3,ds_v4],combine_attrs='override')

    ds_v200=ds_v.sel(level=200).rename({'v':'v200'})
    ds_v850=ds_v.sel(level=850).rename({'v':'v850'})

    print("READING SST")
    # SST
    tfile1='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.sea.surface.temperature.0.75.x.0.75.1979-2018.nc'
    tfile2='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.sea.surface.temperature.0.75.x.0.75.2019s.nc'

    ds_sst1=xr.open_dataset(tfile1)
    ds_sst2=xr.open_dataset(tfile2)
    ds_sst2=ds_sst2.sel(time=slice('2019-01-01','2019-09-01'))
    ds_sst=xr.combine_by_coords([ds_sst1,ds_sst2],combine_attrs='override')
    ds_sst=ds_sst.sel(time=ds_sst['time.hour']==0)

    print("READING OLR")
    # OLR
    olrfiles1='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.top.net.thermal.radiation.19*s.nc'
    olrfiles2='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.top.net.thermal.radiation.20*s.nc'
    ds_olr1=xr.open_mfdataset(olrfiles1,combine='by_coords',combine_attrs='drop_conflicts',chunks={'time':-1})
    ds_olr1['time']=ds_olr1['time'].dt.floor('d')
    ds_olr2=xr.open_mfdataset(olrfiles2,combine='by_coords',chunks={'time':-1})
    ds_olr=xr.combine_by_coords([ds_olr1,ds_olr2],combine_attrs='drop_conflicts').rename({'ttr':'olr'})
    ds_olr['olr']=ds_olr['olr']*-1
    
    print("COARSE GRAIN")
    # Coarse Grain Data
    new_lon=np.arange(0,360,5)
    new_lat=np.arange(-90,92,5)

    ds_out = xr.Dataset({'lat': (['lat'], new_lat),
                         'lon': (['lon'], new_lon)})
    
    regridder = xe.Regridder(ds_z500,ds_out,'bilinear',periodic=True)
    ds_z500=regridder(ds_z500.chunk({'lon':-1}))
    ds_z850=regridder(ds_z850.chunk({'lon':-1}))
    ds_olr=regridder(ds_olr.chunk({'lon':-1}))
    ds_u200=regridder(ds_u200.chunk({'lon':-1}))
    ds_u850=regridder(ds_u850.chunk({'lon':-1}))
    ds_v200=regridder(ds_v200.chunk({'lon':-1}))
    ds_v850=regridder(ds_v850.chunk({'lon':-1}))

    regridder_sst = xe.Regridder(ds_sst,ds_out,'bilinear',periodic=True)
    ds_sst=regridder_sst(ds_sst)
    
    print("MERGE")
    ds=xr.merge([ds_sst,ds_z500,ds_z850,ds_olr,ds_u850,ds_v850,ds_u200,ds_v200],compat='override')
     
    print("ANOMS")
    ds_anoms=ds.groupby('time.dayofyear')-ds.groupby('time.dayofyear').mean()
    
    return ds_anoms

In [3]:
confidence_thresh=0.80
r='lrp.alpha_1_beta_0'
mean_dims=['time','model']

#### Get data for composites

In [4]:
ds_anoms=getFields()
ds_anoms

READING Z
<xarray.Dataset>
Dimensions:  (time: 14853, lat: 73, lon: 144)
Coordinates:
  * time     (time) datetime64[ns] 1979-01-01 1979-01-02 ... 2019-08-31
    level    float32 500.0
  * lat      (lat) float32 90.0 87.5 85.0 82.5 80.0 ... -82.5 -85.0 -87.5 -90.0
  * lon      (lon) float32 0.0 2.5 5.0 7.5 10.0 ... 350.0 352.5 355.0 357.5
Data variables:
    z500     (time, lat, lon) float32 dask.array<chunksize=(365, 73, 72), meta=np.ndarray>
Attributes:
    Conventions:  None
    source_file:  /homes/khuang8/obs/reanalysis/era-interim/ncformat//homes/k...
    title:        Daily mean Geopotential Height from Era-Interim
READING U
READING V
READING SST
READING OLR
COARSE GRAIN
MERGE
ANOMS


Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 20.81 kiB Shape (14854, 37, 72) (2, 37, 72) Count 53821 Tasks 14853 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 20.81 kiB Shape (14854, 37, 72) (2, 37, 72) Count 53821 Tasks 14853 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,10.41 kiB
Shape,"(14854, 37, 72)","(1, 37, 72)"
Count,26241 Tasks,14854 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 10.41 kiB Shape (14854, 37, 72) (1, 37, 72) Count 26241 Tasks 14854 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,10.41 kiB
Shape,"(14854, 37, 72)","(1, 37, 72)"
Count,26241 Tasks,14854 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 20.81 kiB Shape (14854, 37, 72) (2, 37, 72) Count 53821 Tasks 14853 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 20.81 kiB Shape (14854, 37, 72) (2, 37, 72) Count 53821 Tasks 14853 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 20.81 kiB Shape (14854, 37, 72) (2, 37, 72) Count 53821 Tasks 14853 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 150.95 MiB 20.81 kiB Shape (14854, 37, 72) (2, 37, 72) Count 53821 Tasks 14853 Chunks Type float32 numpy.ndarray",72  37  14854,

Unnamed: 0,Array,Chunk
Bytes,150.95 MiB,20.81 kiB
Shape,"(14854, 37, 72)","(2, 37, 72)"
Count,53821 Tasks,14853 Chunks
Type,float32,numpy.ndarray


### Calculate Composites of LRP and Anomalies (Combine Train-Val,Test)

In [5]:
for seas in ['DJF']: 
#for seas in ['DJF','JJA']: 

    print(seas)
    ds_list=[]
    
    # Combine Train-Val and Test Data
    for tt in ['Train-Val','Test']:
        fname='../data/cnn_test/model_validate.'+tt+'.'+seas+'.cnn_cat.*.nc'
        ds=xr.open_mfdataset(fname,combine='nested',concat_dim='model')
        ds_list.append(ds)
    ds_cnn_val=xr.combine_by_coords(ds_list).sel(rules=r).drop(['rules'])

    # Composite LRP
    tmp1=ds_cnn_val.where(np.logical_and(ds_cnn_val['pred']==ds_cnn_val['verif'],
                                        ds_cnn_val['probs']>=confidence_thresh))
    tmp1.chunk({'time':10,'model':10,'cat':2,'lat':37,'lon':72})
    ds_lrp=tmp1.mean(dim=['model','time']).compute()
    
    print("Writing LRP Data")
    ds_lrp.to_netcdf('../data/cnn_test/model_lrpcomp.'+seas+'.cnn_cat.nc')

    del ds_lrp
    del tmp1
    
    # Composite Fields
    ds_comp_vars=[]
    for v in list(ds_anoms.keys()):
        print(v)
        tmp2=ds_anoms[v].where(np.logical_and(ds_cnn_val['pred']==ds_cnn_val['verif'],
                                              ds_cnn_val['probs']>=confidence_thresh))
        tmp3=tmp2.sel(model=0).mean(dim=['time'],skipna=True)
        del tmp2
        
        ds_comp_vars.append(tmp3.to_dataset(name=v))
        
    
    ds_comp=xr.merge(ds_comp_vars)
    
    del ds_comp_vars
    
    print("Writing Comp Data")
    ds_comp.to_netcdf('../data/cnn_test/model_anomscomp.'+seas+'.cnn_cat.nc')

DJF
Writing LRP Data
sst
z500
z850
olr
u850
v850
u200
v200
Writing Comp Data


for seas in ['JJA']: 
#for seas in ['DJF','JJA']: 

    print(seas)
    ds_list=[]
    # Combine Train-Val and Test Data
    for tt in ['Train-Val','Test']:
        fname='../data/cnn_test/model_validate.'+tt+'.'+seas+'.cnn_cat.*.nc'
        ds=xr.open_mfdataset(fname,combine='nested',concat_dim='model')
        ds_list.append(ds)
    ds_cnn_val=xr.combine_by_coords(ds_list).sel(rules=r).drop(['rules'])

    # Composite LRP
    tmp1=ds_cnn_val.where(np.logical_and(ds_cnn_val['pred']==ds_cnn_val['verif'],
                                        ds_cnn_val['probs']>=confidence_thresh))
    
    # Count number of True, Confident, Positive and Negative
    #pos_count=np.count_nonzero(~np.isnan(tmp1['pred'].sel(cat='Positive')))
    #neg_count=np.count_nonzero(~np.isnan(tmp1['pred'].sel(cat='Negative')))

    #print("True, Confident(>=80), Positive: ",pos_count)
    #print("True, Confident(>=80), Negative: ",neg_count)
