# Transformers Interpret Multiclass Classification Example

In [17]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Import Industry Classification Mode
This finetuned model by @sampathkethineedi uses a distilbert base to predict the professional industry a text is referring to. 

In [18]:
tokenizer = AutoTokenizer.from_pretrained("sampathkethineedi/industry-classification")
model = AutoModelForSequenceClassification.from_pretrained(
    "sampathkethineedi/industry-classification"
)

Let's explore the classes, there are 62 unique classes. Many of these are overlapping/related industries such as __Health Care Equipment__ and __Health Care Supplies__.

In [19]:
model.config.id2label

{0: 'Advertising',
 1: 'Aerospace & Defense',
 2: 'Apparel Retail',
 3: 'Apparel, Accessories & Luxury Goods',
 4: 'Application Software',
 5: 'Asset Management & Custody Banks',
 6: 'Auto Parts & Equipment',
 7: 'Biotechnology',
 8: 'Building Products',
 9: 'Casinos & Gaming',
 10: 'Commodity Chemicals',
 11: 'Communications Equipment',
 12: 'Construction & Engineering',
 13: 'Construction Machinery & Heavy Trucks',
 14: 'Consumer Finance',
 15: 'Data Processing & Outsourced Services',
 16: 'Diversified Metals & Mining',
 17: 'Diversified Support Services',
 18: 'Electric Utilities',
 19: 'Electrical Components & Equipment',
 20: 'Electronic Equipment & Instruments',
 21: 'Environmental & Facilities Services',
 22: 'Gold',
 23: 'Health Care Equipment',
 24: 'Health Care Facilities',
 25: 'Health Care Services',
 26: 'Health Care Supplies',
 27: 'Health Care Technology',
 28: 'Homebuilding',
 29: 'Hotels, Resorts & Cruise Lines',
 30: 'Human Resource & Employment Services',
 31: 'IT Co

Import __SequenceClassificationExplainer__ from transformers interpret. This class should work with most if not all language models with a sequence classification head from the transformers package. 


In [20]:
from transformers_interpret import SequenceClassificationExplainer

In [21]:
sample_text = """
Stocks ended a choppy session mixed as investors digested a host of corporate earnings results and considered policymakers’ next moves to support the still virus-stricken economy.
The S&P 500 shook off earlier declines to narrowly eke out a record closing high.The Dow ended a tick below its recent record closing level."""

In [22]:
multiclass_explainer = SequenceClassificationExplainer(text=sample_text, model=model, tokenizer=tokenizer)

In [23]:
# call the exlplainer
attributions = multiclass_explainer()

In [24]:
# seems to be an appropriate prediction
multiclass_explainer.predicted_class_name

'Investment Banking & Brokerage'

In [25]:
# look the the raw word attributions
attributions.word_attributions

[('[CLS]', 0.0),
 ('stocks', 0.12852050555427189),
 ('ended', 0.0678023663980611),
 ('a', 0.010843463306493134),
 ('chop', -0.1741370883504282),
 ('##py', -0.2869499284798822),
 ('session', 0.02951569420018835),
 ('mixed', -0.05576730134017583),
 ('as', -0.09759507221934578),
 ('investors', 0.2785585664351567),
 ('digest', -0.21567059524531962),
 ('##ed', -0.3518323360336279),
 ('a', -0.08061975667352177),
 ('host', 0.06370071734102233),
 ('of', -0.07011347958188589),
 ('corporate', 0.08471542890790655),
 ('earnings', 0.14865712166334163),
 ('results', 0.09851925006407185),
 ('and', -0.05313666655795184),
 ('considered', -0.16200717034085704),
 ('policy', 0.18906722719996574),
 ('##makers', 0.07233153257954199),
 ('’', -0.08503453329118392),
 ('next', -0.030800841145391163),
 ('moves', 0.09755593103322785),
 ('to', -0.1086500149688071),
 ('support', 0.005649268164363375),
 ('the', 0.003786696146925574),
 ('still', -0.20186011585753563),
 ('virus', 0.01782458344696517),
 ('-', -0.171144

## Visualizating Explanations 
With a single call to the `visualize()` method we get a nice inline display of what inputs are causing the activations to fire that led to this prediction. **Note the alogirthm used to calcualte attributions are Layer Integreated Gradients to read more about them click [here](https://captum.ai/docs/algorithms)**

In [26]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Investment Banking & Brokerage,Investment Banking & Brokerage (0.78),Investment Banking & Brokerage,-1.89,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


## Explaining The Same Text For A Different Class 
Lets say we think this text could also fall somewhat under the class of __Asset Management & Custody Banks__ If we want it is also possible to get an explantion/attributions for the text with that class

In [27]:
attributions = multiclass_explainer(class_name="Asset Management & Custody Banks")

In [28]:
# look the the raw word attributions
attributions.word_attributions

[('[CLS]', 0.0),
 ('stocks', 0.1443226499964759),
 ('ended', 0.16898651064297826),
 ('a', -0.12579349788252478),
 ('chop', -0.17038107468722083),
 ('##py', -0.3670204232006531),
 ('session', 0.024688000298861642),
 ('mixed', 0.061788930164107796),
 ('as', -0.016986853634390717),
 ('investors', 0.3335520518628928),
 ('digest', -0.05479571667862658),
 ('##ed', -0.13245183450109696),
 ('a', -0.13574195708985184),
 ('host', 0.09813786234733723),
 ('of', 0.06590878508605036),
 ('corporate', 0.14073566363485016),
 ('earnings', 0.18212774733566625),
 ('results', 0.14253270209079974),
 ('and', -0.026995538860852136),
 ('considered', -0.00018286749771121526),
 ('policy', 0.3805815892016278),
 ('##makers', 0.0876919664585931),
 ('’', 0.04828180253509558),
 ('next', 0.031988018340632904),
 ('moves', 0.09709131699547752),
 ('to', 0.021208600691770996),
 ('support', 0.0759541095757481),
 ('the', -0.04476977634427086),
 ('still', -0.046166848338842194),
 ('virus', 0.12804904585084104),
 ('-', -0.073

The results are close to the first visualization, a good sign that the model is generalizing well for both of these related classes

In [29]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Investment Banking & Brokerage,Investment Banking & Brokerage (0.19),Asset Management & Custody Banks,0.95,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


What if we get attributions for class that makes no sense in this context such as __Restaurants__?


In [30]:
attributions = multiclass_explainer(class_name="Restaurants")

There isn't much to this prediction, it is worth nothing however that the words "choppy" had a more positive impact in this instance which seems plausible given the industry. 

In [31]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Investment Banking & Brokerage,Investment Banking & Brokerage (0.00),Restaurants,0.91,[CLS] stocks ended a chop ##py session mixed as investors digest ##ed a host of corporate earnings results and considered policy ##makers ’ next moves to support the still virus - stricken economy . the s & p 500 shook off earlier declines to narrowly ek ##e out a record closing high . the dow ended a tick below its recent record closing level . [SEP]
,,,,


In [32]:
attributions.word_attributions

[('[CLS]', 0.0),
 ('stocks', 0.047411473930637625),
 ('ended', 0.052023908212124034),
 ('a', 0.0752584762603034),
 ('chop', -0.10848856622080573),
 ('##py', -0.3458641358911959),
 ('session', -0.02992446427964216),
 ('mixed', -0.0775741971712423),
 ('as', 0.07272603452264179),
 ('investors', 0.1587678979745181),
 ('digest', -0.02524284557609148),
 ('##ed', -0.18750726770081383),
 ('a', 0.10784273295912807),
 ('host', 0.033808812046453585),
 ('of', 0.11717727435888481),
 ('corporate', 0.13452790063497175),
 ('earnings', 0.12521389593281765),
 ('results', 0.2907672513381314),
 ('and', 0.01861359443694477),
 ('considered', 0.13164531355744005),
 ('policy', 0.22353304274582514),
 ('##makers', 0.007924292060352996),
 ('’', 0.05086628826172883),
 ('next', -0.10541361590943575),
 ('moves', 0.002579234192110453),
 ('to', -0.06626001023997512),
 ('support', -0.1715419728384023),
 ('the', 0.12186656665947443),
 ('still', -0.19106584788842457),
 ('virus', 0.02286550522273487),
 ('-', -0.029420968