https://github.com/robinvanschaik/interpret-flair

In [30]:
from flair.models import TextClassifier
from flair.data import Sentence
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
from interpretation_package.flair_model_wrapper import ModelWrapper
from interpretation_package.interpret_flair import interpret_sentence, visualize_attributions
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [129]:
# redefine the function from interpret_flair, in order to control whether to norm the word_attributions before sum
# also prevent the attributions becomes Nan if the norm is 0.
def summarize_attributions(attributions, need_norm=True):
    """
    Helper function for calculating word attributions.
    Inputs:
    attributions_ig: integrated gradients attributions.
    Ouputs:
    word_attributions: the attributions score per token.
    attribution_score: the attribution score of the entire document w.r.t. ground label.
    """
    attributions = attributions.sum(dim=-1).squeeze(0)
    if need_norm==True:
        if torch.norm(attributions) != 0:
            attributions = attributions / torch.norm(attributions)
    attribution_score = attributions.sum()

    return attributions, attribution_score

In [132]:
# redefine the function from interpret_flair, take control of whether to add special tokens, 
#                      whether to add softmax layer in the attribution calculation or after
#                       whether to use all zero baseline or base_line filled with [PAD] token
def interpret_sentence2(flair_model_wrapper, lig, sentence, target_label, visualization_list, n_steps=100, estimation_method="gausslegendre",internal_batch_size=None, add_special_tokens=True, need_softmax = True, need_norm=True, pad_base_line=True):
    """
    We can visualise the attributions made by making use of Pytorch Captum.
    Inputs:
    flair_model_wrapper: class containing a customized forward function of Flair model.
    lig: the layer integrated gradient object.
    sentence: the Flair sentence-object we want to interpret.
    target_label: the ground truth class-label of the sentence.
    visualization_list: a list to store the visualization records in.
    """

    # Return the target index from the label dictionary.
    target_index = flair_model_wrapper.label_dictionary.get_idx_for_item(target_label)
#     target_index = target_index.to(device)
    
    # In order maintain consistency with Flair, we apply the same tokenization
    # steps.
    flair_sentence = Sentence(sentence)

    tokenized_sentence = flair_sentence.to_tokenized_string()

    tokenizer_max_length = flair_model_wrapper.tokenizer.model_max_length

    # This calculates the token input IDs tensor for the model.
    input_ids = flair_model_wrapper.tokenizer.encode(tokenized_sentence,
                                                     add_special_tokens=add_special_tokens,
                                                     max_length=tokenizer_max_length,
                                                     truncation=True,
                                                     return_tensors="pt")
    input_ids = input_ids.to(device)
    
    # Create a baseline by creating a tensor of equal length
    # containing the padding token tensor id.
    pad_token_id = flair_model_wrapper.tokenizer.pad_token_id

    ref_base_line = torch.full_like(input_ids, pad_token_id)

    # Convert back to tokens as the model requires.
    # As some words might get split up. e.g. Caroll to Carol l.
    all_tokens = flair_model_wrapper.tokenizer.convert_ids_to_tokens(input_ids[0])

    # The tokenizer in the model adds a special character
    # in front of every sentence.
    readable_tokens = [token.replace("▁", "") for token in all_tokens]

    # The input IDs are passed to the embedding layer of the model.
    # It is better to return the logits for Captum.
    # https://github.com/pytorch/captum/issues/355#issuecomment-619610044
    # Thus we calculate the softmax afterwards.
    # For now, I take the first dimension and run this sentence, per sentence.
    model_outputs = flair_model_wrapper(input_ids)
    if need_softmax == True:
        softmax = torch.nn.functional.softmax(model_outputs[0], dim=0)
        # Return the confidence and the class ID of the top predicted class.
        conf, idx = torch.max(softmax, 0)
    else:
        conf, idx = torch.max(model_outputs[0], 0)
    # Returns the probability.
    prediction_confidence = conf.item()

    # Returns the label name from the top prediction class.
    pred_label = flair_model_wrapper.label_dictionary.get_item_for_index(idx.item())
    if pad_base_line == True:
    # Calculate the attributions according to the LayerIntegratedGradients method.
        attributions_ig, delta = lig.attribute(input_ids,
                                               baselines=ref_base_line,
                                               n_steps=n_steps,
                                               return_convergence_delta=True,
                                               target=target_index,
                                               method=estimation_method,
                                               internal_batch_size=internal_batch_size)
    else:
        attributions_ig, delta = lig.attribute(input_ids,
                                               n_steps=n_steps,
                                               return_convergence_delta=True,
                                               target=target_index,
                                               method=estimation_method,
                                               internal_batch_size=internal_batch_size)
    convergence_delta = abs(delta)
    print('pred: ', idx.item(), '(', '%.2f' % conf.item(), ')', ', delta: ', convergence_delta)


    word_attributions, attribution_score = summarize_attributions(attributions_ig, need_norm=need_norm)


    visualization_list.append(
    visualization.VisualizationDataRecord(word_attributions=word_attributions,
                                pred_prob=prediction_confidence,
                                pred_class=pred_label,
                                true_class=target_label,
                                attr_class=target_label,
                                attr_score=attribution_score,
                                raw_input=readable_tokens,
                                convergence_score=delta)
                    )

    # Return these for the sanity checks.
    return readable_tokens, word_attributions, convergence_delta

