## Exploration of LimeforTime with weather dataset

This notebook shows how to use the LimeforTime explainer to explain trained onnx model with weather dataset. <br>
Here is a notebook showing how to use LimeforTime explainer:<br>
https://github.com/emanuel-metzenthin/Lime-For-Time/blob/master/demo/LIME-Pipeline.ipynb

### Load weather dataset

In [1]:
import os
import pandas as pd
import numpy as np
from lime_timeseries import LimeTimeSeriesExplainer
from scipy.special import softmax
from sklearn.model_selection import train_test_split

import onnx
import onnxruntime as ort

In [2]:
np.random.seed(42)

In [3]:
fname = "weather_prediction_dataset_light.csv"
if os.path.isfile(fname):
    data = pd.read_csv(fname)
else:
    data = pd.read_csv(f"https://zenodo.org/record/5071376/files/{fname}?download=1")
data.describe()

Unnamed: 0,DATE,MONTH,BASEL_cloud_cover,BASEL_humidity,BASEL_pressure,BASEL_global_radiation,BASEL_precipitation,BASEL_sunshine,BASEL_temp_mean,BASEL_temp_min,...,SONNBLICK_temp_mean,SONNBLICK_temp_min,SONNBLICK_temp_max,TOURS_humidity,TOURS_pressure,TOURS_global_radiation,TOURS_precipitation,TOURS_temp_mean,TOURS_temp_min,TOURS_temp_max
count,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,...,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0
mean,20045680.0,6.520799,5.418446,0.745107,1.017876,1.33038,0.234849,4.661193,11.022797,6.989135,...,-4.626327,-6.884319,-2.352244,0.781872,1.016639,1.369787,0.1861,12.205802,7.860536,16.551779
std,28742.87,3.450083,2.325497,0.107788,0.007962,0.935348,0.536267,4.330112,7.414754,6.653356,...,6.98708,7.120333,6.972886,0.115572,0.018885,0.926472,0.422151,6.467155,5.692256,7.714924
min,20000100.0,1.0,0.0,0.38,0.9856,0.05,0.0,0.0,-9.3,-16.0,...,-26.6,-30.3,-24.7,0.33,0.0003,0.05,0.0,-6.2,-13.0,-3.1
25%,20020700.0,4.0,4.0,0.67,1.0133,0.53,0.0,0.5,5.3,2.0,...,-9.4,-11.8,-7.1,0.7,1.0121,0.55,0.0,7.6,3.7,10.8
50%,20045670.0,7.0,6.0,0.76,1.0177,1.11,0.0,3.6,11.4,7.3,...,-4.4,-6.4,-2.2,0.8,1.0173,1.235,0.0,12.3,8.3,16.6
75%,20070700.0,10.0,7.0,0.83,1.0227,2.06,0.21,8.0,16.9,12.4,...,0.7,-1.1,2.7,0.87,1.0222,2.09,0.16,17.2,12.3,22.4
max,20100100.0,12.0,8.0,0.98,1.0408,3.55,7.57,15.3,29.0,20.8,...,13.8,8.7,14.3,1.0,1.0414,3.56,6.2,31.2,22.6,39.8


### Prepare dataset
Given how the classification model is trained, we prepare the testing data for prediction.

In [4]:
# select only data from De Bilt
columns = [col for col in data.columns if col.startswith('DE_BILT')]
data_debilt = data[columns]
data_debilt.describe()

Unnamed: 0,DE_BILT_cloud_cover,DE_BILT_humidity,DE_BILT_pressure,DE_BILT_global_radiation,DE_BILT_precipitation,DE_BILT_sunshine,DE_BILT_temp_mean,DE_BILT_temp_min,DE_BILT_temp_max
count,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0,3654.0
mean,5.303229,0.817882,1.015299,1.190903,0.236888,4.744444,10.70353,6.397099,14.798604
std,2.279416,0.097465,0.009861,0.870267,0.459495,3.995637,6.19077,5.639597,7.21074
min,0.0,0.37,0.9732,0.11,0.0,0.0,-7.9,-14.4,-4.7
25%,4.0,0.76,1.0094,0.41,0.0,1.1,6.2,2.3,9.2
50%,6.0,0.83,1.0157,1.02,0.01,4.1,11.0,6.8,14.9
75%,7.0,0.89,1.0217,1.86,0.29,7.5,15.5,10.8,20.2
max,8.0,1.0,1.0449,3.41,4.25,15.5,26.9,20.8,35.7


