<img width="150" alt="Logo_ER10" src="https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png">

### Interpreting a leaf identification model with LIME
This notebook demonstrates the use of DIANNA with the LIME timeseries method on the coffee dataset.

LIME (Local Interpretable Model-agnostic Explanations) is an explainable-AI method that aims to create an interpretable model that locally represents the classifier. For more details see the [LIME paper](https://arxiv.org/abs/1602.04938).

*NOTE*: This tutorial is still work-in-progress, the final results need to be improved by tweaking the LIME parameters

#### 1. Imports and paths

In [1]:
import dianna
from pathlib import Path
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier as KNN
from sklearn.metrics import accuracy_score as acc

#### 2. Loading the data

In [3]:
# Load coffee dataset
path_to_data = "path_to_data"
path_to_data = "/home/yangliu/MLexpo/Lime-For-Time/demo/data"
coffee_train = pd.read_csv(Path(path_to_data, "coffee_train.csv"),
                            sep=',', header=None).astype(float)
coffee_train_y = coffee_train.loc[:, 0]
coffee_train_x = coffee_train.loc[:, 1:]
coffee_test = pd.read_csv(Path(path_to_data, "coffee_test.csv"),
                           sep=',', header=None).astype(float)
coffee_test_y = coffee_test.loc[:, 0]
coffee_test_x = coffee_test.loc[:, 1:]

In [4]:
# Define and train model
knn = KNN()
knn.fit(coffee_train_x, coffee_train_y)
print('Accuracy KNN for coffee dataset: %f' % (acc(coffee_test_y, knn.predict(coffee_test_x))))

Accuracy KNN for coffee dataset: 0.964286


In [5]:
# Select instance for explanation
idx = 5 # explained instance
num_features = 10 # how many feature contained in explanation
num_slices = 24 # split time series
series = coffee_test_x.iloc[idx, :]

#### 3. Applying LIME with DIANNA

In [None]:
# Things we need to replace/reproduce
class_names = ['Arabica', 'Robusta']

explanation_heatmap = dianna.explain_timeseries(knn.predict_proba,
                                                series,
                                                'LIME',
                                                labels=class_names[0],
                                                num_features=num_features,
                                                num_samples=5000,
                                                num_slices=num_slices,
                                                mask_type='mean'
)