In [1]:
!pip install captum
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting captum
  Downloading captum-0.6.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: captum
Successfully installed captum-0.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.1-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, BertForQuestionAnswering, BertConfig
import nltk
import torch.nn as nn
from captum.attr import IntegratedGradients, InterpretableEmbeddingBase, TokenReferenceBase, visualization, configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
from captum.attr import LayerConductance, LayerIntegratedGradients
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report

In [3]:
nltk.download('reuters')
from nltk.corpus import reuters

# Set the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the Reuters dataset
document_ids = reuters.fileids()
documents = [reuters.raw(doc_id) for doc_id in document_ids]
labels = [reuters.categories(doc_id)[0] for doc_id in document_ids]
label2idx = {label: idx for idx, label in enumerate(set(labels))}
encoded_labels = [label2idx[label] for label in labels]

# Define the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label2idx))

# Tokenize the input documents
tokenized_inputs = tokenizer(documents, padding=True, truncation=True, return_tensors='pt')

# Prepare the dataset
input_ids = tokenized_inputs['input_ids']
attention_mask = tokenized_inputs['attention_mask']
labels_tensor = torch.tensor(encoded_labels)

# Split the dataset into training and testing
train_inputs, test_inputs, train_masks, test_masks, train_labels, test_labels = train_test_split(input_ids, attention_mask, labels_tensor, test_size=0.2, random_state=42)

train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
test_dataset = TensorDataset(test_inputs, test_masks, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

# Set the model in training mode
model.train()
model.to(device)

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

[nltk_data] Downloading package reuters to /root/nltk_data...


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [4]:
# Training loop
for epoch in range(2):  # Run for 2 epochs
    epoch_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    
    for batch in train_loader:
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()

        # Calculate accuracy
        predicted_labels = outputs.logits.argmax(dim=1)
        correct_predictions += (predicted_labels == labels).sum().item()
        total_predictions += labels.size(0)

    accuracy = correct_predictions / total_predictions
    average_loss = epoch_loss / len(train_loader)
    
    print(f"Epoch {epoch + 1}/{2} - Loss: {average_loss:.4f} - Accuracy: {accuracy:.4f}")

# Save the trained model
model.save_pretrained("./saved_model")

Epoch 1/2 - Loss: 1.3894 - Accuracy: 0.7219
Epoch 2/2 - Loss: 0.5680 - Accuracy: 0.8837


In [6]:
# Set the model in evaluation mode
# Set the device (CPU or GPU)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained("./saved_model")
# Set the model in evaluation mode
model.eval()
model.to(device)

# Assuming test_loader is your DataLoader object for test data
test_predictions = []
test_labels_list = []
for batch in test_loader:
    input_ids, attention_mask, labels = batch
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predicted_label_idx = torch.argmax(probabilities, dim=1)
        test_predictions.extend(predicted_label_idx.tolist())
        test_labels_list.extend(labels.tolist())

# Calculate Accuracy
accuracy = accuracy_score(test_labels_list, test_predictions)
print(f"Accuracy: {accuracy}")

# Calculate Precision, Recall, F1-Score
precision, recall, f1, _ = precision_recall_fscore_support(test_labels_list, test_predictions, average='weighted')
print(f"Precision: {precision}\nRecall: {recall}\nF1-Score: {f1}")

Accuracy: 0.9026876737720111
Precision: 0.8804229537772347
Recall: 0.9026876737720111
F1-Score: 0.8881948142550192


  _warn_prf(average, modifier, msg_start, len(result))


In [7]:
unique_labels = np.unique(np.hstack([test_labels_list, test_predictions]))
idx2label = {idx: label for label, idx in label2idx.items()}
target_names = [idx2label[i] for i in unique_labels if i in idx2label]

# Print the classification report
print(classification_report(test_labels_list, test_predictions, target_names=target_names))


                 precision    recall  f1-score   support

       reserves       1.00      0.83      0.91        12
           earn       0.99      0.99      0.99       780
           lead       0.00      0.00      0.00         3
            gas       0.00      0.00      0.00         2
     castor-oil       0.00      0.00      0.00         2
           gold       0.91      0.94      0.92        31
   money-supply       0.86      0.97      0.91        32
strategic-metal       0.00      0.00      0.00         4
          nzdlr       0.00      0.00      0.00         1
           alum       0.38      0.86      0.52         7
           rand       0.00      0.00      0.00         1
            ipi       0.83      0.62      0.71         8
        nat-gas       0.71      0.50      0.59        10
           zinc       0.00      0.00      0.00         2
         lumber       0.00      0.00      0.00         2
          cocoa       1.00      1.00      1.00        17
         orange       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [8]:
for label, idx in label2idx.items():
    print(f"Label: {label}, Index: {idx}")

Label: reserves, Index: 0
Label: earn, Index: 1
Label: lead, Index: 2
Label: gas, Index: 3
Label: propane, Index: 4
Label: l-cattle, Index: 5
Label: castor-oil, Index: 6
Label: gold, Index: 7
Label: rice, Index: 8
Label: money-supply, Index: 9
Label: strategic-metal, Index: 10
Label: tea, Index: 11
Label: nzdlr, Index: 12
Label: alum, Index: 13
Label: rand, Index: 14
Label: ipi, Index: 15
Label: nat-gas, Index: 16
Label: zinc, Index: 17
Label: lumber, Index: 18
Label: palladium, Index: 19
Label: cocoa, Index: 20
Label: orange, Index: 21
Label: dmk, Index: 22
Label: nickel, Index: 23
Label: bop, Index: 24
Label: groundnut-oil, Index: 25
Label: palm-oil, Index: 26
Label: cotton, Index: 27
Label: groundnut, Index: 28
Label: cpu, Index: 29
Label: yen, Index: 30
Label: lei, Index: 31
Label: rape-oil, Index: 32
Label: soybean, Index: 33
Label: jobs, Index: 34
Label: iron-steel, Index: 35
Label: veg-oil, Index: 36
Label: naphtha, Index: 37
Label: instal-debt, Index: 38
Label: livestock, Index

In [15]:
# Set the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained("./saved_model")
# Set the model in evaluation mode
model.eval()
model.to(device)

# Define a new sentence
sentence = "Global stock markets have witnessed a substantial surge amidst renewed investor confidence"

# Tokenize the sentence
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')

# Get the inputs ready for the model
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

# Forward pass: get the outputs of the model for these inputs
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)