In [5]:
# find where the month changes
idx = np.where(np.diff(data['MONTH']) != 0)[0]
# idx contains the index of the last day of each month, except for the last month.
# of the last month only a single day is recorded, so we discard it.

nmonth = len(idx)
# add start of first month
idx = np.insert(idx, 0, 0)
ncol = len(columns)
# create single object containing each timeseries
# for simplicity we truncate each timeseries to the same length, i.e. 28 days
nday = 28
data_ts = np.zeros((nmonth, nday, ncol))
for m in range(nmonth):
    data_ts[m] = data_debilt[idx[m]:idx[m+1]][:28]
    
print(data_ts.shape)

(120, 28, 9)


In [6]:
# the labels are based on the month of each timeseries, in range 1 to 12
months = (np.arange(nmonth) + data['MONTH'][0] - 1) % 12 + 1

# one class per meteorological season
labels = np.zeros_like(months, dtype=int)
spring = (3 <= months) & (months <= 5)   # mar - may
summer = (6 <= months) & (months <= 8)   # jun - aug
autumn = (9 <= months) & (months <= 11)  # sep - nov
winter = (months <= 2) | (months == 12)  # dec - feb

labels[spring] = 0
labels[summer] = 1
labels[autumn] = 2
labels[winter] = 3

target = pd.get_dummies(labels)

classes = ['spring', 'summer', 'autumn', 'winter']
nclass = len(classes)

target.describe()

Unnamed: 0,0,1,2,3
count,120.0,120.0,120.0,120.0
mean,0.25,0.25,0.25,0.25
std,0.434828,0.434828,0.434828,0.434828
min,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0
75%,0.25,0.25,0.25,0.25
max,1.0,1.0,1.0,1.0


### Train/test split

In [7]:
data_trainval, data_test, target_trainval, target_test = train_test_split(data_ts, target, stratify=target, random_state=0, test_size=.12)
data_train, data_val, target_train, target_val = train_test_split(data_trainval, target_trainval, stratify=target_trainval, random_state=0, test_size=.12)
print(data_train.shape, data_val.shape, data_test.shape)

(92, 28, 9) (13, 28, 9) (15, 28, 9)


### Check predictions with ONNX model

In [8]:
# path to ONNX model
onnx_file = 'season_prediction_model.onnx'

# verify the ONNX model is valid
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)

In [9]:
def run_model(data):
    # get ONNX predictions
    sess = ort.InferenceSession(onnx_file)
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name
    
    onnx_input = {input_name: data.astype(np.float32)}
    pred_onnx = sess.run([output_name], onnx_input)[0]
    
    return pred_onnx
    #return softmax(pred_onnx[0], axis=1)

In [10]:
idx = 5 # explained instance
data_instance = data_test[idx][np.newaxis, ...]
# get ONNX predictions
pred_onnx = run_model(data_instance)
pred_class = classes[np.argmax(pred_onnx)]
print("The predicted class is:", pred_class)
print("The actual class is:", classes[np.argmax(target_test.iloc[idx])])

The predicted class is: winter
The actual class is: winter


In [13]:
num_features = 5 # how many feature contained in explanation
num_slices = 14 # split time series

In [14]:
explainer = LimeTimeSeriesExplainer(class_names=classes)
exp = explainer.explain_instance(data_instance, run_model, num_features=num_features, num_samples=5000, num_slices=num_slices, 
                                 replacement_method='total_mean')
#exp.as_pyplot_figure()

In [15]:
exp.predict_proba

array([0.28958738, 0.18529397, 0.19098334, 0.3341353 ], dtype=float32)