In [37]:
classifier = TextClassifier.load('sentiment')

2021-01-14 23:54:47,901 loading file /home/joey/.flair/models/sentiment-en-mix-distillbert_3.1.pt


In [39]:
pad_token_id = flair_model_wrapper.tokenizer.pad_token_id

In [80]:
sentences = ["It's a great day.", 
             "It's absolutely not a great day!",
             "It's ABSOLUTELY not a great day!!!",
             "Today sucks",
             "I had a car accident today.",
             "I'm sorry to hear this happened.",
             "I can hear you clearly.",
             "He happened to be the president.",
            ]

# Interpret-flair function wrapper
## without special tokens (default)

In [64]:
flair_model_wrapper = ModelWrapper(classifier)
lig = LayerIntegratedGradients(flair_model_wrapper, flair_model_wrapper.model.embeddings)

In [65]:
target_label = flair_model_wrapper.label_dictionary.get_item_for_index(1)
print(target_label)

POSITIVE


In [127]:
# classifier output on the sentences
for sentence in sentences:
    s = Sentence(sentence)
    classifier.predict(s)
    print(s.labels[0])


POSITIVE (0.9938)
NEGATIVE (0.9997)
NEGATIVE (0.9994)
NEGATIVE (0.9998)
NEGATIVE (0.8022)
NEGATIVE (0.9993)
POSITIVE (0.9968)
POSITIVE (0.7145)


In [81]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence(flair_model_wrapper,
                        lig,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        internal_batch_size=3)

