In [44]:
from tqdm import tqdm
import numpy as np
import json

In [39]:
def process_conll(fname):
    with open(fname) as f:
        lines = f.readlines()
    
    processed_samples = {}
    i = -1
    
    for line in tqdm(lines):
        if '-DOCSTART-' in line:
            i += 1
            processed_samples[i] = {'text':[], 'tags':[], 'global_attn_mask':[]}
        else:
            line = line.split(' ')
            if line[0]=='\n':
                continue
            else:
                processed_samples[i]['text'].append(line[0])
                ner_tag = line[-1][:-1] # end with \n
                processed_samples[i]['tags'].append(ner_tag)
                mask = 0 if ner_tag=='O' else 1
                processed_samples[i]['global_attn_mask'].append(mask)
    
    # sanity check
    for k in processed_samples:
        sample = processed_samples[k]
        assert len(sample['text']) == len(sample['tags']) == len(sample['global_attn_mask'])
        assert sum(np.array(sample['global_attn_mask'])) == (sum(np.array(sample['tags'])!='O'))
    
    return processed_samples        

In [46]:
splits = ['train', 'test', 'valid']

for split in splits:
    fname = split+'.txt'
    processed_samples = process_conll(fname)
    
    with open(split+'.jsonl', 'w+') as fout:
        for k in processed_samples:
            json.dump(processed_samples[k], fout)
            fout.write("\n")    

100%|██████████| 219553/219553 [00:00<00:00, 476117.42it/s]
100%|██████████| 50349/50349 [00:00<00:00, 364879.95it/s]
100%|██████████| 55043/55043 [00:00<00:00, 496904.02it/s]
