# DroughtPredict Visualization notebook

This notebook allows to visualize the outputs of the MINT Drought Prediction model as well as the training accuracy.


This notebook makes use of interactive visualization and needs to be executed (in cell order) after each initialization

In [1]:
%matplotlib inline

In [2]:
#Import Packages

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import xarray as xr
import matplotlib.cm as cm
import numpy as np
import pandas as pd
from scipy.stats import pearsonr

## Visualising the output of the prediction model

The dashboard below represents SPI values at lead time 1,2,3 and 4 months respectively (to be changed with the time slider) from the CNN model. The model was initialized with values from the ECMWF ERA5 datasets, which ends in September 2019. 

The top plots represent the spatial values. The bottom plot is a timeseries of the index for a particular location as specified in the lat/lon box. Note that the resolution of the model is coarse (1.5x1 deg grid), therefore selecting close points in lat/lon will not change the timeseries plot. 

In [4]:
d_pred = xr.open_dataset('./data/results.nc')

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def plot(time='2020.10.01',lon=23.75,lat=3.30):
    d_pred = xr.open_dataset('./data/results.nc')
    #Grab the proper lat/lon
    lon_idx=find_nearest(d_pred.longitude.values,float(lon))
    lat_idx=find_nearest(d_pred.latitude.values,float(lat))
    
    #Create the figure
    levels = np.arange(-4,4.2,0.2)
    fig = plt.figure(figsize=[10,8])

    #make the temportal plot
    ax2 = plt.subplot(2,1,2)
    ax2.plot(d_pred.time.values,d_pred.spi.values[:,lat_idx,lon_idx],marker='o')
    ax2.set_title('SPI for coordinates '+'{0:.2f}'.format(d_pred.latitude.values[lat_idx])+' latitude and '+'{0:.2f}'.format(d_pred.longitude.values[lon_idx])+ ' longitude')
    ax2.set_ylabel('Standardized Precipitation Index (unitless)')
    ax2.set_ylim(-4,4)
    ax2.set_xticks(d_pred.time.values)
    
    #Make the spatial plot
    ax1 = plt.subplot(2,1,1,projection=ccrs.PlateCarree())
    d_pred['spi'].sel(time=time).plot.contourf(ax=ax1,levels = levels,
                  transform=ccrs.PlateCarree(), cmap=cm.BrBG, cbar_kwargs={'orientation':'vertical'})
    #Add borders and coastlines
    ax1.add_feature(cfeature.BORDERS)
    ax1.add_feature(cfeature.COASTLINE)
    # Pretty up the plot
    gl = ax1.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                      linewidth=2, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    gl.xlines = False
    gl.ylines = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
    gl.xlabel_style = {'size': 12, 'color': 'gray'}
    gl.ylabel_style = {'size': 12, 'color': 'gray'}
    d_pred.close()

#Widgets
time_strings = pd.to_datetime(d_pred.time.values).strftime('%Y.%m.%d')
time_widget = widgets.SelectionSlider(description='Time', options=time_strings, width='40%')
lon_widget = widgets.FloatText(value=23.75,description='longitude')
lat_widget = widgets.FloatText(value=3.30,description='latitude')
x = widgets.interactive(plot,time=time_widget,lon=lon_widget,lat=lat_widget)    
display(x)

