In [1]:
import numpy as np
import xarray as xr
import xesmf as xe
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [2]:
def getFields():
    tfiles='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.sea.surface.temperature.0.75.x.0.75.1979-2018.nc'
    zfiles='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.geoheight.*.nc'
    ufiles='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.u.*.nc'
    vfiles='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.v.*.nc'

    print("READING Z")
    # Z500 and Z850
    ds_z=xr.open_mfdataset(zfiles,combine='by_coords',chunks={'time':-1})
    ds_z500=ds_z.sel(level=500).rename({'z':'z500'})
    ds_z850=ds_z.sel(level=850).rename({'z':'z850'})

    print("READING U")
    # U200 and U850
    ds_u=xr.open_mfdataset(ufiles,combine='by_coords',chunks={'time':-1})
    ds_u200=ds_u.sel(level=200).rename({'u':'u200'})
    ds_u850=ds_u.sel(level=850).rename({'u':'u850'})
    
    print("READING V")
    # V200 and V850
    ds_v=xr.open_mfdataset(vfiles,combine='by_coords',chunks={'time':-1})
    ds_v200=ds_v.sel(level=200).rename({'v':'v200'})
    ds_v850=ds_v.sel(level=850).rename({'v':'v850'})

    print("READING SST")
    # SST
    ds_sst=xr.open_dataset(tfiles,chunks={'time':-1})
    ds_sst=ds_sst.sel(time=ds_sst['time.hour']==0)

    print("READING OLR")
    # OLR
    olrfiles1='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.top.net.thermal.radiation.19*s.nc'
    olrfiles2='/shared/ccsm4/khuang/obs/era-interim/era-interim.daily.top.net.thermal.radiation.20*s.nc'
    ds_olr1=xr.open_mfdataset(olrfiles1,combine='by_coords',combine_attrs='drop_conflicts',chunks={'time':-1})
    ds_olr1['time']=ds_olr1['time'].dt.floor('d')
    ds_olr2=xr.open_mfdataset(olrfiles2,combine='by_coords',chunks={'time':-1})
    ds_olr=xr.combine_by_coords([ds_olr1,ds_olr2],combine_attrs='drop_conflicts').rename({'ttr':'olr'})
    ds_olr['olr']=ds_olr['olr']*-1
    
    print("COARSE GRAIN")
    # Coarse Grain Data
    new_lon=np.arange(0,360,5)
    new_lat=np.arange(-90,92,5)

    ds_out = xr.Dataset({'lat': (['lat'], new_lat),
                         'lon': (['lon'], new_lon)})

    regridder = xe.Regridder(ds_z500,ds_out,'bilinear',periodic=True)
    ds_z500=regridder(ds_z500)
    ds_z850=regridder(ds_z850)
    ds_olr=regridder(ds_olr)
    ds_u200=regridder(ds_u200)
    ds_u850=regridder(ds_u850)
    ds_v200=regridder(ds_v200)
    ds_v850=regridder(ds_v850)


    regridder_sst = xe.Regridder(ds_sst,ds_out,'bilinear',periodic=True)
    ds_sst=regridder_sst(ds_sst)
    
    print("MERGE")
    ds=xr.merge([ds_sst,ds_z500,ds_z850,ds_olr,ds_u850,ds_v850,ds_u200,ds_v200],compat='override')
     
    print("ANOMS")
    ds_anoms=ds.groupby('time.dayofyear')-ds.groupby('time.dayofyear').mean()
    
    return ds_anoms

### Identify most skillful model in TEST data

In [3]:
confidence_thresh=0.80
r='lrp.alpha_1_beta_0'
mean_dims=['time','model']

best_list=[]
for seas in ['DJF','JJA']:
    
    cnn_val_fname='../data/cnn/model_validate.Test.'+seas+'.cnn_cat.*.nc'
#    cnn_val_fname='../data/cnn/model_validate.Train-Val.'+seas+'.cnn_cat.*.nc'

    ds_cnn_val=xr.open_mfdataset(cnn_val_fname,
                                 combine='nested',
                                 concat_dim='model').sel(rules=r)
    ibest=ds_cnn_val['acc'].argmax(dim='model').values
    print(ibest)
    best_list.append(ibest)


#### Get data for composites

In [4]:
ds_anoms=getFields()
ds_anoms

READING Z
READING U
READING V
READING SST
READING OLR
COARSE GRAIN
MERGE
ANOMS


Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,17931 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 17931 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,17931 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24563 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24563 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24568 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24568 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24568 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24563 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24563 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24563 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 148.47 MiB 10.41 kiB Shape (14610, 37, 72) (1, 37, 72) Count 24563 Tasks 14610 Chunks Type float32 numpy.ndarray",72  37  14610,

Unnamed: 0,Array,Chunk
Bytes,148.47 MiB,10.41 kiB
Shape,"(14610, 37, 72)","(1, 37, 72)"
Count,24563 Tasks,14610 Chunks
Type,float32,numpy.ndarray


### Calculate Composites of LRP and Anomalies (Combine Train-Val,Test)

In [15]:
for seas in ['DJF','JJA']:   
    print(seas)
    ds_list=[]
    
    # Combine Train-Val and Test Data
    for tt in ['Train-Val','Test']:
        fname='../data/cnn_test/model_validate.'+tt+'.'+seas+'.cnn_cat.*.nc'
        ds=xr.open_mfdataset(fname,combine='nested',concat_dim='model')
        ds_list.append(ds)
    ds_cnn_val=xr.combine_by_coords(ds_list).sel(rules=r).drop(['rules'])

    
    # Composite LRP
    tmp=ds_cnn_val.where(np.logical_and(ds_cnn_val['pred']==ds_cnn_val['verif'],
                                     ds_cnn_val['probs']>=confidence_thresh))
    tmp.chunk({'time':-1,'model':-1,'cat':2,'lat':4,'lon':3})
    ds_lrp=tmp.mean(dim=['model','time']).compute()

    # Count number of True, Confident, Positive and Negative
    pos_count=np.count_nonzero(~np.isnan(tmp['pred'].sel(cat='Positive')))
    neg_count=np.count_nonzero(~np.isnan(tmp['pred'].sel(cat='Negative')))

    print("True, Confident(>=80), Positive: ",pos_count)
    print("True, Confident(>=80), Negative: ",neg_count)
    
    # Composite Fields
    tmp=ds_anoms.where(np.logical_and(ds_cnn_val['pred']==ds_cnn_val['verif'],
                                      ds_cnn_val['probs']>=confidence_thresh))
    tmp.chunk({'time':-1,'model':-1,'cat':2,'lat':4,'lon':3})
    ds_comp=tmp.mean(dim=['model','time']).compute()
                             
    # Write Data to File
    print("Writing Data")
    ds_lrp.to_netcdf('../data/cnn_test/model_lrpcomp.'+seas+'.cnn_cat.nc')
    ds_comp.to_netcdf('../data/cnn_test/model_anomscomp.'+seas+'.cnn_cat.nc')


DJF
True, Confident(>=80), Positive:  1983
True, Confident(>=80), Negative:  4038
JJA
True, Confident(>=80), Positive:  704
True, Confident(>=80), Negative:  111
