In [None]:
import transformers
import torch
import json
from tqdm import tqdm
from transformers import BertTokenizer
import pandas as pd
from utils import QADataset, collate_fn, preprocess
from transformers import BertForQuestionAnswering
from torch.utils.data import DataLoader

In [None]:
PRETRAINED_MODEL_NAME = "bert-base-chinese" 
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)

In [3]:
valid = preprocess("data/dev.json", "test")
validset = QADataset(valid, "test", tokenizer=tokenizer)

In [4]:
BATCH_SIZE = 8
validloader = DataLoader(validset, batch_size=BATCH_SIZE, 
                         collate_fn=collate_fn)

In [5]:
def clean_word(sentence):
    stop_words = ['##', '[UNK]', '[CLS]']
    for s in stop_words:
        sentence = sentence.replace(s, '')
    if(sentence):
        if(sentence[0]=="《" and sentence[-1] != "》"):
            sentence +=  "》"
        if(sentence[0]!="《" and sentence[-1] == "》"):
            sentence = "《"+sentence
    return sentence

def predict(model, dataloader, device):
    predictions = None
    correct = 0
    total = 0
    total_loss = 0
    predictions = {}
    with torch.no_grad():
        for data in tqdm(dataloader):
            ids = data[5]
            data = [t.to(device) for t in data[:-1]]
            tokens_tensors, segments_tensors, masks_tensors = data[:3]
            start_scores, end_scores = model(input_ids=tokens_tensors, 
                            token_type_ids=segments_tensors, 
                            attention_mask=masks_tensors)
            start_scores = torch.nn.functional.softmax(start_scores)
            end_scores = torch.nn.functional.softmax(end_scores)
            for i in range(len(data[0])):
                id = ids[i]
                all_tokens = tokenizer.convert_ids_to_tokens(tokens_tensors[i])
                start, end = torch.argmax(start_scores[i]), torch.argmax(end_scores[i])
                scorer = start_scores[i]+end_scores[i]
                retry = 1
                while(end - start >= 20 or start > end or (not start and end)):
                    retry += 1
                    _, starts = torch.topk(start_scores[i], retry)
                    _, ends = torch.topk(end_scores[i], retry)
                    if start == 0:
                        start = starts[retry-1]
                    else:
                        if(start_scores[i][start]+end_scores[i][ends[retry-1]] > start_scores[i][starts[retry-1]]+end_scores[i][end]):
                            end = ends[retry-1]
                        else:
                            start = starts[retry-1]
                is_answerable = torch.sigmoid(torch.tensor([scorer[0], (start_scores[i][start] + end_scores[i][end])]))
                if(is_answerable[0] > is_answerable[1]):
                    predictions[id] = ''
                else:
                    temp = ''.join(all_tokens[start:end])
                    predictions[id] = clean_word(temp)

    return predictions

In [6]:
##TODO model.load()
PRETRAINED_MODEL_NAME = "bert-base-chinese"
NUM_LABELS = 2
model = BertForQuestionAnswering.from_pretrained(
    PRETRAINED_MODEL_NAME)

model.load_state_dict(torch.load('./save/QA-final.pkl'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.cuda()

In [7]:
predictions = predict(model, validloader,device)
json.dump(predictions, open('./output/test.json', "w"))

100%|██████████| 630/630 [01:51<00:00,  5.67it/s]


In [9]:
output = json.dump(predictions, open('./output/qa_1.json', "w"))

In [10]:
predictions

{'ee807557a886e8775242e561575818737c77880dae59e91a698598b0': '歐洲',
 '314e58ad0bd4ac6c0c73ad87f83a36560d6d78bc028864356b88f519': '梵語',
 '6ee8f450ae49a8c36ed1efb1cc5d15d92fd6ada456a1739934ca7d93': '語文學家',
 '2bb5949c9612ff127b31b86a2746f359f756ad2697f1cd4a32e831ff': '',
 'e9d2ec04fad375aaf771faac1760d9dd4c040c7127e9ffcaaa581719': '當梵語發展成俗語之後',
 '4aaf75c36bf4cca749b5e35193a3ddf25dd158abdc8cba249b1d8316': '19世紀晚期',
 '338808b7cd565984fb4803c04cfb518439b23542ff8ec99e7f7a288c': '婆羅米文',
 '5b42e3f2daa3d7c938cedd7c87e0325195c081a4c49d8f28327dd4bc': '無頭化變體',
 'c347b913dc64da71f19a717d2b2e69c74b4d0a9397cb671ec5116861': '1991年',
 '19b0c3ea2d437c2e7c8bebe5a4c44fe0d824d871158aedd63c7573ce': '《蘇達爾摩》',
 '1be6494c75a6a875b38a35d469b7ec037a45228dd22dfd9ca0dac876': '馬祖列島',
 '9323b2b379191eb21e32a578239fa25fee3626ae5b8d7698bf127a08': '交通部觀光局',
 '294cf4c5da3ae75e7ac0658e56631d3cf3cfa952a3062fc9b1f7d029': '2000年6月',
 '3ea85bac73d5a4c90dbb39967af088dc720da8a1f330437c40e18d74': '白犬列島',
 '627fa2952d2f2c9a26a031d

BertConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "model_type": "bert",
  "no_repeat_ngram_size": 0,
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pad_token_id": 0,
  "prefix": null,
  "pruned_