# Transformers Interpret Multiclass Classification Example

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Import Industry Classification Mode
This finetuned model by @sampathkethineedi uses a distilbert base to predict the professional industry a text is referring to. 

In [2]:
tokenizer = AutoTokenizer.from_pretrained("sampathkethineedi/industry-classification")
model = AutoModelForSequenceClassification.from_pretrained(
    "sampathkethineedi/industry-classification"
)

Let's explore the classes, there are 62 unique classes. Many of these are overlapping/related industries such as __Health Care Equipment__ and __Health Care Supplies__.

In [3]:
model.config.id2label

{0: 'Advertising',
 1: 'Aerospace & Defense',
 2: 'Apparel Retail',
 3: 'Apparel, Accessories & Luxury Goods',
 4: 'Application Software',
 5: 'Asset Management & Custody Banks',
 6: 'Auto Parts & Equipment',
 7: 'Biotechnology',
 8: 'Building Products',
 9: 'Casinos & Gaming',
 10: 'Commodity Chemicals',
 11: 'Communications Equipment',
 12: 'Construction & Engineering',
 13: 'Construction Machinery & Heavy Trucks',
 14: 'Consumer Finance',
 15: 'Data Processing & Outsourced Services',
 16: 'Diversified Metals & Mining',
 17: 'Diversified Support Services',
 18: 'Electric Utilities',
 19: 'Electrical Components & Equipment',
 20: 'Electronic Equipment & Instruments',
 21: 'Environmental & Facilities Services',
 22: 'Gold',
 23: 'Health Care Equipment',
 24: 'Health Care Facilities',
 25: 'Health Care Services',
 26: 'Health Care Supplies',
 27: 'Health Care Technology',
 28: 'Homebuilding',
 29: 'Hotels, Resorts & Cruise Lines',
 30: 'Human Resource & Employment Services',
 31: 'IT Co

Import __SequenceClassificationExplainer__ from transformers interpret. This class should work with most if not all language models with a sequence classification head from the transformers package. 


In [4]:
from transformers_interpret import SequenceClassificationExplainer

In [5]:
sample_text = """
Stocks ended a choppy session mixed as investors digested a host of corporate earnings results and considered policymakers’ next moves to support the still virus-stricken economy.
The S&P 500 shook off earlier declines to narrowly eke out a record closing high.The Dow ended a tick below its recent record closing level."""

In [6]:
multiclass_explainer = SequenceClassificationExplainer(text=sample_text, model=model, tokenizer=tokenizer)

In [7]:
# call the exlplainer
attributions = multiclass_explainer()

In [8]:
# seems to be an appropriate prediction
multiclass_explainer.predicted_class_name

'Investment Banking & Brokerage'

In [9]:
# look the the raw word attributions
attributions.word_attributions

[('[CLS]', 0.0),
 ('stocks', -0.004456745041476635),
 ('ended', 0.13193417174421274),
 ('a', 0.05641152917733269),
 ('chop', -0.012832735630589245),
 ('##py', -0.13861640177798698),
 ('session', 0.04677203070064154),
 ('mixed', 0.00456270865206808),
 ('as', -0.02196107395623291),
 ('investors', 0.3463117654678203),
 ('digest', -0.1500723809400845),
 ('##ed', -0.27232389332855483),
 ('a', 0.02145133233175586),
 ('host', 0.19881990197329233),
 ('of', 0.08638398072671191),
 ('corporate', 0.20302091157480334),
 ('earnings', 0.2154456398173302),
 ('results', 0.16101571665374576),
 ('and', 0.026943728367538014),
 ('considered', -0.07904327915567533),
 ('policy', 0.23334850964528223),
 ('##makers', 0.03744834172027172),
 ('’', -0.044502146745965795),
 ('next', 0.04380059324826015),
 ('moves', 0.15665502037005144),
 ('to', -0.0479011674293024),
 ('support', 0.07189598606497517),
 ('the', 0.021216427378130336),
 ('still', -0.23291363921919042),
 ('virus', 0.047506291065723814),
 ('-', -0.151059

## Visualizating Explanations 
With a single call to the `visualize()` method we get a nice inline display of what inputs are causing the activations to fire that led to this prediction. **Note the alogirthm used to calcualte attributions are Layer Integreated Gradients to read more about them click [here](https://captum.ai/docs/algorithms)**

In [10]:
multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Investment Banking & Brokerage,Investment Banking & Brokerage (0.00),Investment Banking & Brokerage,1.3,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


## Explaining The Same Text For A Different Class 
Lets say we think this text could also fall somewhat under the class of __Asset Management & Custody Banks__ If we want it is also possible to get an explantion/attributions for the text with that class

In [11]:
attributions = multiclass_explainer(class_name="Asset Management & Custody Banks")

In [12]:
# look the the raw word attributions
attributions.word_attributions

[('[CLS]', 0.0),
 ('stocks', -0.013076095817261638),
 ('ended', 0.21332437398379314),
 ('a', -0.0650173968095115),
 ('chop', -0.053776494369439945),
 ('##py', -0.1361883302504456),
 ('session', 0.05245961920595882),
 ('mixed', 0.11133688837284314),
 ('as', 0.0290092563922162),
 ('investors', 0.3508248779709538),
 ('digest', -0.031376772246840434),
 ('##ed', -0.08242556105918372),
 ('a', -0.1282496697255458),
 ('host', 0.1932059547552784),
 ('of', 0.14453142789631288),
 ('corporate', 0.174956812088488),
 ('earnings', 0.18457683227048555),
 ('results', 0.16502234537811186),
 ('and', 0.057246362109120805),
 ('considered', 0.044706561381177),
 ('policy', 0.419625329483318),
 ('##makers', 0.07724468902782176),
 ('’', 0.0641782420831589),
 ('next', 0.0863311384704364),
 ('moves', 0.13913528586360044),
 ('to', 0.06091511183016304),
 ('support', 0.12192537817088814),
 ('the', -0.09255245322690696),
 ('still', -0.07162789933654426),
 ('virus', 0.10373773961955873),
 ('-', -0.054282238520226525)

The results are close to the first visualization, a good sign that the model is generalizing well for both of these related classes

In [13]:
multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Investment Banking & Brokerage,Investment Banking & Brokerage (0.00),Asset Management & Custody Banks,2.61,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


What if we get attributions for class that makes no sense in this context such as __Restaurants__?


In [14]:
attributions = multiclass_explainer(class_name="Restaurants")

There isn't much to this prediction, it is worth nothing however that the words "choppy" had a more positive impact in this instance which seems plausible given the industry. 

In [15]:
multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Investment Banking & Brokerage,Investment Banking & Brokerage (0.00),Restaurants,1.02,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


In [16]:
attributions.word_attributions

[('[CLS]', 0.0),
 ('stocks', -0.09201294942361385),
 ('ended', 0.04407271448219919),
 ('a', 0.16321588179875865),
 ('chop', 0.04632777499253846),
 ('##py', -0.25187041889583356),
 ('session', -0.03191380380572038),
 ('mixed', -0.08486779783022193),
 ('as', 0.0990228369494118),
 ('investors', 0.15616562033570283),
 ('digest', -0.05320900579431406),
 ('##ed', -0.1943873312313048),
 ('a', 0.0701817057427743),
 ('host', 0.02375706910921956),
 ('of', 0.1987030179538432),
 ('corporate', 0.21816854083740192),
 ('earnings', 0.1034336318064541),
 ('results', 0.25163873610032916),
 ('and', -0.009519313354122897),
 ('considered', 0.14355278315043185),
 ('policy', 0.25695840883129323),
 ('##makers', -0.0067529899810760685),
 ('’', 0.042844477015168185),
 ('next', -0.09635841194957916),
 ('moves', 0.04030577027611355),
 ('to', -0.05089957413373526),
 ('support', -0.13898610509971007),
 ('the', 0.16586873963353052),
 ('still', -0.1971275206247004),
 ('virus', 0.048957715766651075),
 ('-', 0.01272739