# Data exploration

In [2]:
from datasets import load_dataset
from transformers import pipeline
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
from transformers import DistilBertTokenizerFast
import time
import numpy as np
import pandas as pd
import torch

In [3]:
split = "train"
cache_dir = "./data_cache"

dialogue_dataset = load_dataset(
    "doc2dial",
    name="dialogue_domain",  # this is the name of the dataset for the second subtask, dialog generation
    split=split,
    ignore_verifications=True,
    cache_dir=cache_dir,
)

Reusing dataset doc2dial (./data_cache/doc2dial/dialogue_domain/1.0.1/c15afdf53780a8d6ebea7aec05384432195b356f879aa53a4ee39b740d520642)


In [4]:
document_dataset = load_dataset(
    "doc2dial",
    name="document_domain",  # this is the name of the dataset for the second subtask, dialog generation
    split=split,
    ignore_verifications=True,
    cache_dir=cache_dir,
)

Reusing dataset doc2dial (./data_cache/doc2dial/document_domain/1.0.1/c15afdf53780a8d6ebea7aec05384432195b356f879aa53a4ee39b740d520642)


In [9]:
dialogue_dataset[0]

{'dial_id': '9f44c1539efe6f7e79b02eb1b413aa43',
 'doc_id': 'Top 5 DMV Mistakes and How to Avoid Them#3_0',
 'domain': 'dmv',
 'turns': [{'turn_id': 1,
   'role': 'user',
   'da': 'query_condition',
   'references': [{'sp_id': '4', 'label': 'precondition'}],
   'utterance': 'Hello, I forgot o update my address, can you help me with that?'},
  {'turn_id': 2,
   'role': 'agent',
   'da': 'respond_solution',
   'references': [{'sp_id': '6', 'label': 'solution'},
    {'sp_id': '7', 'label': 'solution'}],
   'utterance': 'hi, you have to report any change of address to DMV within 10 days after moving. You should do this both for the address associated with your license and all the addresses associated with all your vehicles.'},
  {'turn_id': 3,
   'role': 'user',
   'da': 'query_solution',
   'references': [{'sp_id': '56', 'label': 'solution'}],
   'utterance': 'Can I do my DMV transactions online?'},
  {'turn_id': 4,
   'role': 'agent',
   'da': 'respond_solution',
   'references': [{'sp_id

In [8]:
document_dataset[250]

{'domain': 'dmv',
 'doc_id': 'Top 5 DMV Mistakes and How to Avoid Them#3_0',
 'title': 'Top 5 DMV Mistakes and How to Avoid Them#3',
 'doc_text': 'Many DMV customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. Because we see customers make these mistakes over and over again , we are issuing this list of the top five DMV mistakes and how to avoid them. \n\n1. Forgetting to Update Address \nBy statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. It is not sufficient to only: write your new address on the back of your old license; tell the United States Postal Service; or inform the police officer writing you a ticket. If you fail to keep your address current , you will miss a suspension order and may be charged with operati

In [11]:
search_domain = 'dmv'
search_doc_id = 'Top 5 DMV Mistakes and How to Avoid Them#3_0'
search_id_sp = ['6','7']

def text_from_spans(search_domain, search_doc_id, search_id_sp, document_dataset):
    start = time.time()
    total_answer = ''
    for doc in document_dataset:
        if doc['domain'] == search_domain and doc['doc_id'] == search_doc_id:
            for span in doc['spans']:
                if span['id_sp'] in search_id_sp:
                    total_answer+=span['text_sp']
            break
    print(f"Time elapsed: {time.time() - start}")
    return total_answer

text_from_spans(search_domain, search_doc_id, search_id_sp, document_dataset)

Time elapsed: 0.33464717864990234


'By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. '

## Creating the dataset

Steps:
- [X] Sliding windows from the Document
- [ ] Extract user utterance
- [ ] Extract Dialogue history

### Sliding windows from the Document

In [4]:
# Tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('bert-base-uncased')

# Defining train_dict
train_dict = dict()
train_dict['train_document'] = []
train_dict['train_id_sp'] = []
train_dict['train_user_utterance'] = []
train_dict['train_doc_domain'] = []
train_dict['train_doc_id'] = []
train_dict['train_text_sp'] = []
train_dict['train_dial_id_turn_id'] = []     # necessary for evaluation
train_dict['train_start_pos'] = []     
train_dict['train_end_pos'] = []     
train_dict['train_start_tok'] = []     
train_dict['train_end_tok'] = []  

start = time.time()
for idx, dialogue in tqdm(enumerate(dialogue_dataset)):
    if idx == 10:
        break
    dial_id_turn_id = []       # running list of <dial_id>_<turn_id> for evaluation
    sp_id_list = []            # running list of spans per document
    user_utterance_list = []   # running list of user utterances per document
    
    for turn in dialogue['turns']:
        dial_id_turn_id.append(dialogue['dial_id'] + '_' + str(turn['turn_id']))
        if turn['role'] == 'user':
            # TURN UTTERANCE IS FLATTENED AND ONLY THE [INPUT_IDS] IS STORED
            turn['utterance'] = tokenizer(turn['utterance'], padding=True, truncation=True, return_tensors="pt")['input_ids'].view(-1)
            user_utterance_list.append(turn['utterance'])   # adding user utterance to user_utterance_list
        else:
            references = turn['references']
            ref_sp_id = []
            for ref in references:
                ref_sp_id.append(ref['sp_id'])
            sp_id_list.append(ref_sp_id)          # adding list of sp_ids per dialogue to list of sp_ids per document
    train_dict['train_id_sp'].append(sp_id_list)
    train_dict['train_user_utterance'].append(user_utterance_list)
    train_dict['train_doc_domain'].append(dialogue['domain'])
    train_dict['train_doc_id'].append(dialogue['doc_id'])
    train_dict['train_dial_id_turn_id'].append(dial_id_turn_id)
    
    for doc in document_dataset:
        if doc['doc_id'] == train_dict['train_doc_id'][-1]:
            # DOCUMENT TEXT IS NOT A TENSOR. PREVIOUSLY WE HAD tokenizer( )['index_ids'].view(-1)
            doc['doc_text'] = tokenizer(doc['doc_text'], padding=True, truncation=False, return_tensors="pt")
            train_dict['train_document'].append(doc['doc_text'])          # adding the total document text
            text_sp_2 = []            
            start_sp_list = []         # big start sp list
            end_sp_list = []           # big end sp list        
            start_tok_list = []         # big start token list
            end_tok_list = []           # big end token list     
            for train_spans_id in train_dict['train_id_sp'][-1]:    
                text_sp = ""         
                ref_start_pos_list = []
                ref_end_pos_list = []      
                for span in doc['spans']:                    
                    if span['id_sp'] in train_spans_id:
                        text_sp += span['text_sp']                        
                        ref_start_pos_list.append(span['start_sp'])
                        ref_end_pos_list.append(span['end_sp'])    
                start_pos = np.amin(ref_start_pos_list)
                start_sp_list.append(start_pos)
                # convert start_pos to start_token
                start_tok_pos = doc['doc_text'].char_to_token(start_pos)
                start_tok_list.append(start_tok_pos)
                # convert end_pos to end_token
                end_pos = np.amax(ref_end_pos_list)
                end_sp_list.append(end_pos)
                end_tok_pos = doc['doc_text'].char_to_token(end_pos)
                end_tok_list.append(end_tok_pos)
                text_sp_2.append(text_sp)
            train_dict['train_text_sp'].append(text_sp_2)
            train_dict['train_start_pos'].append(start_sp_list)
            train_dict['train_end_pos'].append(end_sp_list)
            train_dict['train_start_tok'].append(start_tok_list)
            train_dict['train_end_tok'].append(end_tok_list)
            break
end = time.time()
print(f'Total time: {end-start}')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'DistilBertTokenizerFast'.
0it [00:00, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
10it [00:03,  3.33it/s]

Total time: 3.030385971069336





Results:

In [None]:
print('User utterances:')
print(train_dict['train_user_utterance'][0])

print('\nID Sp:')
print(train_dict['train_id_sp'][0])

print('\nDoc ID:')
print(train_dict['train_doc_id'][0])

print('\nDoc domain:')
print(train_dict['train_doc_domain'][0])

print('\nTrain text spans:')
print(train_dict['train_text_sp'][0])

print('\nDial_ID Turn_ID:')
print(train_dict['train_dial_id_turn_id'][0])

In [None]:
print('\nDoc text:')
print(train_dict['train_document'][0])

## Create a Dataframe out of the train_dictionary

In [None]:
data = pd.DataFrame(train_dict)

In [None]:
from datasets import load_metric

metric = load_metric("squad_v2")
print(metric.features) #this shows you what format the metric is expecting

prediction = {'id': <rc dataset is of shape dialid_turnid - this value has to match the answer>,
              'prediction_text': <your prediction>,
              'no_answer_probability': 0.0} #edwin said we can ignore this for task 1
reference = {'id': <see prediction>, 
              'answers': {
                  'text': [list of answer, best to use the ones from the rc dataset],                                       
                  'answer_start': [list of numbers of the answer star char again see rc dataset. ]}
            }

metric.add(prediction=prediction, reference=reference)
final_score = metric.compute()
final_score

# Testing model

In [None]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [11]:
question=train_dict['train_user_utterance'][0][0]
print(f'Decoded question: {tokenizer.decode(question)}')
# If already tokenized from dataset
text=train_dict['train_document'][0]    # tokenized text
# if simple text
#text='By statute , you must report a change of address to DMV within ten days of moving. That is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ.'
#text=tokenizer([text],  return_tensors="pt")['input_ids'].view(-1)

def text_mask(question, text):
    '''   
    text['input_ids'].view(-1)[1:] was on the line below where 'text' is now - 
    need to do this to text before sending it into this function
    
    input_ids: will be the question and the window of the document concat together
    segment_ids: is a mask that makes the two sentences distinct 1's for question 0 for document text
    '''
    input_ids=torch.cat((question, text), 0)
    SEP_token_id=102
    sep_idx = (input_ids == 102).nonzero(as_tuple=False)[0][0].item()
    num_seg_a = sep_idx+1
    num_seg_b = len(input_ids) - num_seg_a
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    assert len(segment_ids) == len(input_ids)
    return input_ids, segment_ids


Decoded question: [CLS] hello, i forgot o update my address, can you help me with that? [SEP]


Create mask for start and end positions. This way we only check the first token after the '.' as start positions, and the tokens before the '.' as end positions.

In [None]:
# token id for '.' = 1012
def mask_start_end(input_ids_trunc, segment_ids_trunc, mode):
    """Returns a mask for the start and end logits. 
    input_ids_trunc = tokens (tensor)
    segment_ids_trunc = mask (question / text)
    mode = "start" or 'end'
    return tensor
    """
    a = torch.where(input_ids_trunc == 1012, 1, 0)   # mask=1 for '.'
    a = a * torch.tensor(segment_ids_trunc)          # mask question - text
    if mode=='start':
        b = torch.cat((torch.tensor([0]),a),0)[:-1]     # move the 1s one position to the right
    else:
        b = torch.cat((a, torch.tensor([0])),0)[1:]
    assert len (a) == len(b)
    return b

def tensor_to_positive(tensor, mask):
    """ All the values need to be higher than 0, since 0s are values for the mask
    and we don't want to choose them when selecting the start or end token.
    Return torch.tensor """
    min_value = torch.amin(tensor) 
    tensor_positive = tensor + (mask * np.abs(min_value.detach().numpy()))
    return tensor_positive

In [None]:
input_ids_trunc = input_ids[:511]
input_ids_trunc = torch.cat((input_ids_trunc, torch.tensor([102])),0)
segment_ids_trunc = segment_ids[:512]

output = model(input_ids_trunc.view(1,-1), token_type_ids=torch.tensor([segment_ids_trunc]))
tokens = tokenizer.convert_ids_to_tokens(input_ids_trunc)   # table with token_id -> word
#tokens with highest start and end scores
mask_start = mask_start_end(input_ids_trunc, segment_ids_trunc, 'start')
start_logits_positive = tensor_to_positive(output.start_logits * mask_start, mask_start)
answer_start = torch.argmax(start_logits_positive)  # token index for the highest start token
max_start_prob = output.start_logits[0][answer_start].item()
mask_end = mask_start_end(input_ids_trunc, segment_ids_trunc, 'end')
end_logits_positive = tensor_to_positive(output.end_logits * mask_end, mask_end)
answer_end = torch.argmax(end_logits_positive)
max_end_prob = output.end_logits[0][answer_end].item()
sum_joint_prob = max_start_prob + max_end_prob
if answer_end >= answer_start:
    answer = " ".join(tokens[answer_start:answer_end+1])
else:
    print("I am unable to find the answer to this question. Can you please ask another question?")
    
print("\nQuestion:\n{}".format(tokenizer.decode(question)))
print("\nAnswer:\n{}.".format(answer))

### Output - For Report

This section of text shows that span [49][50][51][52] is what we return. However, the section belows is what the ground truth says. SPan [51] is highlighted in red (we return 51, ground truth doesn't contain it).

- 'About ten percent of customers visiting a DMV office do not bring what they need to complete their transaction, and have to come back a second time to finish their business. This can be as simple as not bringing sufficient funds to pay for a license renewal or not having the proof of auto insurance required to register a car. <font color='red'>Better yet ,</font> don t visit a DMV office at all, and see if your transaction can be performed online, like an address change, registration renewal, license renewal, replacing a lost title, paying a DRA or scheduling a road test. '
- 'About ten percent of customers visiting a DMV office do not bring what they need to complete their transaction, and have to come back a second time to finish their business. This can be as simple as not bringing sufficient funds to pay for a license renewal or not having the proof of auto insurance required to register a car. don t visit a DMV office at all, and see if your transaction can be performed online, like an address change, registration renewal, license renewal, replacing a lost title, paying a DRA or scheduling a road test. '

In [12]:
def add_sep_tokens(windows):
    tmp = []
    sep_token = 102
    for window in windows:
        end = len(window) - 1
        if window[end] != sep_token:
            # remove final value, add a SEP
            window = window[0:-1]
            tmp.append(torch.cat((window, torch.tensor([102])),0))
        else:
            tmp.append(window)
    return tmp

def sliding_windows(question, document, stride=256):
    # tokenized input_ids is the document - remove [CLS] before sending through
    windows = []
    model_tok_limit = 512  # model can take 512 tokens maximum
    start = 0
    end = model_tok_limit - len(question)
    doc_size = len(document)
        
    # handling edge case of documents smaller than models input (512 tokens)   
    if len(document) <= model_tok_limit:
        end = len(document)
    
    while(start <= doc_size):
        # print(start, end, doc_size)
        window = document[start:end]
        windows.append(window)
        
        if end == doc_size: 
            break
        
        start += stride
        # if there are less tokens than the slide amount
        if (doc_size - (start + stride)) < stride:
            end = doc_size
        else:
            end += stride
    
    windows = add_sep_tokens(windows)
    return windows

In [13]:
''' Sliding Window [1:] to remove the [CLS] that was put in by the tokenizer
    The Model likes '[CLS] Sentence1 [SEP] Sentence2 [SEP]' it doesn't need the [CLS]
'''

windows = sliding_windows(question, text['input_ids'][0][1:])

In [18]:
model_inputs = []

for window in windows:
    model_inputs.append(text_mask(question, window))

In [19]:
for model_input in model_inputs:
    print(tokenizer.decode(model_input[0]))

[CLS] hello, i forgot o update my address, can you help me with that? [SEP] many dmv customers make easily avoidable mistakes that cause them significant problems, including encounters with law enforcement and impounded vehicles. because we see customers make these mistakes over and over again, we are issuing this list of the top five dmv mistakes and how to avoid them. 1. forgetting to update address by statute, you must report a change of address to dmv within ten days of moving. that is the case for the address associated with your license, as well as all the addresses associated with each registered vehicle, which may differ. it is not sufficient to only : write your new address on the back of your old license ; tell the united states postal service ; or inform the police officer writing you a ticket. if you fail to keep your address current, you will miss a suspension order and may be charged with operating an unregistered vehicle and / or aggravated unlicensed operation, both mis