In [None]:
import pandas as pd
import numpy as np
import re

SEED = 42


from transformers import (DataCollatorWithPadding, Trainer, TrainingArguments,
                          LongformerTokenizer, LongformerForSequenceClassification,
                          LongformerConfig)

from transformers.models.longformer.modeling_longformer import create_position_ids_from_input_ids

from datasets import Dataset, DatasetDict

from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

import torch
from torch.utils.data import DataLoader
assert torch.cuda.is_available(), 'GPU not found. You should fix this.'

In [None]:
from captum.attr import Saliency, LayerIntegratedGradients, IntegratedGradients
from captum.attr import visualization as viz

## Data

In [None]:
def get_datadict(score_to_predict):
    
    scores = {
        'Overall',
        'Cohesion',
        'Syntax',
        'Vocabulary',
        'Phraseology',
        'Grammar',
        'Conventions'
    }
    
    columns_to_remove = scores.symmetric_difference([score_to_predict])
    
    dd = (DatasetDict
          .load_from_disk('../data/ellipse.hf')
          .remove_columns(columns_to_remove)
          .rename_column(score_to_predict, 'label')
         )
    
    return dd

In [None]:
score_to_predict = 'Grammar'

dd = get_datadict(score_to_predict)
dd

## Testbed

In [None]:
model_chkpt = '../bin/checkpoint-284/'
model = LongformerForSequenceClassification.from_pretrained(model_chkpt, num_labels=1).cuda()
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

Prepare a sample item

In [37]:
def forward_func(input_embedding, attention_mask, global_attention_mask):
    return model(
        inputs_embeds=input_embedding,
        attention_mask=attention_mask,
        global_attention_mask=global_attention_mask,
    ).logits

In [None]:
def predict(input_ids, position_ids, attention_mask):
    return model(
        input_ids,
        attention_mask=attention_mask,
        # global_attention_mask=global_attention_mask,
        # token_type_ids=token_type_ids,
        position_ids=position_ids,
    ).logits

In [76]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [73]:
sample = dd['train'].shuffle()[0]
text_id = sample.pop('text_id')
true_score = sample.pop('label')    
print(true_score)

sample = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in sample.items()}

manual_embed = model.longformer.embeddings(
    input_ids=sample['input_ids'],
)


sample['global_attention_mask'] = torch.zeros_like(sample['input_ids'])
sample['global_attention_mask'][:, 0] = 1

tokens = [
    t.replace('Ġ', '')
    for t in 
    tokenizer.convert_ids_to_tokens(
        sample['input_ids'][0].detach().tolist()
    )
]

print(tokens[:25])

with torch.no_grad():    
    print(model(sample['input_ids']))
    
    print(
        forward_func(
            manual_embed,
            attention_mask=sample['attention_mask'],
            global_attention_mask=sample['global_attention_mask'],
        )
    )

3.5
['<s>', 'I', 'Believe', 'people', 'who', 'have', 'positive', 'attitude', "'s", 'towards', 'life', ',', 'are', 'one', 'of', 'the', 'biggest', 'keys', 'to', 'success', 'in', 'life', '.', 'You', 'will']
LongformerSequenceClassifierOutput(loss=None, logits=tensor([[3.5902]], device='cuda:0'), hidden_states=None, attentions=None, global_attentions=None)
tensor([[3.5470]], device='cuda:0')


In [74]:
saliency = Saliency(forward_func)

attribution = saliency.attribute(inputs=manual_embed,
                                 additional_forward_args=(
                                     sample['attention_mask'],
                                     sample['global_attention_mask'],
                                 ),
                                 abs=False
                                )

In [77]:
attribution_sum = summarize_attributions(attribution)

# storing couple samples in an array for visualization purposes
position_vis = viz.VisualizationDataRecord(
    attribution_sum, # token attributions
    torch.max(torch.softmax(pred_score[0], dim=0)), # pred_prob
    round(pred_score.logits.item(), 2), # pred_class
    true_score, # true_class
    None, # attr_class
    attribution_sum.sum(), # attr_score
    tokens, # raw_input_ids
    None # convergence score
)

print('\033[1m', 'Visualizations', '\033[0m')
viz.visualize_text([position_vis])

NameError: name 'pred_score' is not defined

In [None]:
from torch.utils.data import DataLoader

ds = dd.with_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

