In [1]:
import json
from time import time
import gc
import random
from tqdm import tqdm
import tensorflow as tf
from torch import nn

from multiprocessing import Pool
import multiprocessing
from fairseq.models.roberta import RobertaModel
import torch
from glob import glob
import numpy as np

from tokenizer.roberta import RobertaTokenizer

max_seq_length   = 512
max_query_length = 128
doc_stride       = 128

get_tokenizer = lambda: RobertaTokenizer(config_dir='roberta.large')

tk = tokenizer =  get_tokenizer()


def init():
    global tokenizer, tk
    import gc
    tokenizer = tk = get_tokenizer()
    

def char_anchors_to_tok_pos(r):
    if len(r.char_anchors) == 2:
        a,b = r.char_anchors
    else:
        return 0,0
    a = r.char_to_tok_offset[a]
    b = r.char_to_tok_offset[b]
    while b+1 < len(r.all_doc_tokens) and r.all_text_tokens[b+1] == '':
        b += 1
        
    return a, b

def read(dat):
    uid, inp, start, end, p_mask, unanswerable = marshal.loads(dat)
    inp = np.frombuffer(inp, dtype=np.uint16).astype(np.int32)
    p_mask = np.frombuffer(p_mask, dtype=np.bool).astype(np.float32)
    return uid, inp, start, end, p_mask, unanswerable

def fread(f):
    uid, inp, start, end, p_mask, unanswerable = marshal.load(f)
    inp = np.frombuffer(inp, dtype=np.uint16).astype(np.int32)
    p_mask = np.frombuffer(p_mask, dtype=np.bool).astype(np.float32)
    return uid, inp, start, end, p_mask, unanswerable
            
def data_from_path(train_dir):
    index = 0
    for fn in glob(train_dir):
        with tf.gfile.Open(fn, "r") as f:
            entries = [e for e in json.load(f)["data"] for e in e['paragraphs']]


        print("%-40s : %s contexts"%(fn.split('/')[-1],len(entries)))
        for e in entries:
            c = e['context']
            yield index, c, e['qas']
            index += 1


def gen(paths):
    i = 0
    for i,context,qas in data_from_path(paths):
        yield i,context, qas
        
        
import marshal
def work(ss, debug=False):
    
    unique_index, \
     context, \
     qas, \
     is_training, \
     return_feature = ss
    for q in qas:
        q['question'] = q['question']
    
    rss = tokenizer.merge_cq(context, 
                             qas,
                             max_seq_length = max_seq_length,
                             max_query_length = max_query_length,
                             doc_stride = doc_stride,
                             unique_index=unique_index,
                             is_training=is_training,
                             debug = debug
                           )
    o = 0
    results = []
    for rs in rss:
        q = qas[o]
        o += 1
        for r in rs:
            inp = tk.convert_tokens_to_ids(r.all_doc_tokens)
            start_position,end_position = char_anchors_to_tok_pos(r)
            p_mask = r.p_mask
            uid = r.unique_index[0]*1000 + r.unique_index[1]
            
            no_ans = start_position == 0
            
            #if no_ans:
            #    print(q['answer_text'], '>>', r.all_doc_tokens[start_position:end_position+1])
            assert start_position >= 0 and end_position >= 0 and start_position < len(inp) and end_position < len(inp)
            assert len(inp) <= max_seq_length
            
            S, E = start_position, end_position
            
            record = marshal.dumps(
                (
                uid,
                np.array(inp,dtype=np.uint16).tobytes(),
                start_position,
                end_position,
                np.array(p_mask,dtype=np.bool).tobytes(),
                int(no_ans)
                )
            )
            
            if return_feature:
                results.append((record, no_ans,r.serialize()))
            else:
                results.append((record, no_ans))


    
    return results



def generate_tfrecord(data_dir,
                      write_fn=None, 
                      is_training=False,
                      return_feature=False,
                      parallel_process=False,
                      debug=False):

    if return_feature:
        rs = []

    i = 0
    
    if parallel_process:
        cpu_count = multiprocessing.cpu_count()
    
        if 'pool' not in globals():
            pool = Pool(cpu_count-1,initializer=init)
        
    tokenizer = get_tokenizer()
        
    tot_num_no_ans = 0
    
    
        
    records = []
    
        
    num_no_ans = 0
    i += 1

    jobs = ((i, c, q, is_training, return_feature) for i, c, q in gen(data_dir))
    t0 = time()
    results = pool.imap_unordered(work,jobs) if parallel_process else tqdm(iter(work(e, debug=debug) for e in jobs))
    c = 0
    for e in results:
        for record in e:
            if return_feature:
                record, no_ans, r = record
                r = tk.from_bytes(r)
                rs.append(r)
            else:
                record, no_ans = record


            records.append(record)

            if no_ans:
                num_no_ans += 1
            c += 1
            if c % 2500 == 0:
                t1 = time()
                uid, inp, start, end, p_mask, unanswerable = read(record)
                # print(uid, tk.convert_ids_to_tokens(inp) , start, end, p_mask)
                print('%d features (%d no ans) extracted (time: %.2f s)'%(c, num_no_ans, t1-t0))

    if not return_feature:
        random.shuffle(records)
        with open(write_fn, 'wb') as f:
            for record in records:
                f.write(record)
    tot_num_no_ans = num_no_ans

    print('num has ans / num no ans : %d / %d'%(c - tot_num_no_ans, tot_num_no_ans))
    
    
    if return_feature:
        return records, rs
    

train_dir = 'train-v2.0.json'
generate_tfrecord(train_dir, 'qa_records_squad_q', is_training=True, parallel_process=True)

train-v2.0.json                          : 19035 contexts
2500 features (182 no ans) extracted (time: 1.38 s)
5000 features (357 no ans) extracted (time: 1.58 s)
7500 features (387 no ans) extracted (time: 1.75 s)
10000 features (1757 no ans) extracted (time: 1.91 s)
12500 features (2797 no ans) extracted (time: 2.05 s)
15000 features (3694 no ans) extracted (time: 2.19 s)
17500 features (4409 no ans) extracted (time: 2.33 s)
20000 features (5503 no ans) extracted (time: 2.48 s)
22500 features (6196 no ans) extracted (time: 2.62 s)
25000 features (6363 no ans) extracted (time: 2.79 s)
27500 features (6859 no ans) extracted (time: 2.95 s)
30000 features (6974 no ans) extracted (time: 3.16 s)
32500 features (7947 no ans) extracted (time: 3.34 s)
35000 features (9153 no ans) extracted (time: 3.50 s)
37500 features (10411 no ans) extracted (time: 3.67 s)
40000 features (11232 no ans) extracted (time: 3.87 s)
42500 features (11918 no ans) extracted (time: 4.24 s)
45000 features (13009 no an