pred:  1 ( 0.99 ) , delta:  tensor([0.6045], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([3.3334], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([2.7723], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.69 ) , delta:  tensor([1.0796], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.95 ) , delta:  tensor([0.1376], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.97 ) , delta:  tensor([1.4901], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.83 ) , delta:  tensor([0.6609], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.51 ) , delta:  tensor([0.7495], device='cuda:0', dtype=torch.float64)


In [82]:
visualize_attributions(visualization_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.24,it ' s a great day .
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.89,it ' s absolutely not a great day !
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-1.12,it ' s absolutely not a great day ! ! !
,,,,
POSITIVE,NEGATIVE (0.69),POSITIVE,-0.74,today sucks
,,,,
POSITIVE,POSITIVE (0.95),POSITIVE,-0.06,i had a car accident today .
,,,,


## with special tokens

In [99]:
for sentence in sentences:
    s = Sentence(sentence)
    classifier.predict(s)
    print(s)

Sentence: "It 's a great day ."   [− Tokens: 6  − Sentence-Labels: {'label': [POSITIVE (0.9938)]}]
Sentence: "It 's absolutely not a great day !"   [− Tokens: 8  − Sentence-Labels: {'label': [NEGATIVE (0.9997)]}]
Sentence: "It 's ABSOLUTELY not a great day !! !"   [− Tokens: 9  − Sentence-Labels: {'label': [NEGATIVE (0.9994)]}]
Sentence: "Today sucks"   [− Tokens: 2  − Sentence-Labels: {'label': [NEGATIVE (0.9998)]}]
Sentence: "I had a car accident today ."   [− Tokens: 7  − Sentence-Labels: {'label': [NEGATIVE (0.8022)]}]
Sentence: "I 'm sorry to hear this happened ."   [− Tokens: 8  − Sentence-Labels: {'label': [NEGATIVE (0.9993)]}]
Sentence: "I can hear you clearly ."   [− Tokens: 6  − Sentence-Labels: {'label': [POSITIVE (0.9968)]}]
Sentence: "He happened to be the president ."   [− Tokens: 7  − Sentence-Labels: {'label': [POSITIVE (0.7145)]}]


In [111]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence2(flair_model_wrapper,
                        lig,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=True, 
                        need_softmax = True,
                        internal_batch_size=3)

pred:  1 ( 0.99 ) , delta:  tensor([0.7097], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([3.1532], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([3.0233], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([3.9005], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.80 ) , delta:  tensor([1.1394], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([2.6181], device='cuda:0', dtype=torch.float64)
pred:  1 ( 1.00 ) , delta:  tensor([1.1234], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.71 ) , delta:  tensor([0.6667], device='cuda:0', dtype=torch.float64)


In [112]:
visualize_attributions(visualization_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.77,[CLS] it ' s a great day . [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-1.7,[CLS] it ' s absolutely not a great day ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-1.47,[CLS] it ' s absolutely not a great day ! ! ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.71,[CLS] today sucks [SEP]
,,,,
POSITIVE,NEGATIVE (0.80),POSITIVE,-0.77,[CLS] i had a car accident today . [SEP]
,,,,


# Self defined wrapper

In [105]:
# redefine the ModelWrapper, adding the softmax inside
class ModelWrapper2(nn.Module):

    def __init__(self, flair_model, layers: str = "-1"):
        super(ModelWrapper2, self).__init__()

        # Pass the flair
        self.flair_model = flair_model
        
        # Shorthand for the actual PyTorch model.
        self.model = flair_model.document_embeddings.model

        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
            
        self.model.eval()
        self.model.zero_grad()

        # Split the name to automatically grab the right tokenizer.
        self.model_name = flair_model.document_embeddings.get_names()[0].split('transformer-document-')[-1]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        self.label_dictionary = self.flair_model.label_dictionary
        self.num_classes = len(self.flair_model.label_dictionary)
        self.embedding_length = self.flair_model.document_embeddings.embedding_length

        self.initial_cls_token = flair_model.document_embeddings.initial_cls_token

        if layers == 'all':
            # send mini-token through to check how many layers the model has
            hidden_states = self.model(torch.tensor([1], device=device).unsqueeze(0))[-1]
            self.layer_indexes = [int(x) for x in range(len(hidden_states))]
        else:
            self.layer_indexes = [int(x) for x in layers.split(",")]
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_ids):
        # Run the input embeddings through all the layers.
        # Return the hidden states of the model.
        hidden_states = self.model(input_ids=input_ids)[-1]

        # BERT has an initial CLS token.
        # Meaning that the the first token contains the classification.
        # Other models have this as the top layer.
        index_of_CLS_token = 0 if self.initial_cls_token else input_ids.shape[1] -1

        # For batching we need to replace
        # [layer][0][index_of_CLS_token]
        # with [layer][i][index_of_CLS_token].
        cls_embeddings_all_layers = \
            [hidden_states[layer][0][index_of_CLS_token] for layer in self.layer_indexes]

        output_embeddings = torch.cat(cls_embeddings_all_layers)

        # https://github.com/pytorch/captum/issues/355#issuecomment-619610044
        # It's better to attribute the logits to the inputs.
        label_scores = self.flair_model.decoder(output_embeddings)

        # Captum expects [#examples, #classes] as size.
        # We do to this so we can specify the target class with multiclass
        # models.
        label_scores_resized = torch.reshape(label_scores, (1, self.num_classes))
        
        label_scores_softmaxed = self.softmax(label_scores_resized)
        return label_scores_softmaxed

In [106]:
flair_model_wrapper2 = ModelWrapper2(classifier)
lig2 = LayerIntegratedGradients(flair_model_wrapper2, flair_model_wrapper2.model.embeddings)

## without special tokens
The confidence level are the same with default wrapper without special token.

This one may not make sense

In [89]:
for sentence in sentences:
    s = Sentence(sentence)
    classifier.predict(s)
    print(s)

Sentence: "It 's a great day ."   [− Tokens: 6  − Sentence-Labels: {'label': [POSITIVE (0.9938)]}]
Sentence: "It 's absolutely not a great day !"   [− Tokens: 8  − Sentence-Labels: {'label': [NEGATIVE (0.9997)]}]
Sentence: "It 's ABSOLUTELY not a great day !! !"   [− Tokens: 9  − Sentence-Labels: {'label': [NEGATIVE (0.9994)]}]
Sentence: "Today sucks"   [− Tokens: 2  − Sentence-Labels: {'label': [NEGATIVE (0.9998)]}]
Sentence: "I had a car accident today ."   [− Tokens: 7  − Sentence-Labels: {'label': [NEGATIVE (0.8022)]}]
Sentence: "I 'm sorry to hear this happened ."   [− Tokens: 8  − Sentence-Labels: {'label': [NEGATIVE (0.9993)]}]
Sentence: "I can hear you clearly ."   [− Tokens: 6  − Sentence-Labels: {'label': [POSITIVE (0.9968)]}]
Sentence: "He happened to be the president ."   [− Tokens: 7  − Sentence-Labels: {'label': [POSITIVE (0.7145)]}]


In [113]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence2(flair_model_wrapper2,
                        lig2,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=False, 
                        need_softmax = False,
                        internal_batch_size=3)

pred:  1 ( 0.99 ) , delta:  tensor([0.0137], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6364], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6351], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.69 ) , delta:  tensor([0.4448], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.95 ) , delta:  tensor([0.0138], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.97 ) , delta:  tensor([0.6173], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.83 ) , delta:  tensor([0.0899], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.51 ) , delta:  tensor([0.3035], device='cuda:0', dtype=torch.float64)


In [114]:
visualize_attributions(visualization_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.04,it ' s a great day .
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-1.84,it ' s absolutely not a great day !
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-2.31,it ' s absolutely not a great day ! ! !
,,,,
POSITIVE,NEGATIVE (0.69),POSITIVE,-0.75,today sucks
,,,,
POSITIVE,POSITIVE (0.95),POSITIVE,-0.02,i had a car accident today .
,,,,


## with special tokens
Comparing to default wrapper with special tokens. The attribution score improve a little, but still not in range(-1,1)

In [89]:
for sentence in sentences:
    s = Sentence(sentence)
    classifier.predict(s)
    print(s)

Sentence: "It 's a great day ."   [− Tokens: 6  − Sentence-Labels: {'label': [POSITIVE (0.9938)]}]
Sentence: "It 's absolutely not a great day !"   [− Tokens: 8  − Sentence-Labels: {'label': [NEGATIVE (0.9997)]}]
Sentence: "It 's ABSOLUTELY not a great day !! !"   [− Tokens: 9  − Sentence-Labels: {'label': [NEGATIVE (0.9994)]}]
Sentence: "Today sucks"   [− Tokens: 2  − Sentence-Labels: {'label': [NEGATIVE (0.9998)]}]
Sentence: "I had a car accident today ."   [− Tokens: 7  − Sentence-Labels: {'label': [NEGATIVE (0.8022)]}]
Sentence: "I 'm sorry to hear this happened ."   [− Tokens: 8  − Sentence-Labels: {'label': [NEGATIVE (0.9993)]}]
Sentence: "I can hear you clearly ."   [− Tokens: 6  − Sentence-Labels: {'label': [POSITIVE (0.9968)]}]
Sentence: "He happened to be the president ."   [− Tokens: 7  − Sentence-Labels: {'label': [POSITIVE (0.7145)]}]


In [115]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence2(flair_model_wrapper2,
                        lig2,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=True, 
                        need_softmax = False,
                        internal_batch_size=3)

pred:  1 ( 0.99 ) , delta:  tensor([0.0186], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6323], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6466], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6507], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.80 ) , delta:  tensor([0.5074], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6349], device='cuda:0', dtype=torch.float64)
pred:  1 ( 1.00 ) , delta:  tensor([0.0314], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.71 ) , delta:  tensor([0.1624], device='cuda:0', dtype=torch.float64)


In [116]:
visualize_attributions(visualization_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.16,[CLS] it ' s a great day . [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-1.3,[CLS] it ' s absolutely not a great day ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-1.22,[CLS] it ' s absolutely not a great day ! ! ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.67,[CLS] today sucks [SEP]
,,,,
POSITIVE,NEGATIVE (0.80),POSITIVE,-1.22,[CLS] i had a car accident today . [SEP]
,,,,


## without normalization
The attribution score is in the right magnitude. But the baseline turns out to have non-neutral sentiment

In [131]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence2(flair_model_wrapper2,
                        lig2,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=True, 
                        need_softmax = False,
                        need_norm=False,
                        internal_batch_size=3)
visualize_attributions(visualization_list)

pred:  1 ( 0.99 ) , delta:  tensor([0.0186], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6323], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6466], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6507], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.80 ) , delta:  tensor([0.5074], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6349], device='cuda:0', dtype=torch.float64)
pred:  1 ( 1.00 ) , delta:  tensor([0.0314], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.71 ) , delta:  tensor([0.1624], device='cuda:0', dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.02,[CLS] it ' s a great day . [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.32,[CLS] it ' s absolutely not a great day ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.31,[CLS] it ' s absolutely not a great day ! ! ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.34,[CLS] today sucks [SEP]
,,,,
POSITIVE,NEGATIVE (0.80),POSITIVE,-0.25,[CLS] i had a car accident today . [SEP]
,,,,


## zero embedding as baseline
Instead of the default embedding filling with PAD token

In [133]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence2(flair_model_wrapper2,
                        lig2,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=True, 
                        need_softmax = False,
                        need_norm=False,
                        pad_base_line=False,
                        internal_batch_size=3)
visualize_attributions(visualization_list)

pred:  1 ( 0.99 ) , delta:  tensor([0.0186], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6323], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6466], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6507], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.80 ) , delta:  tensor([0.5074], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6349], device='cuda:0', dtype=torch.float64)
pred:  1 ( 1.00 ) , delta:  tensor([0.0314], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.71 ) , delta:  tensor([0.1624], device='cuda:0', dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.02,[CLS] it ' s a great day . [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.32,[CLS] it ' s absolutely not a great day ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.31,[CLS] it ' s absolutely not a great day ! ! ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.34,[CLS] today sucks [SEP]
,,,,
POSITIVE,NEGATIVE (0.80),POSITIVE,-0.25,[CLS] i had a car accident today . [SEP]
,,,,


## Modify the embeding layer directly
Another way to use zero embeddings as baseline

In [134]:
classifier2 = TextClassifier.load('sentiment')
# manually set the PAD token embedding as 0
classifier2.document_embeddings.model.embeddings.word_embeddings.weight[0,:] = 0

2021-01-15 03:52:57,319 loading file /home/joey/.flair/models/sentiment-en-mix-distillbert_3.1.pt


In [135]:
flair_model_wrapper3 = ModelWrapper2(classifier2)
lig3 = LayerIntegratedGradients(flair_model_wrapper3, flair_model_wrapper3.model.embeddings)

  if p.grad is not None:


In [137]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences:
    interpret_sentence2(flair_model_wrapper2,
                        lig2,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=True, 
                        need_softmax = False,
                        need_norm=False,
                        pad_base_line=True,
                        internal_batch_size=3)
visualize_attributions(visualization_list)

pred:  1 ( 0.99 ) , delta:  tensor([0.0186], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6323], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6466], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6507], device='cuda:0', dtype=torch.float64)
pred:  0 ( 0.80 ) , delta:  tensor([0.5074], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6349], device='cuda:0', dtype=torch.float64)
pred:  1 ( 1.00 ) , delta:  tensor([0.0314], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.71 ) , delta:  tensor([0.1624], device='cuda:0', dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,POSITIVE (0.99),POSITIVE,0.02,[CLS] it ' s a great day . [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.32,[CLS] it ' s absolutely not a great day ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.31,[CLS] it ' s absolutely not a great day ! ! ! [SEP]
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-0.34,[CLS] today sucks [SEP]
,,,,
POSITIVE,NEGATIVE (0.80),POSITIVE,-0.25,[CLS] i had a car accident today . [SEP]
,,,,


# input baseline embeddings to classifier
Confirm the truth that the baseline doesn't have a neutral sentiment

In [143]:
for sentence in sentences:
    flair_sentence = Sentence(sentence)
    tokenized_sentence = flair_sentence.to_tokenized_string()
    input_ids = flair_model_wrapper2.tokenizer.encode(tokenized_sentence,
                                                         add_special_tokens=True,
                                                         max_length=tokenizer_max_length,
                                                         truncation=True,
                                                         return_tensors="pt")
    input_ids = input_ids.to(device)
    ref_base_line = torch.full_like(input_ids, pad_token_id)

    print(flair_model_wrapper2.forward(ref_base_line))

tensor([[0.0454, 0.9546]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0469, 0.9531]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0451, 0.9549]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0112, 0.9888]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0454, 0.9546]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0469, 0.9531]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0418, 0.9582]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([[0.0454, 0.9546]], device='cuda:0', grad_fn=<SoftmaxBackward>)


