<a href="https://colab.research.google.com/github/hailusong/nlp-qa/blob/master/nlp_qa_albert_pretrained_captum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

References:
1. [Captum on HuggingFace SQuAD](https://captum.ai/tutorials/Bert_SQUAD_Interpret)



## Environment Setup
To get the source code of transformer release, do this (assuming the release version is 2.11.0)
```
git fetch -a --tags
git checkout tags/v2.11.0
```

Make sure we have access to file command in Linux

In [5]:
!uname -a
!pip install wget
!apt-get install file

Linux c65bfaf37983 4.19.104+ #1 SMP Wed Feb 19 05:26:34 PST 2020 x86_64 x86_64 x86_64 GNU/Linux
Collecting wget
  Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-cp36-none-any.whl size=9682 sha256=25fc53bbb5a0e5bae740faa9631fd65c28e583d389b448ab13cbfa9450422fb2
  Stored in directory: /root/.cache/pip/wheels/40/15/30/7d8f7cea2902b4db79e3fea550d7d7b85ecb27ef992b618f3f
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
The following additional packages will be installed:
  libmagic-mgc libmagic1

In [None]:
try:
    import transformers
except ImportError as e:
    # install huggingface
    !pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/82/25/89050e69ed53c2a3b7f8c67844b3c8339c1192612ba89a172cf85b298948/transformers-3.0.1-py3-none-any.whl (757kB)
[K     |████████████████████████████████| 757kB 4.6MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 14.1MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 29.4MB/s 
Collecting tokenizers==0.8.0-rc4
[?25l  Downloading https://files.pythonhosted.org/packages/e8/bd/e5abec46af977c8a1375c1dca7cb1e5b3ec392ef279067af7f6bc50491a0/tokenizers-0.8.0rc4-cp36-cp36m-manylinux1_x86_64.whl (3.0MB

In [None]:
import torch
import transformers

print(torch.__version__)
print(transformers.__version__)

In [None]:
!ls -al
!rm -rf nlp-qa
!git clone --depth 1 https://github.com/hailusong/nlp-qa.git

In [None]:
!pip install -r nlp-qa/question-answering/requirements.txt

### GPU Setup

In [None]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

## Configuratiuon

In [None]:
MULTILABEL_INTENT=False
QA_EXTRACTION=True
CAPTUM=True

## Download

### Pre-trained ALBERT v2 language model 
Ref:
- https://huggingface.co/transformers/pretrained_models.html
- [BERT Fine-Tuning Tutorial with PyTorch](https://mccormickml.com/2019/07/22/BERT-fine-tuning/#4-train-our-classification-model)

In [None]:
if MULTILABEL_INTENT:
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
  from transformers import AlbertConfig, AlbertForSequenceClassification
  import torch

  config = torch.hub.load('huggingface/pytorch-transformers', 'config', 'albert-large-v2')
  config.num_labels = 2
  assert(type(config) == AlbertConfig)
  print(config)

  # Download model and configuration from S3 and cache.
  # label: lawyer letter, deed, other
  model = torch.hub.load('huggingface/pytorch-transformers', 
                         'modelForSequenceClassification', 
                         'albert-large-v2',
                         config=config
                         )
  assert(type(model) == AlbertForSequenceClassification)

  # Copy the model to the GPU.
  model.to(device)

  # Albert tokenizer - download vocabulary from S3 and cache.
  tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'albert-large-v2')
  # e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
  # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', './test/bert_saved_model/')

  # model = torch.hub.load('huggingface/pytorch-transformers', 'model', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
  # model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'albert-large-v2', output_attentions=True)  # Update configuration during loading
  # assert model.config.output_attentions == True


In [None]:
if MULTILABEL_INTENT:
  # Logging
  print('\n==== Output Layer ====\n')
  
  params = list(model.named_parameters())
  for p in params[-4:]:
      print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

### Pre-trained ALBERT v2 model on SQuAD2
https://huggingface.co/elgeish/cs224n-squad2.0-albert-xxlarge-v1, including:
- config.json, 721.0B
- pytorch_model.bin, 849.2MB
- special_tokens_map.json, 156.0B
- spiece.model, 742.5KB
- tokenizer_config.json, 39.0B

In [None]:
if QA_EXTRACTION and CAPTUM:
  !pip install captum
  from captum.attr import visualization as viz
  from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
  from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [None]:
if QA_EXTRACTION:
  from transformers import ( AutoTokenizer, AutoModelForQuestionAnswering )

  tokenizer = AutoTokenizer.from_pretrained("elgeish/cs224n-squad2.0-albert-xxlarge-v1")
  model = AutoModelForQuestionAnswering.from_pretrained("elgeish/cs224n-squad2.0-albert-xxlarge-v1")

In [None]:
if QA_EXTRACTION:
  # inspect tokenizer max length
  print(f'Tokenizer max input length {tokenizer.model_max_length}')

## Load Training Data

### Load ALBERT multi-label training data
Some example datasets:
-  The [Corpus of Linguistic Acceptability (CoLA)](https://nyu-mll.github.io/CoLA/) dataset for single sentence classification, which is about sentences labeled as grammatically correct or incorrect - part of GLUE benchmark
  * raw/in_domain_train.tsv (8551 lines), split further to 9:1 as train/validation
  * raw/in_domain_dev.tsv (527 lines)

In [None]:
if MULTILABEL_INTENT:
  import wget
  import os

  print('Downloading dataset...')

  # The URL for the dataset zip file.
  url = 'https://nyu-mll.github.io/CoLA/cola_public_1.1.zip'

  # Download the file (if we haven't already)
  if not os.path.exists('./cola_public_1.1.zip'):
      wget.download(url, './cola_public_1.1.zip')

In [None]:
if MULTILABEL_INTENT:
  # Unzip the dataset (if we haven't already)
  if not os.path.exists('./cola_public/'):
      !unzip cola_public_1.1.zip

In [None]:
if MULTILABEL_INTENT:
  import pandas as pd

  # Load the dataset into a pandas dataframe.
  df = pd.read_csv("./cola_public/raw/in_domain_train.tsv", delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence'])

  # Report the number of sentences.
  print('Number of training sentences: {:,}\n'.format(df.shape[0]))

  # Display 10 random rows from the data.
  df.sample(10)

In [None]:
if MULTILABEL_INTENT:
  # Get the lists of sentences and their labels.
  sentences = df.sentence.values
  labels = df.label.values

In [None]:
# if MULTILABEL_INTENT:
#   from transformers import BertTokenizer

#   # Load the BERT tokenizer.
#   print('Loading BERT tokenizer...')
#   tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [None]:
if MULTILABEL_INTENT:
  max_len = 0

  # For every sentence...
  for sent in sentences:

      # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
      input_ids = tokenizer.encode(sent, add_special_tokens=True)

      # Update the maximum sentence length.
      max_len = max(max_len, len(input_ids))

  print('Max sentence length: ', max_len)

In [None]:
if MULTILABEL_INTENT:
  # Tokenize all of the sentences and map the tokens to thier word IDs.
  input_ids = []
  attention_masks = []

  # For every sentence...
  for sent in sentences:
      # `encode_plus` will:
      #   (1) Tokenize the sentence.
      #   (2) Prepend the `[CLS]` token to the start.
      #   (3) Append the `[SEP]` token to the end.
      #   (4) Map tokens to their IDs.
      #   (5) Pad or truncate the sentence to `max_length`
      #   (6) Create attention masks for [PAD] tokens.
      encoded_dict = tokenizer.encode_plus(
                          sent,                      # Sentence to encode.
                          add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                          max_length = 64,           # Pad & truncate all sentences.
                          pad_to_max_length = True,
                          return_attention_mask = True,   # Construct attn. masks.
                          return_tensors = 'pt',     # Return pytorch tensors.
                    )
      
      # Add the encoded sentence to the list.    
      input_ids.append(encoded_dict['input_ids'])
      
      # And its attention mask (simply differentiates padding from non-padding).
      attention_masks.append(encoded_dict['attention_mask'])

  # Convert the lists into tensors.
  input_ids = torch.cat(input_ids, dim=0)
  attention_masks = torch.cat(attention_masks, dim=0)
  labels = torch.tensor(labels)

  # Print sentence 0, now as a list of IDs.
  print('Original: ', sentences[0])
  print('Token IDs:', input_ids[0])
  print('Labels:', labels[0])
  print('Attention Masks:', attention_masks[0])


#### Label structure expected by [AlbertForSequenceClassification](https://huggingface.co/transformers/model_doc/albert.html#albertforsequenceclassification)
labels (torch.**LongTensor** of **shape (batch_size,)**, optional, defaults to None)
- Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. 
- If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

In [None]:
 # label data is grammar correctness: one per sentence
 assert(max(labels).dtype == torch.int64)
 assert(min(labels) == 0 and max(labels) == 1)

In [None]:
if MULTILABEL_INTENT:
  from torch.utils.data import TensorDataset, random_split

  # Combine the training inputs into a TensorDataset.
  dataset = TensorDataset(input_ids, attention_masks, labels)

  # Create a 90-10 train-validation split.

  # Calculate the number of samples to include in each set.
  train_size = int(0.9 * len(dataset))
  val_size = len(dataset) - train_size

  # Divide the dataset by randomly selecting samples.
  train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

  print('{:>5,} training samples'.format(train_size))
  print('{:>5,} validation samples'.format(val_size))

In [None]:
if MULTILABEL_INTENT:
  from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

  # The DataLoader needs to know our batch size for training, so we specify it 
  # here. For fine-tuning BERT on a specific task, the authors recommend a batch 
  # size of 16 or 32.
  batch_size = 32

  # Create the DataLoaders for our training and validation sets.
  # We'll take training samples in random order. 
  train_dataloader = DataLoader(
              train_dataset,  # The training samples.
              sampler = RandomSampler(train_dataset), # Select batches randomly
              batch_size = batch_size # Trains with this batch size.
          )

  # For validation the order doesn't matter, so we'll just read them sequentially.
  validation_dataloader = DataLoader(
              val_dataset, # The validation samples.
              sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
              batch_size = batch_size # Evaluate with this batch size.
          )

## Fine-tuning

### Fine-tune the ALBERT multi-label classification

In [None]:
if MULTILABEL_INTENT:
  from transformers import AdamW

  # Optimizer and LR setup
  # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
  # I believe the 'W' stands for 'Weight Decay fix"
  optimizer = AdamW(
      model.parameters(),
      lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
      eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
    )

In [None]:
if MULTILABEL_INTENT:
  from transformers import get_linear_schedule_with_warmup

  # Number of training epochs. The BERT authors recommend between 2 and 4. 
  # We chose to run for 4, but we'll see later that this may be over-fitting the
  # training data.
  epochs = 4

  # Total number of training steps is [number of batches] x [number of epochs]. 
  # (Note that this is not the same as the number of training samples).
  total_steps = len(train_dataloader) * epochs

  # Create the learning rate scheduler.
  scheduler = get_linear_schedule_with_warmup(
      optimizer, 
      num_warmup_steps = 0, # Default value in run_glue.py
      num_training_steps = total_steps)

### Traing Loop

In [None]:
if MULTILABEL_INTENT:
  import numpy as np

  # Function to calculate the accuracy of our predictions vs labels
  def flat_accuracy(preds, labels):
      pred_flat = np.argmax(preds, axis=1).flatten()
      labels_flat = labels.flatten()
      return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
if MULTILABEL_INTENT:
  import time
  import datetime

  def format_time(elapsed):
      '''
      Takes a time in seconds and returns a string hh:mm:ss
      '''
      # Round to the nearest second.
      elapsed_rounded = int(round((elapsed)))
      
      # Format as hh:mm:ss
      return str(datetime.timedelta(seconds=elapsed_rounded))

In [None]:
if MULTILABEL_INTENT:
  import random
  import numpy as np

  # This training code is based on the `run_glue.py` script here:
  # https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128

  # Set the seed value all over the place to make this reproducible.
  seed_val = 42

  random.seed(seed_val)
  np.random.seed(seed_val)
  torch.manual_seed(seed_val)
  torch.cuda.manual_seed_all(seed_val)

  # We'll store a number of quantities such as training and validation loss, 
  # validation accuracy, and timings.
  training_stats = []

  # Measure the total training time for the whole run.
  total_t0 = time.time()

  # For each epoch...
  for epoch_i in range(0, epochs):
      
      # ========================================
      #               Training
      # ========================================
      
      # Perform one full pass over the training set.

      print("")
      print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
      print('Training...')

      # Measure how long the training epoch takes.
      t0 = time.time()

      # Reset the total loss for this epoch.
      total_train_loss = 0

      # Put the model into training mode. Don't be mislead--the call to 
      # `train` just changes the *mode*, it doesn't *perform* the training.
      # `dropout` and `batchnorm` layers behave differently during training
      # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
      model.train()

      # For each batch of training data...
      for step, batch in enumerate(train_dataloader):
          # batch[0].shape, batch[1].shape, batch[2].shape 
          # torch.Size([24, 64]) torch.Size([24, 64]) torch.Size([24])
          assert(len(batch) == 3, 'each batch contains 3 tensors: ids, masks and labels')

          # Progress update every 40 batches.
          if step % 40 == 0 and not step == 0:
              # Calculate elapsed time in minutes.
              elapsed = format_time(time.time() - t0)
              
              # Report progress.
              print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

          # Unpack this training batch from our dataloader. 
          #
          # As we unpack the batch, we'll also copy each tensor to the GPU using the 
          # `to` method.
          #
          # `batch` contains three pytorch tensors:
          #   [0]: input ids 
          #   [1]: attention masks
          #   [2]: labels 
          b_input_ids = batch[0].to(device)
          b_input_mask = batch[1].to(device)
          b_labels = batch[2].to(device)

          # Always clear any previously calculated gradients before performing a
          # backward pass. PyTorch doesn't do this automatically because 
          # accumulating the gradients is "convenient while training RNNs". 
          # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
          model.zero_grad()        

          # Perform a forward pass (evaluate the model on this training batch).
          # The documentation for this `model` function is here: 
          # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
          # It returns different numbers of parameters depending on what arguments
          # arge given and what flags are set. For our useage here, it returns
          # the loss (because we provided labels) and the "logits"--the model
          # outputs prior to activation.
          loss, logits = model(b_input_ids, 
                              token_type_ids=None, 
                              attention_mask=b_input_mask, 
                              labels=b_labels)

          # Accumulate the training loss over all of the batches so that we can
          # calculate the average loss at the end. `loss` is a Tensor containing a
          # single value; the `.item()` function just returns the Python value 
          # from the tensor.
          total_train_loss += loss.item()

          # Perform a backward pass to calculate the gradients.
          loss.backward()

          # Clip the norm of the gradients to 1.0.
          # This is to help prevent the "exploding gradients" problem.
          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

          # Update parameters and take a step using the computed gradient.
          # The optimizer dictates the "update rule"--how the parameters are
          # modified based on their gradients, the learning rate, etc.
          optimizer.step()

          # Update the learning rate.
          scheduler.step()

      # Calculate the average loss over all of the batches.
      avg_train_loss = total_train_loss / len(train_dataloader)            
      
      # Measure how long this epoch took.
      training_time = format_time(time.time() - t0)

      print("")
      print("  Average training loss: {0:.2f}".format(avg_train_loss))
      print("  Training epcoh took: {:}".format(training_time))
          
      # ========================================
      #               Validation
      # ========================================
      # After the completion of each training epoch, measure our performance on
      # our validation set.

      print("")
      print("Running Validation...")

      t0 = time.time()

      # Put the model in evaluation mode--the dropout layers behave differently
      # during evaluation.
      model.eval()

      # Tracking variables 
      total_eval_accuracy = 0
      total_eval_loss = 0
      nb_eval_steps = 0

      # Evaluate data for one epoch
      for batch in validation_dataloader:
          
          # Unpack this training batch from our dataloader. 
          #
          # As we unpack the batch, we'll also copy each tensor to the GPU using 
          # the `to` method.
          #
          # `batch` contains three pytorch tensors:
          #   [0]: input ids 
          #   [1]: attention masks
          #   [2]: labels 
          b_input_ids = batch[0].to(device)
          b_input_mask = batch[1].to(device)
          b_labels = batch[2].to(device)
          
          # Tell pytorch not to bother with constructing the compute graph during
          # the forward pass, since this is only needed for backprop (training).
          with torch.no_grad():        

              # Forward pass, calculate logit predictions.
              # token_type_ids is the same as the "segment ids", which 
              # differentiates sentence 1 and 2 in 2-sentence tasks.
              # The documentation for this `model` function is here: 
              # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
              # Get the "logits" output by the model. The "logits" are the output
              # values prior to applying an activation function like the softmax.
              (loss, logits) = model(b_input_ids, 
                                    token_type_ids=None, 
                                    attention_mask=b_input_mask,
                                    labels=b_labels)
              
          # Accumulate the validation loss.
          total_eval_loss += loss.item()

          # Move logits and labels to CPU
          logits = logits.detach().cpu().numpy()
          label_ids = b_labels.to('cpu').numpy()

          # Calculate the accuracy for this batch of test sentences, and
          # accumulate it over all batches.
          total_eval_accuracy += flat_accuracy(logits, label_ids)
          

      # Report the final accuracy for this validation run.
      avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
      print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

      # Calculate the average loss over all of the batches.
      avg_val_loss = total_eval_loss / len(validation_dataloader)
      
      # Measure how long the validation run took.
      validation_time = format_time(time.time() - t0)
      
      print("  Validation Loss: {0:.2f}".format(avg_val_loss))
      print("  Validation took: {:}".format(validation_time))

      # Record all statistics from this epoch.
      training_stats.append(
          {
              'epoch': epoch_i + 1,
              'Training Loss': avg_train_loss,
              'Valid. Loss': avg_val_loss,
              'Valid. Accur.': avg_val_accuracy,
              'Training Time': training_time,
              'Validation Time': validation_time
          }
      )

  print("")
  print("Training complete!")

  print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

In [None]:
if MULTILABEL_INTENT:
  import pandas as pd

  # Display floats with two decimal places.
  pd.set_option('precision', 2)

  # Create a DataFrame from our training statistics.
  df_stats = pd.DataFrame(data=training_stats)

  # Use the 'epoch' as the row index.
  df_stats = df_stats.set_index('epoch')

  # A hack to force the column headers to wrap.
  #df = df.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])

  # Display the table.
  print(df_stats)

## Captum for QA Inference

In [None]:
if QA_EXTRACTION and CAPTUM:
  ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
  sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
  cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

We define a set of helper function for **constructing references / baselines** for 
- word tokens
- token types and 
- position ids

We also provide separate helper functions for all sub-(word)-embeddings of BertEmbeddings layer, that will construct  
- the sub-(word)-embeddings and ...
- corresponding baselines / references

In [None]:
if QA_EXTRACTION and CAPTUM:
  def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
      question_ids = tokenizer.encode(question, add_special_tokens=False)
      text_ids = tokenizer.encode(text, add_special_tokens=False)

      # construct input token ids
      input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

      # construct reference token ids 
      ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
          [ref_token_id] * len(text_ids) + [sep_token_id]

      return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

  def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
      seq_len = input_ids.size(1)
      token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
      ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
      return token_type_ids, ref_token_type_ids

  def construct_input_ref_pos_id_pair(input_ids):
      seq_length = input_ids.size(1)
      position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
      # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
      ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

      position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
      ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
      return position_ids, ref_position_ids
      
  def construct_attention_mask(input_ids):
      return torch.ones_like(input_ids)

  def construct_bert_sub_embedding(input_ids, ref_input_ids,
                                    token_type_ids, ref_token_type_ids,
                                    position_ids, ref_position_ids):
      input_embeddings = interpretable_embedding1.indices_to_embeddings(input_ids)
      ref_input_embeddings = interpretable_embedding1.indices_to_embeddings(ref_input_ids)

      input_embeddings_token_type = interpretable_embedding2.indices_to_embeddings(token_type_ids)
      ref_input_embeddings_token_type = interpretable_embedding2.indices_to_embeddings(ref_token_type_ids)

      input_embeddings_position_ids = interpretable_embedding3.indices_to_embeddings(position_ids)
      ref_input_embeddings_position_ids = interpretable_embedding3.indices_to_embeddings(ref_position_ids)
      
      return (input_embeddings, ref_input_embeddings), \
            (input_embeddings_token_type, ref_input_embeddings_token_type), \
            (input_embeddings_position_ids, ref_input_embeddings_position_ids)
      
  def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                      token_type_ids=None, ref_token_type_ids=None, \
                                      position_ids=None, ref_position_ids=None):
      input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
      ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
      
      return input_embeddings, ref_input_embeddings

## Inference
Source: https://pytorch.org/hub/huggingface_pytorch-transformers/


In [None]:
if QA_EXTRACTION and not CAPTUM:
  # model = torch.hub.load('huggingface/pytorch-transformers', 'modelForQuestionAnswering', 'bert-large-uncased-whole-word-masking-finetuned-squad')
  # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-large-uncased-whole-word-masking-finetuned-squad')

  # The format is paragraph first and then question
  # fact = "Jim Henson was a puppeteer"
  # q = "Who was Jim Henson ?"
  fact = input('Fact: ')
  q = input('Question: ')

  # Converts a list of string to a sequence of ids (integer), using the tokenizer and vocabulary. 
  # Also request to add model-specific special tokens (such as beginning of sequence, end of sequence, sequence separator).
  indexed_tokens_fact = tokenizer.encode(fact, add_special_tokens=True)
  indexed_tokens_q = tokenizer.encode(q, add_special_tokens=True)

  # merge fact tokens and Q tokens by removing the START token in Q tokens - this is NOT pipeline way
  # indexed_tokens = indexed_tokens_fact + indexed_tokens_q[1:]
  # segments_ids = [0] * len(indexed_tokens_fact) + [1] * len(indexed_tokens_q[1:])
  #
  # this is pipeline way
  indexed_tokens = indexed_tokens_q + indexed_tokens_fact[1:]
  segments_ids = [0] * len(indexed_tokens_q) + [1] * len(indexed_tokens_ctx[1:])

  print(f'Length is {len(indexed_tokens)}, Indexed token: {indexed_tokens}')

  # Predict the start and end positions logits
  segments_tensors = torch.tensor([segments_ids])
  tokens_tensor = torch.tensor([indexed_tokens])

  with torch.no_grad():
      start_logits, end_logits = model(tokens_tensor, token_type_ids=segments_tensors)

  # get the highest prediction
  answer = tokenizer.decode(indexed_tokens[torch.argmax(start_logits):torch.argmax(end_logits)+1])
  print(f'The answer is: {answer}')

  # Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions (set model to train mode before if used for training)
  # start_positions, end_positions = torch.tensor([12]), torch.tensor([14])
  # multiple_choice_loss = model(tokens_tensor, token_type_ids=segments_tensors, start_positions=start_positions, end_positions=end_positions)

In [None]:
if QA_EXTRACTION and CAPTUM:
  # model = torch.hub.load('huggingface/pytorch-transformers', 'modelForQuestionAnswering', 'bert-large-uncased-whole-word-masking-finetuned-squad')
  # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-large-uncased-whole-word-masking-finetuned-squad')

  # The format is paragraph first and then question
  # fact = "Jim Henson was a puppeteer"
  # q = "Who was Jim Henson ?"
  fact = input('Fact: ')
  q = input('Question: ')

  # Converts a list of string to a sequence of ids (integer), using the tokenizer and vocabulary. 
  # Also request to add model-specific special tokens (such as beginning of sequence, end of sequence, sequence separator).
  indexed_tokens_fact = tokenizer.encode(fact, add_special_tokens=True)
  indexed_tokens_q = tokenizer.encode(q, add_special_tokens=True)

  # merge fact tokens and Q tokens by removing the START token in Q tokens - this is NOT pipeline way
  # indexed_tokens = indexed_tokens_fact + indexed_tokens_q[1:]
  # segments_ids = [0] * len(indexed_tokens_fact) + [1] * len(indexed_tokens_q[1:])
  #
  # this is pipeline way
  indexed_tokens = indexed_tokens_q + indexed_tokens_fact[1:]
  segments_ids = [0] * len(indexed_tokens_q) + [1] * len(indexed_tokens_ctx[1:])

  # Captum way
  captum_input_ids, captum_ref_input_ids, captum_sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
  captum_token_type_ids, captum_ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
  captum_position_ids, captum_ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
  captum_attention_mask = construct_attention_mask(input_ids)

  indices = input_ids[0].detach().tolist()
  all_tokens = tokenizer.convert_ids_to_tokens(indices)

  print(f'Length is {len(indexed_tokens)}, Indexed token: {indexed_tokens}')

  # Predict the start and end positions logits
  segments_tensors = torch.tensor([segments_ids])
  tokens_tensor = torch.tensor([indexed_tokens])

  with torch.no_grad():
      start_logits, end_logits = model(tokens_tensor, token_type_ids=segments_tensors)

  # get the highest prediction
  answer = tokenizer.decode(indexed_tokens[torch.argmax(start_logits):torch.argmax(end_logits)+1])
  print(f'The answer is: {answer}')

  # Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions (set model to train mode before if used for training)
  # start_positions, end_positions = torch.tensor([12]), torch.tensor([14])
  # multiple_choice_loss = model(tokens_tensor, token_type_ids=segments_tensors, start_positions=start_positions, end_positions=end_positions)