# Transformers Interpret Multiclass Classification Example

In [39]:
import os
import sys
module_path = os.path.abspath(os.path.join('.\\richa\Documents\GitHub\\transformers-interpret'))
if module_path not in sys.path:
    sys.path.append(module_path)
print(module_path)

C:\Users\richa\Documents\GitHub\transformers-interpret


In [40]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Import Industry Classification Mode
This finetuned model by @sampathkethineedi uses a distilbert base to predict the professional industry a text is referring to. 

In [41]:
tokenizer = AutoTokenizer.from_pretrained("sampathkethineedi/industry-classification")
model = AutoModelForSequenceClassification.from_pretrained("sampathkethineedi/industry-classification")

Let's explore the classes, there are 62 unique classes. Many of these are overlapping/related industries such as __Health Care Equipment__ and __Health Care Supplies__.

In [42]:
model.config.id2label

{0: 'Advertising',
 1: 'Aerospace & Defense',
 2: 'Apparel Retail',
 3: 'Apparel, Accessories & Luxury Goods',
 4: 'Application Software',
 5: 'Asset Management & Custody Banks',
 6: 'Auto Parts & Equipment',
 7: 'Biotechnology',
 8: 'Building Products',
 9: 'Casinos & Gaming',
 10: 'Commodity Chemicals',
 11: 'Communications Equipment',
 12: 'Construction & Engineering',
 13: 'Construction Machinery & Heavy Trucks',
 14: 'Consumer Finance',
 15: 'Data Processing & Outsourced Services',
 16: 'Diversified Metals & Mining',
 17: 'Diversified Support Services',
 18: 'Electric Utilities',
 19: 'Electrical Components & Equipment',
 20: 'Electronic Equipment & Instruments',
 21: 'Environmental & Facilities Services',
 22: 'Gold',
 23: 'Health Care Equipment',
 24: 'Health Care Facilities',
 25: 'Health Care Services',
 26: 'Health Care Supplies',
 27: 'Health Care Technology',
 28: 'Homebuilding',
 29: 'Hotels, Resorts & Cruise Lines',
 30: 'Human Resource & Employment Services',
 31: 'IT Co

Import __SequenceClassificationExplainer__ from transformers interpret. This class should work with most if not all language models with a sequence classification head from the transformers package. 


In [56]:
from transformers_interpret import SequenceClassificationExplainer

In [57]:
sample_text = """
Stocks ended a choppy session mixed as investors digested a host of corporate earnings results and considered policymakers’ next moves to support the still virus-stricken economy.
The S&P 500 shook off earlier declines to narrowly eke out a record closing high.The Dow ended a tick below its recent record closing level."""

In [58]:
multiclass_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)

In [59]:
# call the exlplainer
word_attributions = multiclass_explainer(text=sample_text)

In [47]:
# seems to be an appropriate prediction
multiclass_explainer.predicted_class_name

'Investment Banking & Brokerage'

In [48]:
# look the the raw word attributions
word_attributions

