In [None]:
#First, we do prediction, then we do classifier by other classifier for unanswerable question
#as we find by finding [CLS] output for unanswerable question has low accuracy to idenify it.

#for classifier we had refered to https://github.com/ThaddeusSegura/BERT_on_SQuAD/blob/master/SE_classification.ipynb

In [None]:
from tqdm.auto import tqdm  # for showing progress bar
from datasets import load_dataset
import json
import pandas as pd
import numpy as np
import torch
from transformers import BertForQuestionAnswering, BertForSequenceClassification
from transformers import BertTokenizerFast, BertTokenizer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

device = torch.device('cpu')
#Using torch by GPU
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("Use cuda device:", torch.cuda.get_device_name(0))
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device('cpu')
    print("use cpu")

#set dataset location
SQuAD = pd.read_csv('dev-v2.0-combined-use.csv')
#setoutput file name
output_name = "output-dev.json"
#checkpoint path
model_qa = BertForQuestionAnswering.from_pretrained('./bert_qa_pt_3/', local_files_only=True)
#tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
#tokenizer = BertTokenizerFast.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
#set batch_size here, This batch size just deside how many dataset to gpu at once
#but gpu still good though them one by one, this should low vram usage
batch_size_qa = 128

In [None]:
def prep_data(dataset):
    #convet values from dataframe to list for tokenizer
    questions = dataset['question'].values.tolist()
    contexts = dataset['context'].values.tolist()
    id = dataset['id'].values.tolist()
    #answers = dataset['text']
    
    return {
        'question': questions,
        'context': contexts,
        'id': id,
        #'answers': answers
    }

dataset = prep_data(SQuAD)
input_ids_list = []
attention_mask_list = []
question_id_list = []
type_ids_list = []

for dataset_question, dataset_context in zip(dataset['question'], dataset['context']):
    #do two times as by one line cannot get all values
    
    tokenized__for_index = tokenizer.encode(dataset_context,dataset_question,truncation=True,padding='max_length')                                 
    tokenized = tokenizer.encode_plus(dataset_context,
                            dataset_question,
                            add_special_tokens=True,    # Add `[CLS]` and `[SEP]`
                            truncation=True,
                            return_attention_mask=True,  # Construct attn. masks.
                            padding='max_length',
                            return_tensors='pt')
    
    #find location of sep
    sep_index = tokenized__for_index.index(tokenizer.sep_token_id)

    #segment a which contains sep itself
    segment_a_no = sep_index + 1

    #segment b the rest of the ids
    segment_b_no = len(tokenized__for_index) - segment_a_no

    #make a list, 0 for segment a, 1 for segment b.
    segment_ids = [0] * segment_a_no + [1] * segment_b_no
    #padding is handled by attention mask
    #check if segment_ids is normal
    assert len(tokenized__for_index) == len(segment_ids)

    type_ids = segment_ids
    type_ids_list.append(type_ids)
    input_ids_list.append(tokenized['input_ids'])
    attention_mask_list.append(tokenized['attention_mask'])

for id in dataset['id']:
    question_id_list.append(id)

In [None]:
if len(set(map(len,type_ids_list)))==1:
    print("All are the same length")
else:
    print("They are not the same length!")

In [None]:
#question id list making
for id in dataset['id']:
    question_id_list.append(id)
    
#to tensor
input_ids_list = torch.stack(input_ids_list)
attention_mask_list = torch.stack(attention_mask_list)
type_ids_list = torch.Tensor(type_ids_list).to(torch.int64)

#pack dataset into data loader
data_3_elements = TensorDataset(input_ids_list, attention_mask_list, type_ids_list)
dataloader = DataLoader(data_3_elements, sampler=None, batch_size=batch_size_qa, shuffle=False)

In [None]:
model_qa.to(device)        #put qa model to device

