https://github.com/TeamHG-Memex/eli5/blob/master/notebooks/TextExplainer.ipynb

In [1]:
from sklearn.datasets import fetch_20newsgroups

categories = ['alt.atheism', 'soc.religion.christian', 
              'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=('headers', 'footers'),
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=('headers', 'footers'),
)

In [2]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC
from sklearn.decomposition import TruncatedSVD
from sklearn.pipeline import Pipeline, make_pipeline

vec = TfidfVectorizer(min_df=3, stop_words='english',
                      ngram_range=(1, 2))
svd = TruncatedSVD(n_components=100, n_iter=7, random_state=42)
lsa = make_pipeline(vec, svd)

clf = SVC(C=150, gamma=2e-2, probability=True)
pipe = make_pipeline(lsa, clf)
pipe.fit(twenty_train.data, twenty_train.target)
pipe.score(twenty_test.data, twenty_test.target)

0.89014647137150471

In [3]:
def print_prediction(doc):
    y_pred = pipe.predict_proba([doc])[0]
    for target, prob in zip(twenty_train.target_names, y_pred):
        print("{:.3f} {}".format(prob, target))    

doc = twenty_test.data[0]
print_prediction(doc)

0.001 alt.atheism
0.001 comp.graphics
0.995 sci.med
0.004 soc.religion.christian


In [4]:
from eli5.lime import TextExplainer

te = TextExplainer(random_state=42)
te.fit(doc, pipe.predict_proba)
te.show_prediction(target_names=twenty_train.target_names)

Contribution?,Feature
-0.36,<BIAS>
-9.179,Highlighted in text (sum)

Contribution?,Feature
-0.219,<BIAS>
-8.128,Highlighted in text (sum)

Contribution?,Feature
5.976,Highlighted in text (sum)
-0.119,<BIAS>

Contribution?,Feature
-0.345,<BIAS>
-5.296,Highlighted in text (sum)


In [5]:
import re
doc2 = re.sub(r'(recall|kidney|stones|medication|pain|tech)', '', doc, flags = re.I)
print_prediction(doc2)

0.064 alt.atheism
0.151 comp.graphics
0.359 sci.med
0.426 soc.religion.christian


In [7]:
te.metrics_

{'mean_KL_divergence': 0.020271305949858723, 'score': 0.98578406015683906}

It may happen that the accuracy score is perfect, but KL divergence is bad. because generated texts were not diverse enough classifier haven't learned anything useful; it's having a hard time predicting the probability output of the black-box pipeline on a held-out dataset.

it uses words as features and doesn't take word position in account

In [8]:
from sklearn.tree import DecisionTreeClassifier

te5 = TextExplainer(clf=DecisionTreeClassifier(max_depth=2), random_state=0)
te5.fit(doc, pipe.predict_proba)
print(te5.metrics_)
te5.show_weights()

{'mean_KL_divergence': 0.038434472643997117, 'score': 0.98259788563297445}


Weight,Feature
0.5447,kidney
0.4553,pain


https://github.com/TeamHG-Memex/eli5/blob/master/notebooks/Debugging%20scikit-learn%20text%20classification%20pipeline.ipynb

In [15]:
from sklearn.datasets import fetch_20newsgroups

categories = ['alt.atheism', 'soc.religion.christian', 
              'comp.graphics', 'sci.med']

twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=['headers', 'footers'],
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=['headers', 'footers'],
)

In [18]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.pipeline import make_pipeline

vec = TfidfVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);

In [19]:
from sklearn import metrics

def print_report(pipe):
    y_test = twenty_test.target
    y_pred = pipe.predict(twenty_test.data)
    report = metrics.classification_report(y_test, y_pred, 
        target_names=twenty_test.target_names)
    print(report)
    print("accuracy: {:0.3f}".format(metrics.accuracy_score(y_test, y_pred)))
    
print_report(pipe)

                        precision    recall  f1-score   support

           alt.atheism       0.93      0.77      0.84       319
         comp.graphics       0.84      0.97      0.90       389
               sci.med       0.95      0.89      0.92       396
soc.religion.christian       0.88      0.92      0.90       398

           avg / total       0.90      0.89      0.89      1502

accuracy: 0.893


In [20]:
import eli5

eli5.show_weights(clf, vec=vec, top=10, 
                  target_names=twenty_test.target_names)

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3
+19.035,atheism,,
+14.480,writes,,
+12.999,motto,,
+12.596,livesey,,
+12.517,keith,,
+12.137,mathew,,
+12.091,morality,,
… 7782 more positive …,… 7782 more positive …,,
… 24393 more negative …,… 24393 more negative …,,
-13.003,christ,,

Weight?,Feature
+19.035,atheism
+14.480,writes
+12.999,motto
+12.596,livesey
+12.517,keith
+12.137,mathew
+12.091,morality
… 7782 more positive …,… 7782 more positive …
… 24393 more negative …,… 24393 more negative …
-13.003,christ

Weight?,Feature
+14.083,graphics
+8.147,image
+7.313,code
+7.285,3d
+6.969,files
+6.688,images
… 8207 more positive …,… 8207 more positive …
… 23968 more negative …,… 23968 more negative …
-6.891,edu
-7.547,writes

Weight?,Feature
+12.065,msg
+11.688,disease
+11.632,health
+11.251,doctor
+10.675,treatment
+10.097,pitt
+9.601,com
+8.820,cancer
… 11395 more positive …,… 11395 more positive …
… 20780 more negative …,… 20780 more negative …

Weight?,Feature
+18.488,rutgers
+16.964,christians
+16.647,church
+16.621,christ
+13.858,athos
+11.620,1993
+11.615,christian
+10.766,heaven
+10.458,love
… 8899 more positive …,… 8899 more positive …


In [22]:
eli5.show_prediction(clf, twenty_test.data[0], vec=vec, 
                     target_names=twenty_test.target_names,
                     targets=['sci.med'])

Contribution?,Feature
5.484,Highlighted in text (sum)
-3.578,<BIAS>