# paragraph input
Test with paragraph to see the how the tokens are added

In [123]:
sentences2 = ["The tedious honors Calculus class that he taught just before lunch was not the highlight of his day.  Not that he didn’t like the subject matter, math had always come easy to him, but attempting to convince a group of 11th grade students that the logic of derivatives was actually something that they needed to master in order to survive was another matter.", 
              "The boring calculus class was not the highlight of his day.",
              "President-elect Joseph R. Biden Jr. on Thursday proposed a $1.9 trillion rescue package to combat the economic downturn and the Covid-19 crisis, outlining the type of sweeping aid that Democrats have demanded for months and signaling the shift in the federal government’s pandemic response as Mr. Biden prepares to take office.",
              "He proposed a $1.9 trillion rescue package to combat the economic downturn"
             ]

In [124]:
visualization_list = []
target_label = 'POSITIVE'
for sentence in sentences2:
    interpret_sentence2(flair_model_wrapper2,
                        lig2,
                        sentence,
                        target_label,
                        visualization_list,
                        n_steps=500,
                        estimation_method="gausslegendre",
                        add_special_tokens=True, 
                        need_softmax = False,
                        internal_batch_size=3)

pred:  0 ( 1.00 ) , delta:  tensor([0.6416], device='cuda:0', dtype=torch.float64)
pred:  0 ( 1.00 ) , delta:  tensor([0.6392], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.96 ) , delta:  tensor([0.0066], device='cuda:0', dtype=torch.float64)
pred:  1 ( 0.64 ) , delta:  tensor([0.2201], device='cuda:0', dtype=torch.float64)