In [None]:
#run and get result from model 
answers_list = []
for batch in tqdm(dataloader):
    #batch = torch.stack(batch)
    #batch = batch.to(device)
    
    batch_input, batch_mask, batch_type = tuple(v.to(device) for v in batch)
    batch_input = batch_input.to(device)
    batch_mask = batch_mask.to(device)
    batch_type = batch_type.to(device)
    for input, mask, type in zip(batch_input, batch_mask, batch_type):

        tokens = tokenizer.convert_ids_to_tokens(input[0])  #[1] is device="device"
        result = model_qa(input_ids = input, attention_mask=mask, token_type_ids=type)       #use ids and mask and type to get result
        answer_start = torch.argmax(result.start_logits)    #get start position by argmax
        answer_end = torch.argmax(result.end_logits)        #get end position by argmax

        # join the break word
        if answer_end >= answer_start:
            answer = tokens[answer_start]
            for i in range(answer_start + 1, answer_end + 1):
                if tokens[i][0:2] == "##":
                    answer = ""
                else:
                    answer += " " + tokens[i]
        '''     #In our method, we use other classifier to do this
        if answer.startswith("[CLS]"):
            # CLS means Unable to find the answer to your question.
            answer = ""
        '''
        answers_list.append(answer)
        

In [None]:
del model_qa       #del model release vram

In [None]:
#question answering complete, now for classifier to decide what question is ""

In [None]:
classifier_devset = pd.read_csv("dev-v2.0-combined-use.csv")

In [None]:

classifier_model = BertForSequenceClassification.from_pretrained(
    './bert_qa_classifier_pt_3/', # 12-layer BERT
    num_labels = 2, #0:false 1:true
    output_attentions = False, # no attention output
    output_hidden_states = False, # no need for classifier
)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [None]:
classifier_model.to(device)    #put classifier model to device

In [None]:
input_ids_list = []
attention_mask_list = []

for questions, context in zip(classifier_devset['question'], classifier_devset['context']):

    tokenized = tokenizer.encode_plus(questions,context,
                            add_special_tokens=True,    # Add `[CLS]` and `[SEP]`
                            truncation=True,
                            return_attention_mask=True,  # Construct attn. masks.
                            padding='max_length',       #512
                            max_length=512)
    
    input_ids_list.append(tokenized['input_ids'])
    attention_mask_list.append(tokenized['attention_mask'])

In [None]:
input_ids = torch.tensor(input_ids_list)
attention_masks = torch.tensor(attention_mask_list)


data_2_elements = torch.utils.data.TensorDataset(input_ids, attention_masks)

classifier_batch_size = 32

dataloader = DataLoader(
            data_2_elements,  # The training samples.
            sampler = None,
            batch_size = classifier_batch_size, shuffle=False)

In [None]:
#loop to classify
from tqdm.auto import tqdm
is_answerable_logits_list = []
loop = tqdm(dataloader)
for batch in loop:
    batch_ids = batch[0].to(device)
    batch_mask = batch[1].to(device)
    
    with torch.no_grad(): 

        output = classifier_model(input_ids=batch_ids, 
                                token_type_ids=None, 
                                attention_mask=batch_mask 
                                )

    #get predict logits and move it with true label to cpu
    logits = output['logits']

    logits = logits.detach().cpu().numpy()
    is_answerable_logits_list.append(logits)

    


In [None]:

all_output = []
id_counter = 0
answers_list
for batch_losgits in is_answerable_logits_list:
    for logit in batch_losgits:
        
        result_0_or_1 = np.argmax(logit, axis=None).flatten()      #0 false 1 true
        if result_0_or_1[0] == 0:
            answers_list[id_counter] = ''      #0 means unanswerable question, so ''

        temp_output = {question_id_list[id_counter]: answers_list[id_counter]}

        all_output.append(temp_output)
        id_counter += 1
    

In [None]:
#write result to json file it should be in wrong format
with open(output_name, "w") as outfile:
    json.dump(all_output, outfile)

In [None]:
def fix_json():
    f = open(output_name, "r")
    line = f.read()
    f.close()

    line = line.replace("[", "")
    line = line.replace("]", "")
    line = line.replace("{", "")
    line = line.replace("}", "")

    line = "{" + line + "}"
    
    f = open(output_name, "w")
    f.write(line)
    f.close()

In [None]:
#fix bad formatted json
fix_json()