In [None]:
import numpy as np
import xarray as xr
import os
import pandas as pd

In [None]:
timestamp = '2021_08_05_07_01'
DIR1 = 'runs/NN/TCN/NN_embedding/{}/predictions/test/'.format(timestamp)
ds = xr.open_dataset(os.path.join(DIR1,'forecast.nc'))
Lon, Lat = np.meshgrid(ds.longitude.values, ds.latitude.values)
ytrue = ds['t2ma_true']
ypred = ds['t2ma_pred']

def skill(y_true, y_pred):
    return np.dot(y_true, y_pred)/(np.linalg.norm(y_true, 2)*np.linalg.norm(y_pred, 2))

In [None]:
import matplotlib.pyplot as plt
week = 4
year = 2016
fig = plt.figure(figsize=(8,4))
ds_slice = ds.sel(time=ds.forecast_date[:,week-1].dt.year == year, forecast = ds.forecast[week-1])
dates = ds_slice.forecast_date
mean = ds_slice['t2ma_true'].mean(('latitude', 'longitude'))
std  = ds_slice['t2ma_true'].std(('latitude', 'longitude'))
plt.plot(dates, mean)
plt.fill_between(dates, mean - std, mean + std, color='blue', alpha=0.2)

mean = ds_slice['t2ma_pred'].mean(('latitude', 'longitude'))
std  = ds_slice['t2ma_pred'].std(('latitude', 'longitude'))
plt.plot(dates, mean)
plt.fill_between(dates, mean - std, mean + std, color='red', alpha=0.2)
plt.grid()

pred_skill = np.zeros(dates.shape[0])
for idt in range(pred_skill.shape[0]):
    pred_skill[idt] = skill(
        ds_slice['t2ma_true'].values[:,:,idt].reshape(-1),
        ds_slice['t2ma_pred'].values[:,:,idt].reshape(-1)
    )
plt.title('{} mean skill = {:0.2f}'.format(year, pred_skill.mean()))
plt.show()

In [None]:
from mpl_toolkits.basemap import Basemap
idt=-1

fig = plt.figure(figsize=(15,20))
gs = fig.add_gridspec(1,2)
fig.add_subplot(gs[0,0])
m = Basemap(projection='merc',llcrnrlat=np.min(Lat),urcrnrlat=np.max(Lat),\
            llcrnrlon=np.min(Lon),urcrnrlon=np.max(Lon),resolution='c')
m.drawcoastlines()
m.drawstates()
m.drawcountries(linewidth=1, linestyle='solid', color='blue')
plt.xlabel('longitude')
plt.ylabel('latitude')
lon, lat = m(Lon,Lat)
plt.contourf(lon, lat, ds_slice['t2ma_true'].values[:, :, idt], cmap='jet', levels=np.linspace(-5,5, 41), extend='both')
plt.title(dates[idt].values)
plt.colorbar(shrink=0.1)

fig.add_subplot(gs[0,1])
m = Basemap(projection='merc',llcrnrlat=np.min(Lat),urcrnrlat=np.max(Lat),\
            llcrnrlon=np.min(Lon),urcrnrlon=np.max(Lon),resolution='c')
m.drawcoastlines()
m.drawstates()
m.drawcountries(linewidth=1, linestyle='solid', color='blue')
plt.xlabel('longitude')
plt.ylabel('latitude')
plt.contourf(lon, lat, ds_slice['t2ma_pred'].values[:, :, idt], cmap='jet', levels=np.linspace(-5,5, 41), extend='both')
plt.title(dates[idt].values)
plt.colorbar(shrink=0.1)
plt.show()