# The outputs are logits: apply the softmax function to get probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)

# Get the label with the highest probability
predicted_label_idx = probabilities.argmax().item()

# Decode the predicted label
predicted_label = list(label2idx.keys())[list(label2idx.values()).index(predicted_label_idx)]

print(f"The predicted label is: {predicted_label} and its index is: {predicted_label_idx}")


The predicted label is: interest and its index is: 63


In [18]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained("./saved_model")

# We need to split forward pass into two part: 
# 1) embeddings computation
# 2) classification

def compute_bert_outputs(model_bert, embedding_output, attention_mask=None, head_mask=None):
    if attention_mask is None:
        attention_mask = torch.ones(embedding_output.shape[0], embedding_output.shape[1]).to(embedding_output)

    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(dtype=next(model_bert.parameters()).dtype) # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    if head_mask is not None:
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        head_mask = head_mask.to(dtype=next(model_bert.parameters()).dtype) # switch to fload if need + fp16 compatibility
    else:
        head_mask = [None] * model_bert.config.num_hidden_layers

    encoder_outputs = model_bert.encoder(embedding_output,
                                         extended_attention_mask,
                                         head_mask=head_mask)
    sequence_output = encoder_outputs[0]
    pooled_output = model_bert.pooler(sequence_output)
    outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
    return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)    


class BertModelWrapper(nn.Module):
    
    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model
        
    def forward(self, embeddings):        
        outputs = compute_bert_outputs(self.model.bert, embeddings)
        pooled_output = outputs[1]
        pooled_output = self.model.dropout(pooled_output)
        logits = self.model.classifier(pooled_output)
        return torch.softmax(logits, dim=1)  # Return probabilities for all classes

    
