In [3]:
import json
import torch
import os, sys, random, numpy, pickle
import argparse
import logging
from pprint import pprint
from qdgat.drop_reader import DropReader
from qdgat.drop_dataloader import DropBatchGen
from qdgat.network import QDGATNet
from qdgat.utils import AverageMeter
from datetime import datetime
from qdgat.optimizer import AdamW
from qdgat.utils import create_logger
from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, RandomSampler
from qdgat.drop_dataloader import create_collate_fn
# from tqdm.notebook import tqdm
from tqdm import tqdm

In [4]:
logger = logging.getLogger()
logging.basicConfig(level = logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

In [5]:
from collections import namedtuple

args = {
    'passage_length_limit': 463,
    'question_length_limit': 46,
    'data_dir': "./raw_data",
    'roberta_model': "roberta-base",
    'output_dir': './data'
}

args = namedtuple("Arguments", args.keys())(*args.values())

In [6]:
tokenizer = RobertaTokenizer.from_pretrained(args.roberta_model)

In [7]:
def preprocess_drop(args, data_path, tokenizer, mode):
    if not os.path.exists(data_path):
        raise Exception("Missing %s for preprocessing."%data_path)

    skip_when_all_empty = ["passage_span", "question_span", "addition_subtraction", "counting", "multi_span"] if mode=='train' else None
    reader = DropReader(
        tokenizer, args.passage_length_limit, args.question_length_limit,
        skip_when_all_empty=skip_when_all_empty
    )
    data = reader._read(data_path)
    return data

In [None]:
# pre-process the data for pretrained model
pretrained_data = preprocess_drop(args, os.path.join(args.data_dir, "pretrain_split.json"), tokenizer, 'train')
cache_fpath = os.path.join(args.output_dir, "%s.pkl"%'pretrain')
with open(cache_fpath, "wb") as f:
    pickle.dump(pretrained_data, f)

In [10]:
# pre-process the data for cross validation
for fold in range(5):
    cv_data = preprocess_drop(args, os.path.join(args.data_dir, "cv_fold-%d.json"%fold), tokenizer, 'train')
    cache_fpath = os.path.join(os.path.join(args.output_dir, 'validation'), "train_%s_%d.pkl"%('cv', fold))
    with open(cache_fpath, "wb") as f:
        pickle.dump(cv_data, f)

Reading file at %s ./raw_data/cv_fold-0.json
Reading the dataset
skip: 160b7985-d826-4ff6-8006-d8d4c77f96a2: How many days passed between John entering Perpignan and the treaty at Perpignan ?
skip: 67da715c-c057-4d09-85b7-3c6649b024b1: How many men were on each ship that left on July 8 , 1497 , on average ?
skip: 7b139d33-6d75-4fb1-aae6-5b84ec36edd2: How many days passed between the fleet leaving Libon and them making first contact ?
skip: 6b754ab9-9e48-4c6d-ae48-69d81bb71903: How many months did these events span for ?
skip: 326e4d80-25c9-440a-9bff-3603bfd9a345: How many years after the Dow was holding steady between 7,000 and 9,000 did the Dow pass 14,000 ?
skip: 09e97b73-afdf-483d-bb44-9569a34b0e27: Around how many months do these events span ?
skip: b619bf65-97f5-419d-a38b-36ce7dccc383: How many dollars would Johnson have earned in 2008 if he achieved all bonuses ?
skip: e39c2efe-9004-4225-80f0-3e7aa9fc2552: How many percent of people were not working at home ?
skip: e2d4beeb-2b95-