# LIME
Local Interpretable Model-agnostic Explanations
- Can be used for any ML model (agnostic)
- Good for black box models (can only see IN and OUT)
![alt text](../images/lime_ex.png)
- Prediciton made in this example is highly non-linear
- The model then learn some complex patterns as a combination of those 2 features.
- Therefore, we zoom in to the **local area** and create a simple explanation without taking into account the whole model
- LIME fits a linear interpretable model in such area which is often called **surrogate** as well. Creating a local approximation. <br> <br>
- Using prior knowledge we can validate the explanations and create trust
- ***Cons:***
    - Explanations are locally faithful, but not necessarily globally <br>
- **Math used in LIME:** <br>
$x$ - input data point <br>
$f$ - complex model <br>
$g$ - Simple interpretable model (**surrogate**) <br>
$G$ - Family of interpretable models(linear reg and its variants) <br>
$\pi$ - Defines local neighbourhood of $x$ data point, with some sort of **proximity measure**<br>
$\mathcal{L}(f, g, \pi_x)$ - we look for an approximation of model $f$ by the simple model $g$ in the neighbourhood of our datapoint x $(\pi_x)$ <br>
$\Omega(g)$ - regularize (simplify) the complexity of our simple **surrogate model** $(g)$ <br>

$$ \xi(x) = argmin_{g \in G} \mathcal{L}(f, g, \pi_x) + \Omega(g) $$ <br>
- We look for a simple model $g$ that looks for the closest approximation of model $f$, and additionally stay as simple (minimize the complexity) as possible ($\Omega(g)$)

# LIME: how to train the surrogate
![alt text](../images/lime_ex2.png)
Steps:
1. Generate some new data points in the neighbourhood of our input data point $x$ (yellow points) (they will be weighted according to the distance to our data point)
2. This datapoints are generated by permutations (by sampling from a normal distribution with the mean and standard deviation for each feature
3. We get the prediction from these data points using our **complex model $f$** so we end up with a new dataset
4. The datapoints that are closer (heatmap) to the input $x$ are weighted the most, to ensure the model is locally faithfull
5. now, for $\Omega(g)$, we use a $g = $ **Sparse Linear Model** (aim to produce n zero weights as possible)
6. In Practice we could use a regularization technique: **Lasso Regression**


In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, accuracy_score
from interpret.blackbox import LimeTabular
from interpret import show

In [5]:
#Load data
path = "src/healthcare-dataset-stroke-data.csv"
stroke_data = pd.read_csv(path)

In [7]:
# Preprocess data
# One-hot encode all categorical columns
categorical_cols = ["gender",
                    "ever_married",
                    "work_type",
                    "Residence_type",
                    "smoking_status"]
encoded = pd.get_dummies(stroke_data[categorical_cols], 
                        prefix=categorical_cols)

# Update data with new columns
stroke_data = pd.concat([encoded, stroke_data], axis=1)
stroke_data.drop(categorical_cols, axis=1, inplace=True)

# Impute missing values of BMI
stroke_data.bmi = stroke_data.bmi.fillna(0)
        
# Drop id as it is not relevant
stroke_data.drop(["id"], axis=1, inplace=True)

# Standardization 
# Usually we would standardize here and convert it back later
# But for simplification we will not standardize / normalize the features

In [9]:
# Split the data for evaluation
def get_data_split(dataset):
    X = dataset.iloc[:,:-1]
    y = dataset.iloc[:,-1]
    return train_test_split(X, y, test_size=0.20, random_state=2021)
X_train, X_test, y_train, y_test = get_data_split(stroke_data)


In [11]:
# Oversample the train data
oversample = RandomOverSampler(sampling_strategy='minority')

# Convert to numpy and oversample
x_np = X_train.to_numpy()
y_np = y_train.to_numpy()
x_np, y_np = oversample.fit_resample(x_np, y_np)

# Convert back to pandas
X_train = pd.DataFrame(x_np, columns=X_train.columns)
y_train = pd.Series(y_np, name=y_train.name)


In [15]:
# %% Fit blackbox model
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
print(f"F1 Score {f1_score(y_test, y_pred, average='macro')}")
print(f"Accuracy {accuracy_score(y_test, y_pred)}")


F1 Score 0.5342599524755053
Accuracy 0.9452054794520548


In [18]:
# %% Apply lime
# Initilize Lime for Tabular data
lime = LimeTabular(predict_fn=rf.predict_proba, 
                   data=X_train, 
                   random_state=1)
# Get local explanations
lime_local = lime.explain_local(X_test[-20:], 
                                y_test[-20:], 
                                name='LIME')

In [19]:
show(lime_local)

The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html
The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_table package is deprecated. Please replace
`import dash_table` with `from dash import dash_table`

Also, if you're using any of the table format helpers (e.g. Group), replace 
`from dash_table.Format import Group` with 
`from dash.dash_table.Format import Group`
  import dash_table as dt
