In [1]:
import string

import pickle as pkl

import numpy as np
import torch

from transformers import AutoModelForMultipleChoice, AutoTokenizer
from datasets import load_dataset

from src.utils_multiple_choice import convert_examples_to_features, InputExample

from src.bertviz.bertviz import head_view_question

In [2]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [3]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [5]:
model = AutoModelForMultipleChoice.from_pretrained("../assets/models/bb_race_m/")
tokenizer = AutoTokenizer.from_pretrained("../assets/models/bb_race_m")

dataset = load_dataset("race", "middle")
test = dataset['test']

label_list = ["0", "1", "2", "3"]
label_map = {
    0: "A",
    1: "B",
    2: "C",
    3: "D"
}
max_seq_length = 128

Reusing dataset race (/home/marcos/.cache/huggingface/datasets/race/middle/0.1.0/a7d1fac780e70c0e75bca35e9f2f8cfc1411edd18ffd6858ddce56f70dfb1e7c)


In [18]:
def predict(article, question, options, real_label=None, return_result=False):
    examples = [InputExample(
        example_id="pred",
        question=question,
        contexts=[article, article, article, article],  # this is not efficient but convenient
        endings=[options[0], options[1], options[2], options[3]],
        label=str(ord(real_label) - ord("A")) if real_label else "0"
    )]
    
    feature = convert_examples_to_features(
        examples,
        label_list,
        max_seq_length,
        tokenizer
    )[0]
    
    features = {
        'input_ids': torch.tensor([feature.input_ids]),
        'attention_mask': torch.tensor([feature.attention_mask]),
        'token_type_ids': torch.tensor([feature.token_type_ids]),
    } 
    
    if return_result:
        result = model.forward(features['input_ids'], features['attention_mask'], features['token_type_ids'], 
                               output_attentions=True, output_hidden_states=True, return_dict=True)
        return result
    else:
        result = model.forward(features['input_ids'], features['attention_mask'], features['token_type_ids'])[0][0]
        return np.array([float(abs(x)) for x in result]).argmax()

In [19]:
def show_head_view(ex):
    question = ex['question']
    article = ex['article']
    options = ex['options']
    answer = ex.get('answer', None)
    examples = [InputExample(
        example_id="pred",
        question=question,
        contexts=[article, article, article, article],  # this is not efficient but convenient
        endings=[options[0], options[1], options[2], options[3]],
        label=str(ord(answer) - ord("A")) if answer else "0"
    )]
    
    feature = convert_examples_to_features(
        examples,
        label_list,
        max_seq_length,
        tokenizer
    )[0]
    
    features = {
        'input_ids': torch.tensor([feature.input_ids]),
        'attention_mask': torch.tensor([feature.attention_mask]),
        'token_type_ids': torch.tensor([feature.token_type_ids]),
    } 
    
    option_a = 0
    option_b = 1
    option_c = 2
    option_d = 3
    
    input_id_list = feature.input_ids[option_a]
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    token_type_ids = feature.token_type_ids
    
    attention = model.forward(features['input_ids'], features['attention_mask'], features['token_type_ids'], 
                               output_attentions=True, output_hidden_states=True, return_dict=True)['attentions']
    q_start_a = token_type_ids[option_a].index(1)
    tokens_a = tokens[q_start_a:]
    idx_end = 0
    for tok in reversed(tokens_a):
        if tok not in list(string.punctuation) + ["[SEP]"]:
            break
        else:
            idx_end += 1
    option_start_a = len(tokens_a)-len(tokenizer.tokenize(options[option_a]))-idx_end
    atts_a = []
    for att in attention:
        atts_a.append(att[option_a].reshape(1, 12, 128, 128)[:, :, q_start_a:, q_start_a:])
    
    input_id_list = feature.input_ids[option_b]
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    token_type_ids = feature.token_type_ids
    
    attention = model.forward(features['input_ids'], features['attention_mask'], features['token_type_ids'], 
                               output_attentions=True, output_hidden_states=True, return_dict=True)['attentions']
    q_start_b = token_type_ids[option_b].index(1)
    tokens_b = tokens[q_start_b:]
    idx_end = 0
    for tok in reversed(tokens_b):
        if tok not in list(string.punctuation) + ["[SEP]"]:
            break
        else:
            idx_end += 1
    option_start_b = len(tokens_b)-len(tokenizer.tokenize(options[option_b]))-idx_end
    atts_b = []
    for att in attention:
        atts_b.append(att[option_b].reshape(1, 12, 128, 128)[:, :, q_start_b:, q_start_b:])
    
    input_id_list = feature.input_ids[option_c]
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    token_type_ids = feature.token_type_ids
    
    attention = model.forward(features['input_ids'], features['attention_mask'], features['token_type_ids'], 
                               output_attentions=True, output_hidden_states=True, return_dict=True)['attentions']
    q_start_c = token_type_ids[option_c].index(1)
    tokens_c = tokens[q_start_c:]
    idx_end = 0
    for tok in reversed(tokens_c):
        if tok not in list(string.punctuation) + ["[SEP]"]:
            break
        else:
            idx_end += 1
    option_start_c = len(tokens_c)-len(tokenizer.tokenize(options[option_c]))-idx_end
    atts_c = []
    for att in attention:
        atts_c.append(att[option_c].reshape(1, 12, 128, 128)[:, :, q_start_c:, q_start_c:])
        
    input_id_list = feature.input_ids[option_d]
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    token_type_ids = feature.token_type_ids
    
    attention = model.forward(features['input_ids'], features['attention_mask'], features['token_type_ids'], 
                               output_attentions=True, output_hidden_states=True, return_dict=True)['attentions']
    q_start_d = token_type_ids[option_d].index(1)
    tokens_d = tokens[q_start_d:]
    idx_end = 0
    for tok in reversed(tokens_d):
        if tok not in list(string.punctuation) + ["[SEP]"]:
            break
        else:
            idx_end += 1
    option_start_d = len(tokens_d)-len(tokenizer.tokenize(options[option_d]))-idx_end
    atts_d = []
    for att in attention:
        atts_d.append(att[option_d].reshape(1, 12, 128, 128)[:, :, q_start_d:, q_start_d:])

    atts = {
        'a': atts_a,
        'b': atts_b,
        'c': atts_c,
        'd': atts_d
    }
    tokens_ = {
        'a': tokens_a,
        'b': tokens_b,
        'c': tokens_c,
        'd': tokens_d
    }
    options = {
        'a': option_start_a,
        'b': option_start_b,
        'c': option_start_c,
        'd': option_start_d
    }
    head_view_question(atts, tokens_, options)
    
    return atts, tokens_, options