In [125]:
visualize_attributions(visualization_list)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POSITIVE,NEGATIVE (1.00),POSITIVE,-3.84,"[CLS] the ted ##ious honors calculus class that he taught just before lunch was not the highlight of his day . not that he did n ’ t like the subject matter , math had always come easy to him , but attempting to convince a group of 11th grade students that the logic of derivatives was actually something that they needed to master in order to survive was another matter . [SEP]"
,,,,
POSITIVE,NEGATIVE (1.00),POSITIVE,-2.23,[CLS] the boring calculus class was not the highlight of his day . [SEP]
,,,,
POSITIVE,POSITIVE (0.96),POSITIVE,0.12,"[CLS] president - elect joseph r . bid ##en jr . on thursday proposed a $ 1 . 9 trillion rescue package to combat the economic down ##turn and the co ##vid - 19 crisis , out ##lining the type of sweeping aid that democrats have demanded for months and signaling the shift in the federal government ’ s pan ##de ##mic response as mr . bid ##en prepares to take office . [SEP]"
,,,,
POSITIVE,POSITIVE (0.64),POSITIVE,-0.39,[CLS] he proposed a $ 1 . 9 trillion rescue package to combat the economic down ##turn [SEP]
,,,,


# ref

- The paper that proposes Integrated Gradients:  https://arxiv.org/pdf/1703.01365.pdf
- Captum tutorial about how to visual text attribution (cannot be used directly, since the tokenizer and the NN are in different class/type): https://captum.ai/tutorials/IMDB_TorchText_Interpret
- A repo explaining how to make flair models compatiable with Captum (There is a bug with the CUDA memory. It also doesn't include the softmax in attribution calculation): https://github.com/robinvanschaik/interpret-flair
- Captum documentation of layer-integrated-gradients, which is made for visualizing text: https://captum.ai/api/layer.html#layer-integrated-gradients