[('[CLS]', 0.0),
 ('stocks', 0.17552737633544166),
 ('ended', -0.03850303276079214),
 ('a', 0.005213843130977056),
 ('chop', -0.371121096479891),
 ('##py', 0.05474096670708312),
 ('session', 0.11222960595661986),
 ('mixed', 0.04523500620798781),
 ('as', 0.00854365827661389),
 ('investors', 0.2255227571532523),
 ('digest', 0.0921733560914598),
 ('##ed', 0.14644429572594486),
 ('a', -0.1024961457360126),
 ('host', -0.07218376391197992),
 ('of', 0.22421007586224576),
 ('corporate', 0.27895964392789346),
 ('earnings', 0.6027407556184162),
 ('results', -0.09964947197433494),
 ('and', -0.017121916846579367),
 ('considered', 0.24777280032883112),
 ('policy', 0.10375664300957889),
 ('##makers', 0.011601848195960555),
 ('’', 0.011302658016292036),
 ('next', -0.0463631890735315),
 ('moves', -0.004164084793551867),
 ('to', 0.08991350690810906),
 ('support', 0.07168822734392323),
 ('the', 0.027829733880404526),
 ('still', 0.07106037827499724),
 ('virus', 0.10655638470524464),
 ('-', -0.07300297909

## Visualizating Explanations 
With a single call to the `visualize()` method we get a nice inline display of what inputs are causing the activations to fire that led to this prediction. **Note the alogirthm used to calcualte attributions are Layer Integreated Gradients to read more about them click [here](https://captum.ai/docs/algorithms)**

In [49]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
37.0,Investment Banking & Brokerage (0.78),Investment Banking & Brokerage,3.09,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


## Explaining The Same Text For A Different Class 
Lets say we think this text could also fall somewhat under the class of __Asset Management & Custody Banks__ If we want it is also possible to get an explantion/attributions for the text with that class

In [50]:
word_attributions = multiclass_explainer(sample_text, class_name="Asset Management & Custody Banks")

In [51]:
# look the the raw word attributions
word_attributions

[('[CLS]', 0.0),
 ('stocks', -0.1438578556922923),
 ('ended', 0.17111803209429655),
 ('a', -0.06059852939470913),
 ('chop', 0.5140056199117679),
 ('##py', -0.08750479024951703),
 ('session', -0.07754127591758762),
 ('mixed', 0.03618514716747898),
 ('as', -0.009426781836862802),
 ('investors', 0.4988444871081522),
 ('digest', -0.2130213282068211),
 ('##ed', -0.06429117362960683),
 ('a', 0.17772627518824835),
 ('host', 0.15808860741275138),
 ('of', -0.06705090033315818),
 ('corporate', -0.15002084127206863),
 ('earnings', -0.23341488417954476),
 ('results', 0.1262585100394356),
 ('and', 0.05719579442609673),
 ('considered', -0.11990905747738642),
 ('policy', 0.14461759247624129),
 ('##makers', 0.15514267780305824),
 ('’', -0.02760059016487817),
 ('next', 0.02122594849198958),
 ('moves', 0.10101868519168605),
 ('to', -0.05263521974391582),
 ('support', -0.0014980207009961854),
 ('the', 0.08523795410290046),
 ('still', 0.00359120003645743),
 ('virus', 0.008249723049030216),
 ('-', -0.05561

The results are close to the first visualization, a good sign that the model is generalizing well for both of these related classes

In [52]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
5.0,Investment Banking & Brokerage (0.19),Asset Management & Custody Banks,0.43,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


What if we get attributions for class that makes no sense in this context such as __Restaurants__?


In [53]:
word_attributions = multiclass_explainer(sample_text, class_name="Restaurants")

There isn't much to this prediction, it is worth nothing however that the words "choppy" had a more positive impact in this instance which seems plausible given the industry. 

In [54]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
52.0,Investment Banking & Brokerage (0.00),Restaurants,-0.93,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


In [55]:
word_attributions

[('[CLS]', 0.0),
 ('stocks', -0.6854857857533402),
 ('ended', -0.02903191299115817),
 ('a', 0.11090252153715577),
 ('chop', 0.00665206632166715),
 ('##py', -0.20624897442880935),
 ('session', -0.07065545152648157),
 ('mixed', -0.13288203112051317),
 ('as', 0.07492886056924308),
 ('investors', -0.12119252866962944),
 ('digest', -0.12445952999212648),
 ('##ed', -0.09281782081714353),
 ('a', 0.04926223728051275),
 ('host', 0.002458014965772025),
 ('of', -0.04181010382080314),
 ('corporate', 0.0959686863791755),
 ('earnings', 0.19780768353695843),
 ('results', 0.26498239130490375),
 ('and', 0.10765587313687698),
 ('considered', 0.08802611240041186),
 ('policy', 0.11527482485493165),
 ('##makers', -0.09797676987609753),
 ('’', -0.04952404541310052),
 ('next', -0.08798398385255239),
 ('moves', 0.016330765485513606),
 ('to', -0.013984971098826775),
 ('support', -0.07466647321209693),
 ('the', 0.01843903056263634),
 ('still', -0.07228695865498085),
 ('virus', 0.07444843187021678),
 ('-', -0.06