bert_model_wrapper = BertModelWrapper(model)
ig = IntegratedGradients(bert_model_wrapper)

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model_wrapper, sentence, target_label_idx):

    model_wrapper.eval()
    model_wrapper.zero_grad()
    
    input_ids = torch.tensor([tokenizer.encode(sentence, add_special_tokens=True)])
    input_embedding = model_wrapper.model.bert.embeddings(input_ids)
    
    # predict
    preds = model_wrapper(input_embedding)
    pred_ind = preds.argmax().item()  # Get the index of the highest probability

    # compute attributions and approximation delta using integrated gradients
    attributions_ig, delta = ig.attribute(input_embedding, target=target_label_idx, n_steps=500, return_convergence_delta=True)

    print('pred: ', pred_ind, '(', '%.2f' % preds[0, pred_ind].item(), ')', ', delta: ', abs(delta))

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].numpy().tolist())    
    add_attributions_to_visualizer(attributions_ig, tokens, preds[0, pred_ind].item(), pred_ind, target_label_idx, delta, vis_data_records_ig)
    
    
def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().numpy()
    
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            pred_ind,
                            label,
                            "label",
                            attributions.sum(),       
                            tokens,
                            delta))    

In [19]:
interpret_sentence(bert_model_wrapper, sentence="Global stock markets have witnessed a substantial surge amidst renewed investor confidence", target_label_idx=9)
visualization.visualize_text(vis_data_records_ig)

pred:  63 ( 0.54 ) , delta:  tensor([0.0068], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,


In [20]:
interpret_sentence(bert_model_wrapper, sentence="Global stock markets have witnessed a substantial surge amidst renewed investor confidence", target_label_idx=10)
visualization.visualize_text(vis_data_records_ig)

pred:  63 ( 0.54 ) , delta:  tensor([0.0013], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,


In [21]:
interpret_sentence(bert_model_wrapper, sentence="Global stock markets have witnessed a substantial surge amidst renewed investor confidence", target_label_idx=63)
visualization.visualize_text(vis_data_records_ig)

pred:  63 ( 0.54 ) , delta:  tensor([0.0159], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,


In [22]:
sentences = [
    "Oil giant Exxon Mobil sees record profits amid rising global demand.",
    "Federal Reserve raises interest rates to combat inflation.",
    "Starbucks sees a surge in coffee sales in the third quarter.",
    "Major tech companies are investing heavily in artificial intelligence.",
    "Strong dollar impacts the global trade negatively."
]

label_indices = [78, 18, 73, 8, 56]  

for i, sentence in enumerate(sentences):
    print(i)
    interpret_sentence(bert_model_wrapper, sentence=sentence, target_label_idx=label_indices[i])
    visualization.visualize_text(vis_data_records_ig)

0
pred:  78 ( 0.67 ) , delta:  tensor([0.0020], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
78.0,78 (0.67),label,0.93,[CLS] oil giant ex ##xon mob ##il sees record profits amid rising global demand . [SEP]
,,,,


1
pred:  63 ( 0.84 ) , delta:  tensor([0.0002], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
78.0,78 (0.67),label,0.93,[CLS] oil giant ex ##xon mob ##il sees record profits amid rising global demand . [SEP]
,,,,
18.0,63 (0.84),label,-2.52,[CLS] federal reserve raises interest rates to combat inflation . [SEP]
,,,,


2
pred:  73 ( 0.75 ) , delta:  tensor([0.0141], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
78.0,78 (0.67),label,0.93,[CLS] oil giant ex ##xon mob ##il sees record profits amid rising global demand . [SEP]
,,,,
18.0,63 (0.84),label,-2.52,[CLS] federal reserve raises interest rates to combat inflation . [SEP]
,,,,


3
pred:  56 ( 0.17 ) , delta:  tensor([0.0008], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
78.0,78 (0.67),label,0.93,[CLS] oil giant ex ##xon mob ##il sees record profits amid rising global demand . [SEP]
,,,,
18.0,63 (0.84),label,-2.52,[CLS] federal reserve raises interest rates to combat inflation . [SEP]
,,,,


4
pred:  46 ( 0.53 ) , delta:  tensor([0.0276], dtype=torch.float64)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
9.0,63 (0.54),label,0.22,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
10.0,63 (0.54),label,-2.28,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
63.0,63 (0.54),label,1.38,[CLS] global stock markets have witnessed a substantial surge amidst renewed investor confidence [SEP]
,,,,
78.0,78 (0.67),label,0.93,[CLS] oil giant ex ##xon mob ##il sees record profits amid rising global demand . [SEP]
,,,,
18.0,63 (0.84),label,-2.52,[CLS] federal reserve raises interest rates to combat inflation . [SEP]
,,,,
