In [1]:
from __future__ import absolute_import, division, print_function
# 放在第一句，不然会报错

import collections
import logging
from operator import index
import os
import random
from xml.dom.minidom import Document
from matplotlib.pyplot import title
import numpy as np
import pandas as pd
import csv

import torch
import torch.nn as nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, Dataset)
from torch.utils.data.distributed import DistributedSampler
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from tqdm import tqdm, trange
from random import random, randrange, randint, shuffle, choice, sample
from transformers import BertModel,BertConfig,BertTokenizer,BertForMaskedLM
from transformers import AdamW
from transformers.optimization import (
    get_constant_schedule,
    get_constant_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_polynomial_decay_schedule_with_warmup,
)
from data_processor import RTE_Processing,CB_Processing,BoolQ_Processing,Json_File_Reader
import protum_args as args
import jieba
import re
import json
from sklearn.model_selection import KFold
from sklearn.metrics  import f1_score

In [2]:
class ProtumDataset(Dataset):
    def __init__(self,data,max_seq_length):
        self.max_seq_length = max_seq_length
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,index):
        data_item = self.data[index]
        input_ids = data_item['input_ids']
        segment_ids = data_item['segment_ids']
        attention_ids = data_item['attention_ids']
        prompt_positions = data_item['prompt_positions']
        label_ids = data_item['label_ids']

        assert len(input_ids)==len(segment_ids)==len(attention_ids)
        padding_length = self.max_seq_length - len(input_ids)
        input_ids += [0] * padding_length
        attention_ids += [0] * padding_length
        segment_ids += [0] * padding_length
        
        input_ids = torch.tensor(input_ids,dtype=torch.long)
        segment_ids = torch.tensor(segment_ids,dtype=torch.long)
        attention_ids = torch.tensor(attention_ids,dtype=torch.long)
        prompt_positions = torch.tensor(prompt_positions,dtype=torch.long)
        label_ids = torch.tensor(label_ids,dtype=torch.long)
        return input_ids, segment_ids, attention_ids, prompt_positions,label_ids

In [3]:
tokenizer = BertTokenizer.from_pretrained(args.model_name_from_hugging_face)

vocab_list = list(tokenizer.vocab.keys())
config = BertConfig.from_pretrained(args.model_name_from_hugging_face)

In [4]:
train_data = Json_File_Reader('RTE',args.train_data_path)
eval_data = Json_File_Reader('RTE',args.dev_data_path)

In [10]:
eval_processor = RTE_Processing(eval_data,
                 tokenizer,
                 args.max_seq_length,
                 vocab_list,
                 is_pretraining=False)
eval_inputs = eval_processor.Creat_Input_For_PLMs()
EvalDataset = ProtumDataset(eval_inputs,args.max_seq_length)

RTE Data Processing: 277it [00:00, 592.61it/s]


In [9]:
args.max_seq_length

128

In [13]:
EvalDataset[1]

(tensor([  101,  6355,   117,  1195,  1208,  1132, 15137,  1115,  2848, 25523,
          1132,  3196,  1147, 12949,  1222,  6946,   119, 20012,   118,  3989,
         10548,  1132,   182, 15012,  1916,  4946,  1190,  1195,  1169,  1435,
          1146,  1114,  1207,  2848, 25523,  1106,  2147,  1103,  1207,  9138,
           119, 22171,   131, 18757, 25857,  1465,  1110,  2183,  1103,  1594,
          1222,  2848, 25523,   136,  1109, 26018,   131,   103,   119,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [14]:
eval_inputs[1]

{'tokens': ['[CLS]',
  'Yet',
  ',',
  'we',
  'now',
  'are',
  'discovering',
  'that',
  'anti',
  '##biotics',
  'are',
  'losing',
  'their',
  'effectiveness',
  'against',
  'illness',
  '.',
  'Disease',
  '-',
  'causing',
  'bacteria',
  'are',
  'm',
  '##uta',
  '##ting',
  'faster',
  'than',
  'we',
  'can',
  'come',
  'up',
  'with',
  'new',
  'anti',
  '##biotics',
  'to',
  'fight',
  'the',
  'new',
  'variations',
  '.',
  'Question',
  ':',
  'Ba',
  '##cter',
  '##ia',
  'is',
  'winning',
  'the',
  'war',
  'against',
  'anti',
  '##biotics',
  '?',
  'The',
  'Answer',
  ':',
  '[MASK]',
  '.',
  '[SEP]'],
 'input_ids': [101,
  6355,
  117,
  1195,
  1208,
  1132,
  15137,
  1115,
  2848,
  25523,
  1132,
  3196,
  1147,
  12949,
  1222,
  6946,
  119,
  20012,
  118,
  3989,
  10548,
  1132,
  182,
  15012,
  1916,
  4946,
  1190,
  1195,
  1169,
  1435,
  1146,
  1114,
  1207,
  2848,
  25523,
  1106,
  2147,
  1103,
  1207,
  9138,
  119,
  22171,
  131,
  