In [1]:
import cvplot
import numpy as np
import pandas as pd
from lime.lime_tabular import LimeTabularExplainer
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

In [2]:
# Prepare dataset
df = pd.read_csv('https://github.com/nyuvis/partial_dependence/raw/master/example_data/test.csv')
X = df.drop('label', axis=1)
y = df['label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=123)

# Min-max scaling
X_train = (X_train - X.min()) / (X.max() - X.min())
X_test = (X_test - X.min()) / (X.max() - X.min())

In [3]:
# Train classifier
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train.values, y_train)

clf.score(X_test.values, y_test)

0.8446115288220551

In [4]:
# Draw Contribution-Value plot
plot = cvplot.CVPlot(X_test, y_test, features=['alcohol', 'pH'], model=clf)
plot

  0%|          | 0/399 [00:00<?, ?it/s]

CVPlot(value={'contributions': {'alcohol': [[[0.0, 0.01983479451909141], [0.05175438596491228, 0.0159792559987…

In [12]:
# Display selected instances
X_test.iloc[plot.selection,:]

Unnamed: 0,alcohol,chlorides,citric acid,density,fixed acidity,free sulfur dioxide,pH,residual sugar,sulphates,total sulfur dioxide,volatile acidity
667,0.716667,0.059347,0.260163,0.078713,0.380952,0.195122,0.203883,0.124233,0.146667,0.294915,0.271186
991,0.255556,0.109792,0.325203,0.146181,0.349206,0.452962,0.349515,0.115031,0.360000,0.515254,0.293785
982,0.700000,0.086053,0.073171,0.054672,0.000000,0.341463,0.601942,0.069018,0.226667,0.267797,0.519774
934,0.500000,0.062315,0.284553,0.097712,0.396825,0.292683,0.009709,0.096626,0.413333,0.332203,0.169492
229,0.250000,0.089021,0.219512,0.102753,0.269841,0.069686,0.427184,0.003067,0.120000,0.216949,0.418079
...,...,...,...,...,...,...,...,...,...,...,...
80,0.433333,0.091988,0.276423,0.118263,0.396825,0.278746,0.398058,0.062883,0.480000,0.484746,0.180791
848,0.333333,0.142433,0.186992,0.160527,0.396825,0.355401,0.427184,0.136503,0.253333,0.501695,0.327684
192,0.233333,0.106825,0.333333,0.151221,0.317460,0.404181,0.291262,0.095092,0.306667,0.671186,0.225989
22,0.166667,0.151335,0.227642,0.145405,0.365079,0.125436,0.466019,0.073620,0.186667,0.294915,0.519774