### Example 1

In [20]:
ex = test[0]

article = ex['article']
question = ex['question']
options = ex['options']
real_label = ex['answer']

result = predict(article, question, options, real_label)
print(f"Question: {question}")
print(f"Options: {options}")
print(f"Result: {label_map[result]}")

convert examples to features: 1it [00:00, 76.43it/s]


Question: A discipline leader is supposed to  _  .
Options: ['take care of the whole group', 'make sure that everybody finishes homework', 'make sure that nobody chats in class', 'collect all the homework and hand it in to teachers']
Result: C


In [21]:
result = predict(article, question, options, real_label, return_result=True)

convert examples to features: 1it [00:00, 62.73it/s]


##### Modification A

In [22]:
new_question = "What is a discipline leader supposed to?"

result = predict(article, new_question, options, real_label)
print(f"Question: {new_question}")
print(f"Options: {options}")
print(f"Result: {label_map[result]}")

ex_mod_a = {
    'article': article,
    'question': new_question,
    'options': options,
    'answer': real_label
}

convert examples to features: 1it [00:00, 83.05it/s]


Question: What is a discipline leader supposed to?
Options: ['take care of the whole group', 'make sure that everybody finishes homework', 'make sure that nobody chats in class', 'collect all the homework and hand it in to teachers']
Result: A


##### Modification B

In [23]:
new_question = "What is a discipline leader?"
new_options = [
    "A person supposed to take care of the whole group",
    "A person supposed to make sure that everybody finished homework",
    "A person supposed to make sure that nobody chats in class",
    "A person supposed to collect all the homework and hand it in to teachers"
]

result = predict(article, new_question, new_options, real_label)
print(f"Question: {new_question}")
print(f"Options: {new_options}")
print(f"Result: {label_map[result]}")

ex_mod_b = {
    'article': article,
    'question': new_question,
    'options': new_options,
    'answer': real_label
}

convert examples to features: 1it [00:00, 81.57it/s]


Question: What is a discipline leader?
Options: ['A person supposed to take care of the whole group', 'A person supposed to make sure that everybody finished homework', 'A person supposed to make sure that nobody chats in class', 'A person supposed to collect all the homework and hand it in to teachers']
Result: B


##### Modification C

In [24]:
new_question = "What is an orderliness leader?"
new_options = [
    "A person supposed to take care of the whole group",
    "A person supposed to make sure that everybody finished homework",
    "A person supposed to make sure that nobody chats in class",
    "A person supposed to collect all the homework and hand it in to teachers"
]
result = predict(article, new_question, new_options, real_label)
print(f"Question: {new_question}")
print(f"Options: {new_options}")
print(f"Result: {label_map[result]}")

ex_mod_c = {
    'article': article,
    'question': new_question,
    'options': new_options,
    'answer': real_label
}

convert examples to features: 1it [00:00, 84.44it/s]


Question: What is an orderliness leader?
Options: ['A person supposed to take care of the whole group', 'A person supposed to make sure that everybody finished homework', 'A person supposed to make sure that nobody chats in class', 'A person supposed to collect all the homework and hand it in to teachers']
Result: C


###### Visualization

In [33]:
atts, tokens, option_start = show_head_view(ex_mod_b)

convert examples to features: 1it [00:00, 64.03it/s]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>