dataloader = DataLoader(ds['dev'].shuffle().select(range(10)), batch_size=1)

### Saliency with Interpretable Embedding layer

In [28]:
for batch in dataloader:
    labels = batch.pop('label')
            
    batch['position_ids'] = create_position_ids_from_input_ids(batch['input_ids'], model.config.pad_token_id)
    
    batch = {k: v.cuda() for k, v in batch.items()}

    tokens = tokenizer.convert_ids_to_tokens(batch['input_ids'][0].detach().tolist())
    
    tokens = [t.replace('Ġ', '') for t in tokens]
    
    input_emb = interpretable_emb.indices_to_embeddings(batch['input_ids'])
    
    with torch.no_grad():
        score = predict(input_emb)
        
    saliency = Saliency(predict)
    
    attribution, delta = saliency.attribute(inputs=input_emb,
                                            additional_forward_args=(
                                                batch['position_ids'],
                                                batch['attention_mask']),
                                            # return_convergence_delta=True
                                           )
    
    attribution_sum = summarize_attributions(attribution)
    
    
    # storing couple samples in an array for visualization purposes
    position_vis = viz.VisualizationDataRecord(
        attribution_sum,
        torch.max(torch.softmax(score[0], dim=0)),
        score,
        score,
        str(0),
        attribution_sum.sum(),       
        tokens,
        delta)

    print('\033[1m', 'Visualizations', '\033[0m')
    viz.visualize_text([position_vis])

RuntimeError: The size of tensor a (478) must match the size of tensor b (768) at non-singleton dimension 2

### Layer Integrated Gradients

In [None]:
for batch in dataloader:
    labels = batch.pop('label')
        
    # outputs = model(**batch, output_hidden_states=True)
    
    batch['position_ids'] = create_position_ids_from_input_ids(batch['input_ids'], model.config.pad_token_id)
    
    batch = {k: v.cuda() for k, v in batch.items()}
    
    with torch.no_grad():
        score = model(**batch).logits
    
    lig = LayerIntegratedGradients(predict, model.longformer.embeddings)
    
    tokens = tokenizer.convert_ids_to_tokens(batch['input_ids'][0].detach().tolist())
    
    tokens = [t.replace('Ġ', '') for t in tokens]
    
    attribution, delta = lig.attribute(inputs=batch['input_ids'],
                                       additional_forward_args=(
                                                                batch['position_ids'],
                                                                batch['attention_mask']),
                                       return_convergence_delta=True)
    
    attribution_sum = summarize_attributions(attribution)
    
    
    # storing couple samples in an array for visualization purposes
    position_vis = viz.VisualizationDataRecord(
        attribution_sum,
        torch.max(torch.softmax(score[0], dim=0)),
        score,
        score,
        str(0),
        attribution_sum.sum(),       
        tokens,
        delta)

    print('\033[1m', 'Visualizations', '\033[0m')
    viz.visualize_text([position_vis])

### Code from Kaggle

https://www.kaggle.com/code/rhtsingh/interpreting-text-models-with-bert-on-tpu

In [None]:
def process(text, label):
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    
    if label == 0:
        lig = LayerIntegratedGradients(custom_forward_0, model.bert.embeddings)
    elif label == 1:
        lig = LayerIntegratedGradients(custom_forward_1, model.bert.embeddings)
    elif label == 2:
        lig = LayerIntegratedGradients(custom_forward_2, model.bert.embeddings)
    
    attributions_main, delta_main = lig.attribute(inputs=input_ids,
                                                  baselines=ref_input_ids,
                                                  n_steps = 150,
                                                  additional_forward_args=(token_type_ids, attention_mask),
                                                  return_convergence_delta=True)
    
    score = predict(input_ids, token_type_ids, attention_mask)
    attributions_main = attributions_main.cpu()
    delta_main = delta_main.cpu()
    score = score.cpu()
    add_attributions_to_visualizer(attributions_main, delta_main, text, score, label, all_tokens)
    
def add_attributions_to_visualizer(attributions, delta, text, score, label, all_tokens):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu()

    score_vis.append(
        viz.VisualizationDataRecord(
            attributions,
            torch.softmax(score, dim = 1)[0][label],
            torch.argmax(torch.softmax(score, dim = 1)[0]),
            label,
            text,
            attributions.sum(),
            all_tokens,
            delta
        )
    ) 