In [37]:
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor

In [55]:
from shapash.explainer.smart_explainer import SmartExplainer
from shapash.utils.load_smartpredictor import load_smartpredictor

## Load Datasets

In [5]:
df  = sns.load_dataset('tips')
df.head(2)

Unnamed: 0,total_bill,tip,sex,smoker,day,time,size
0,16.99,1.01,Female,No,Sun,Dinner,2
1,10.34,1.66,Male,No,Sun,Dinner,3


In [11]:
X = df[df.columns.difference(['tip'])]
y = df['tip']

In [12]:
X.head(2)

Unnamed: 0,day,sex,size,smoker,time,total_bill
0,Sun,Female,2,No,Dinner,16.99
1,Sun,Male,3,No,Dinner,10.34


In [13]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 244 entries, 0 to 243
Data columns (total 7 columns):
 #   Column      Non-Null Count  Dtype   
---  ------      --------------  -----   
 0   total_bill  244 non-null    float64 
 1   tip         244 non-null    float64 
 2   sex         244 non-null    category
 3   smoker      244 non-null    category
 4   day         244 non-null    category
 5   time        244 non-null    category
 6   size        244 non-null    int64   
dtypes: category(4), float64(2), int64(1)
memory usage: 7.4 KB


## Handle Categorical features

In [16]:
X.dtypes

day           category
sex           category
size             int64
smoker        category
time          category
total_bill     float64
dtype: object

In [20]:
X.day

0       Sun
1       Sun
2       Sun
3       Sun
4       Sun
       ... 
239     Sat
240     Sat
241     Sat
242     Sat
243    Thur
Name: day, Length: 244, dtype: category
Categories (4, object): ['Thur', 'Fri', 'Sat', 'Sun']

In [29]:
X.dtypes

day           category
sex           category
size             int64
smoker        category
time          category
total_bill     float64
dtype: object

In [30]:
cat_features = [col for col in X.columns if X[col].dtypes == 'category']
cat_features

['day', 'sex', 'smoker', 'time']

In [31]:
for cat_feature in cat_features:
    X[cat_feature] = X[cat_feature].cat.codes

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [32]:
X

Unnamed: 0,day,sex,size,smoker,time,total_bill
0,3,1,2,1,1,16.99
1,3,0,3,1,1,10.34
2,3,0,3,1,1,21.01
3,3,0,2,1,1,23.68
4,3,1,4,1,1,24.59
...,...,...,...,...,...,...
239,2,0,3,1,1,29.03
240,2,1,2,0,1,27.18
241,2,0,2,0,1,22.67
242,2,0,2,1,1,17.82


## Train Test split

In [36]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
len(X_train), len(X_test), len(y_train), len(y_test)

(195, 49, 195, 49)

## Create Model

In [47]:
regressor  =  RandomForestRegressor().fit(X_train, y_train)

## Understand the model using shapash

In [48]:
xpl = SmartExplainer()

In [49]:
xpl.compile(
    x = X_test, 
    model = regressor,
)

Backend: Shap TreeExplainer


### Identify 3 most important features

In [52]:
xpl.run_app(title_story='Tips Dataset')



INFO:root:Your Shapash application run on http://LTIN263504:8050/


Dash is running on http://0.0.0.0:8050/



INFO:root:Use the method .kill() to down your app.
INFO:shapash.webapp.smart_app:Dash is running on http://0.0.0.0:8050/



<CustomThread(Thread-6, started 6904)>

 * Serving Flask app 'shapash.webapp.smart_app' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


INFO:werkzeug: * Running on http://192.168.61.147:8050/ (Press CTRL+C to quit)
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:51] "GET / HTTP/1.1" 200 -
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:51] "GET /assets/material-icons.css?m=1636161305.4098678 HTTP/1.1" 200 -
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:51] "GET /assets/style.css?m=1636161305.4918709 HTTP/1.1" 200 -
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:51] "GET /_dash-component-suites/dash_renderer/react@16.v1_8_3m1636161194.14.0.min.js HTTP/1.1" 200 -
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:51] "GET /_dash-component-suites/dash_renderer/prop-types@15.v1_8_3m1636161194.7.2.min.js HTTP/1.1" 200 -
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:51] "GET /_dash-component-suites/dash_renderer/polyfill@7.v1_8_3m1636161194.8.7.min.js HTTP/1.1" 200 -
INFO:werkzeug:192.168.61.147 - - [06/Nov/2021 10:49:52] "GET /_dash-component-suites/dash_renderer/react-dom@16.v1_8_3m1636161194.14.0

### Save the SmartPredictor object

In [53]:
predictor = xpl.to_smartpredictor()

In [54]:
predictor.save('shapash_predictor.pkl')

### Load SmartPredictor

In [57]:
predictor_load  = load_smartpredictor('shapash_predictor.pkl')

In [58]:
predictor_load.add_input(x=X, ypred=y)

In [59]:
detailed_contributions = predictor_load.detail_contributions()

In [61]:
detailed_contributions.head()

Unnamed: 0,tip,day,sex,size,smoker,time,total_bill
0,1.01,-0.112152,-0.045576,-0.079049,-0.079562,-0.010776,-0.843569
1,1.66,0.109291,-0.036972,-0.078256,-0.021585,-0.000909,-1.178751
2,3.5,0.117993,-0.006172,-0.112403,-0.020205,0.011796,0.459509
3,3.31,0.192581,0.003129,-0.098057,0.01528,-0.000129,0.480913
4,3.61,0.125073,-0.029428,0.056646,0.084232,-0.014634,0.391328
