In [2]:
pip install transformers captum

Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: captum
Successfully installed captum-0.7.0


In [3]:
from captum.attr import IntegratedGradients
from transformers import AutoTokenizer, LEDForConditionalGeneration
import torch
import numpy as np

class ExplainableSummarizer:
    model_name = "allenai/led-large-16384-arxiv"

    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = LEDForConditionalGeneration.from_pretrained(self.model_name)
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def model_forward(self, input_ids, attention_mask):
        """
        Custom forward function for Captum's Integrated Gradients.
        Ensures input_ids are cast to the correct type (LongTensor).
        """
        input_ids = input_ids.to(dtype=torch.long)  # Ensure LongTensor
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=input_ids)
        return outputs.logits

    def summarize_with_attention_scores(self, text):
      # Tokenize input
      inputs = self.tokenizer(text, return_tensors="pt", max_length=4096, truncation=True)
      input_ids = inputs["input_ids"].to(self.device)
      attention_mask = inputs["attention_mask"].to(self.device)

      # Use model.forward() to compute outputs with attentions
      outputs = self.model(
          input_ids=input_ids,
          attention_mask=attention_mask,
          output_attentions=True,
          return_dict=True
      )

      # Extract logits for summary generation
      logits = outputs.logits
      summary_ids = logits.argmax(dim=-1)

      # Decode summary
      summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

      # Compute attention scores
      decoder_attentions = outputs.decoder_attentions  # Attention from decoder
      attention_weights = torch.stack(decoder_attentions).mean(dim=(0, 1))  # Average across layers and heads

      # Normalize attention scores
      attention_scores = attention_weights.squeeze().cpu().detach().numpy()
      attention_scores = attention_scores[:len(input_ids[0])]  # Trim to match input length
      attention_scores = attention_scores / np.max(attention_scores)  # Normalize to range [0, 1]

      # Print tokens and attention scores
      input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
      print("Tokens and Attention Scores:")
      for token, attention in zip(input_tokens, attention_scores):
        if np.isscalar(attention):
            print(f"Token: {token}, Attention Score: {attention:.4f}")
        else:
            print(f"Token: {token}, Attention Score: {attention.mean():.4f}")


      return summary



summarizer = ExplainableSummarizer()
text = r"Oil settles down after Israel agrees to ceasefire deal with Hezbollah\n  ', 'content': '\n\n      Business\n  \n\n\n\n      Business\n  \n\nAn aerial view shows oil tanks of Transneft oil pipeline operator at the crude oil terminal Kozmino on the shore of Nakhodka Bay near the port city of Nakhodka, Russia June 13, 2022. REUTERS/Tatiana Meel/File Photo\nHOUSTON :Oil prices settled lower on Tuesday, extending the previous day\'s losses in choppy trade after Israel agreed to a ceasefire deal with Hezbollah, reducing oil\'s risk premium. \nBrent crude futures settled down 20 cents, or 0.27 per cent, to $72.81 a barrel. U.S. West Texas Intermediate crude futures settled at $68.77 a barrel, down 17 cents, or 0.25 per cent. \nThe accord between Israel and armed group Hezbollah was expected to take effect on Wednesday, U.S. President Joe Biden said.\nIsraeli Prime Minister Benjamin Netanyahu said he was ready to implement a ceasefire and would respond forcefully to any violation by Hezbollah.\nOn Monday, oil prices fell more than $2 following multiple reports that the warring sides had agreed to terms of a ceasefire.\nA ceasefire could pressure crude oil prices because the U.S. administration would likely reduce sanctions on oil from Iran, a supporter of Hezbollah, StoneX analyst Alex Hodes said in a note. \nOPEC+ EYE OUTPUT HIKE DELAY\nBoth benchmarks briefly jumped more than $1 per barrel during the session.   \nWe popped and dropped around the time news came out of the resumption of OPEC talks, said Phil Flynn, senior analyst at Price Futures Group."
summary = summarizer.summarize_with_attention_scores(text)
print("\nGenerated Summary:\n", summary)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/207 [00:00<?, ?B/s]

Input ids are automatically padded from 388 to 1024 to be a multiple of `config.attention_window`: 1024


Tokens and Attention Scores:
Token: <s>, Attention Score: 0.0026
Token: Oil, Attention Score: 0.0026
Token: Ġsettles, Attention Score: 0.0026
Token: Ġdown, Attention Score: 0.0026
Token: Ġafter, Attention Score: 0.0026
Token: ĠIsrael, Attention Score: 0.0026
Token: Ġagrees, Attention Score: 0.0026
Token: Ġto, Attention Score: 0.0026
Token: Ġceasefire, Attention Score: 0.0026
Token: Ġdeal, Attention Score: 0.0026
Token: Ġwith, Attention Score: 0.0026
Token: ĠHezbollah, Attention Score: 0.0026
Token: \, Attention Score: 0.0026
Token: n, Attention Score: 0.0026
Token: Ġ, Attention Score: 0.0026
Token: Ġ',, Attention Score: 0.0026

Generated Summary:
  oil oil settles down after Israel agrees to a deal with Hezbollah\n  ', 'content': '\n\n\     \n  \n\n\n\n\     \n  \n n  aerial view shows oil tanks of Transneft oil pipeline operator at the crude oil terminal Kozmino on the shore of thehodka Bay near the port city of Nakhodka , Russia , 13, 2022. REUTERS/Tatiana Meel/File Photo\nHOUSTON : 