interactive(children=(SelectionSlider(description='Time', options=('2020.10.01', '2020.11.01', '2020.12.01', '…

## Visualizing the accuracy of the prediction model based on test results

### Spatial correlation

In the dashboard below, the top plots represent the SPI values for ECMWF-ERA5 (test dataset, ground truth) and the the CNN model. Users can select a start date for the simulation using the 'time' slider and a lead time, using the 'lead time' slider. A selection of time='2019.01.01' woth lead time of 1 will show spatial plots for 2019.02.01.

The timeseries plot represents the SPI value for a specific location (adjusted through the lat/lon boxes) with lead time of 1 to 4 months. The start time can be selected using the time slider. Note that the CNN prediction is NaN for the start time (no prediction has been made) but this value from ECMWF is used to initialize the prediction.

In [5]:
d_test = xr.open_dataset('./data/ECMWF_EA_SPI.nc')

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def plot_accuracy(time='1982.01.01',lead_time='1',lon=23.75,lat=3.30):
    d_test = xr.open_dataset('./data/ECMWF_EA_SPI.nc')
    d_1month = xr.open_dataset('./data/test_results_1_month_lead.nc')
    d_2month = xr.open_dataset('./data/test_results_2_month_lead.nc')
    d_3month = xr.open_dataset('./data/test_results_3_month_lead.nc')
    d_4month = xr.open_dataset('./data/test_results_4_month_lead.nc')
    #Grab the proper lat/lon
    lon_idx=find_nearest(d_pred.longitude.values,float(lon))
    lat_idx=find_nearest(d_pred.latitude.values,float(lat))
    lead_time = int(lead_time)
    
    #Create the figure
    levels = np.arange(-4,4.2,0.2)
    fig = plt.figure(figsize=[15,8])
    
    #Make the 2 spatial plots
    #ECMWF
    ax1 = plt.subplot(2,2,1,projection=ccrs.PlateCarree())
    d_test['spi'].sel(time=pd.to_datetime(time)+pd.DateOffset(months=lead_time)).plot.contourf(ax=ax1,levels = levels,
                  transform=ccrs.PlateCarree(), cmap=cm.BrBG, cbar_kwargs={'orientation':'vertical'})
    #Add borders and coastlines
    ax1.add_feature(cfeature.BORDERS)
    ax1.add_feature(cfeature.COASTLINE)
    ax1.set_title('ECMWF-ERA5; time='+(pd.to_datetime(time)+pd.DateOffset(months=lead_time)).strftime('%Y-%m-%d'))
    # Pretty up the plot
    gl1 = ax1.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                      linewidth=2, color='gray', alpha=0.5, linestyle='--')
    gl1.top_labels = False
    gl1.right_labels = False
    gl1.xlines = False
    gl1.ylines = False
    gl1.xformatter = LONGITUDE_FORMATTER
    gl1.yformatter = LATITUDE_FORMATTER
    gl1.xlabel_style = {'size': 12, 'color': 'gray'}
    gl1.ylabel_style = {'size': 12, 'color': 'gray'}
    
    #CNN
    ax2 = plt.subplot(2,2,2,projection=ccrs.PlateCarree())
    if lead_time==1:
        d_1month['spi'].sel(time=pd.to_datetime(time)+pd.DateOffset(months=lead_time)).plot.contourf(ax=ax2,levels = levels,
                  transform=ccrs.PlateCarree(), cmap=cm.BrBG, cbar_kwargs={'orientation':'vertical'}) 
    elif lead_time==2:
        d_2month['spi'].sel(time=pd.to_datetime(time)+pd.DateOffset(months=lead_time)).plot.contourf(ax=ax2,levels = levels,
                  transform=ccrs.PlateCarree(), cmap=cm.BrBG, cbar_kwargs={'orientation':'vertical'})
    elif lead_time==3:
        d_3month['spi'].sel(time=pd.to_datetime(time)+pd.DateOffset(months=lead_time)).plot.contourf(ax=ax2,levels = levels,
                  transform=ccrs.PlateCarree(), cmap=cm.BrBG, cbar_kwargs={'orientation':'vertical'})
    elif lead_time==4:
        d_4month['spi'].sel(time=pd.to_datetime(time)+pd.DateOffset(months=lead_time)).plot.contourf(ax=ax2,levels = levels,
                  transform=ccrs.PlateCarree(), cmap=cm.BrBG, cbar_kwargs={'orientation':'vertical'})
    #Add borders and coastlines
    ax2.add_feature(cfeature.BORDERS)
    ax2.add_feature(cfeature.COASTLINE)
    ax2.set_title('CNN; time='+(pd.to_datetime(time)+pd.DateOffset(months=lead_time)).strftime('%Y-%m-%d'))
    # Pretty up the plot
    gl2 = ax2.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                      linewidth=2, color='gray', alpha=0.5, linestyle='--')
    gl2.top_labels = False
    gl2.right_labels = False
    gl2.xlines = False
    gl2.ylines = False
    gl2.xformatter = LONGITUDE_FORMATTER
    gl2.yformatter = LATITUDE_FORMATTER
    gl2.xlabel_style = {'size': 12, 'color': 'gray'}
    gl2.ylabel_style = {'size': 12, 'color': 'gray'}
    
    #make the temporal plot
    ax3 = plt.subplot(2,2,(3,4))
    t=slice(pd.to_datetime(time),pd.to_datetime(time)+pd.DateOffset(months=4))
    d_test_ = d_test.sel(time=t).spi.values[:,lat_idx,lon_idx]
    t_val= d_test.sel(time=t).time.values
    
    # get the values for the predictions
    pred =[]
    pred.append(d_1month.sel(time=pd.to_datetime(time)+pd.DateOffset(months=1)).spi.values[lat_idx,lon_idx])
    pred.append(d_2month.sel(time=pd.to_datetime(time)+pd.DateOffset(months=2)).spi.values[lat_idx,lon_idx])
    pred.append(d_3month.sel(time=pd.to_datetime(time)+pd.DateOffset(months=3)).spi.values[lat_idx,lon_idx])
    pred.append(d_4month.sel(time=pd.to_datetime(time)+pd.DateOffset(months=4)).spi.values[lat_idx,lon_idx])
    
    ax3.plot(t_val,d_test_,marker='o',label='ECMWF-ERA5')
    ax3.plot(t_val[1:],np.array(pred),marker='^',label='CNN')
    ax3.set_title('SPI for coordinates '+'{0:.2f}'.format(d_pred.latitude.values[lat_idx])+' latitude and '+'{0:.2f}'.format(d_pred.longitude.values[lon_idx])+ ' longitude')
    ax3.set_ylabel('Standardized Precipitation Index (unitless)')
    ax3.set_ylim(-4,4)
    ax3.set_xticks(t_val) 
    ax3.legend(loc='best')
    
    d_test.close()
    d_1month.close()
    d_2month.close()
    d_3month.close()
    d_4month.close()

#Widgets
time2_strings = pd.to_datetime(d_test.time.values[12:-36]).strftime('%Y.%m.%d')
time2_widget = widgets.SelectionSlider(description='Time', options=time2_strings, width='40%')
lead_time_strings = ['1','2','3','4']
lead_time_widget = widgets.SelectionSlider(description='Lead Time', options=lead_time_strings, width='40%')
lon_widget = widgets.FloatText(value=23.75,description='longitude')
lat_widget = widgets.FloatText(value=3.30,description='latitude')
x2 = widgets.interactive(plot_accuracy,time=time2_widget,lead_time=lead_time_widget,lon=lon_widget,lat=lat_widget)    
display(x2)

interactive(children=(SelectionSlider(description='Time', options=('1982.01.01', '1982.02.01', '1982.03.01', '…

### Temporal correlation

This dashboard shows the time correlation of the SPI with various lead times (as chosen by the lead_time slider) for the entire test period for a specific geographical location,set in the lat/lon text boxes.

In [6]:
d_test = xr.open_dataset('./data/ECMWF_EA_SPI.nc')

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def plot_corr(lead_time='1',lon=23.75,lat=3.30):
    d_test = xr.open_dataset('./data/ECMWF_EA_SPI.nc')
    d_1month = xr.open_dataset('./data/test_results_1_month_lead.nc')
    d_2month = xr.open_dataset('./data/test_results_2_month_lead.nc')
    d_3month = xr.open_dataset('./data/test_results_3_month_lead.nc')
    d_4month = xr.open_dataset('./data/test_results_4_month_lead.nc')
    #set the start/end date
    start_date = '1982.01.01'
    end_date = '2017.12.01'
    #Grab the proper lat/lon
    lon_idx=find_nearest(d_pred.longitude.values,float(lon))
    lat_idx=find_nearest(d_pred.latitude.values,float(lat))
    lead_time = int(lead_time)
    t= slice(start_date,end_date)
    t_val= d_test.sel(time=t).time.values
    d_test_=d_test.sel(time=t).spi.values[:,lat_idx,lon_idx]
    d1=d_1month.sel(time=t).spi.values[:,lat_idx,lon_idx]
    d2=d_2month.sel(time=t).spi.values[:,lat_idx,lon_idx]
    d3=d_3month.sel(time=t).spi.values[:,lat_idx,lon_idx]
    d4=d_4month.sel(time=t).spi.values[:,lat_idx,lon_idx]
    
    fig = plt.figure(figsize=[15,8])
    
    if lead_time == 1:
        plt.plot(t_val,d_test_,label='ECMWF-ERA5')
        plt.plot(t_val,d1,label='CNN')
        corr_=pearsonr(d_test_,d1)
        plt.ylabel('Standardized Precipitation Index (unitless)')
        plt.title('SPI at lead time='+ str(lead_time)+' month for coordinates '+'{0:.2f}'.format(d_pred.latitude.values[lat_idx])+' latitude and '+'{0:.2f}'.format(d_pred.longitude.values[lon_idx])+ ' longitude, corr='+'{0:.2f}'.format(corr_[0]))
    elif lead_time == 2:
        plt.plot(t_val,d_test_,label='ECMWF-ERA5')
        plt.plot(t_val,d2,label='CNN')
        corr_=pearsonr(d_test_,d2)
        plt.ylabel('Standardized Precipitation Index (unitless)')
        plt.title('SPI at lead time='+ str(lead_time)+' month for coordinates '+'{0:.2f}'.format(d_pred.latitude.values[lat_idx])+' latitude and '+'{0:.2f}'.format(d_pred.longitude.values[lon_idx])+ ' longitude, corr='+'{0:.2f}'.format(corr_[0]))
    elif lead_time==3:
        plt.plot(t_val,d_test_,label='ECMWF-ERA5')
        plt.plot(t_val,d3,label='CNN')
        corr_=pearsonr(d_test_,d3)
        plt.ylabel('Standardized Precipitation Index (unitless)')
        plt.title('SPI at lead time='+ str(lead_time)+' month for coordinates '+'{0:.2f}'.format(d_pred.latitude.values[lat_idx])+' latitude and '+'{0:.2f}'.format(d_pred.longitude.values[lon_idx])+ ' longitude, corr='+'{0:.2f}'.format(corr_[0]))
    elif lead_time==4:
        plt.plot(t_val,d_test_,label='ECMWF-ERA5')
        plt.plot(t_val,d4,label='CNN')
        corr_=pearsonr(d_test_,d4)
        plt.ylabel('Standardized Precipitation Index (unitless)')
        plt.title('SPI at lead time='+ str(lead_time)+' month for coordinates '+'{0:.2f}'.format(d_pred.latitude.values[lat_idx])+' latitude and '+'{0:.2f}'.format(d_pred.longitude.values[lon_idx])+ ' longitude, corr='+'{0:.2f}'.format(corr_[0]))
    plt.legend(loc='best')
    
    d_test.close()
    d_1month.close()
    d_2month.close()
    d_3month.close()
    d_4month.close()

#widgets
lead_time_strings = ['1','2','3','4']
lead_time_widget = widgets.SelectionSlider(description='Lead Time', options=lead_time_strings, width='40%')
lon_widget = widgets.FloatText(value=23.75,description='longitude')
lat_widget = widgets.FloatText(value=3.30,description='latitude')
x3 = widgets.interactive(plot_corr,lead_time=lead_time_widget,lon=lon_widget,lat=lat_widget)    
display(x3)

interactive(children=(SelectionSlider(description='Lead Time', options=('1', '2', '3', '4'), value='1'), Float…