In [1]:
import os
import argparse
import codecs
import json
import random as rnd
import numpy as np
from collections import Counter, defaultdict
from itertools import chain, count
from six import string_types
import torch
import torchtext.data
import torchtext.vocab

import table
import table.IO
import opts
from tree import SCode

In [2]:
UNK_WORD = '<unk>'
UNK = 0
PAD_WORD = '<blank>'
PAD = 1
BOS_WORD = '<s>'
BOS = 2
EOS_WORD = '</s>'
EOS = 3
SKP_WORD = '<sk>'
SKP = 4
RIG_WORD = '<]>'
RIG = 5
LFT_WORD = '<[>'
LFT = 6
special_token_list = [UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD, SKP_WORD, RIG_WORD, LFT_WORD]

In [3]:
def get_parent_index(tk_list):
    stack = [0]
    r_list = []
    for i, tk in enumerate(tk_list):
        r_list.append(stack[-1])
        if tk.startswith('('):
            # +1: because the parent of the top level is 0
            stack.append(i+1)
        elif tk ==')':
            stack.pop()
    # for EOS (</s>)
    r_list.append(0)
    return r_list


def get_tgt_mask(lay_skip):
    # 0: use layout encoding vectors; 1: use target word embeddings;
    # with a <s> token at the first position
    return [1] + [1 if tk in (SKP_WORD, RIG_WORD) else 0 for tk in lay_skip]


def get_lay_index(lay_skip):
    # with a <s> token at the first position
    r_list = [0]
    k = 0
    for tk in lay_skip:
        if tk in (SKP_WORD, RIG_WORD):
            r_list.append(0)
        else:
            r_list.append(k)
            k += 1
    return r_list


def get_tgt_loss(line, mask_target_loss):
    r_list = []
    for tk_tgt, tk_lay_skip in zip(line['tgt'], line['lay_skip']):
        if tk_lay_skip in (SKP_WORD, RIG_WORD):
            r_list.append(tk_tgt)
        else:
            if mask_target_loss:
                r_list.append(PAD_WORD)
            else:
                r_list.append(tk_tgt)
    return r_list


def __getstate__(self):
    return dict(self.__dict__, stoi=dict(self.stoi))


def __setstate__(self, state):
    self.__dict__.update(state)
    self.stoi = defaultdict(lambda: 0, self.stoi)


torchtext.vocab.Vocab.__getstate__ = __getstate__
torchtext.vocab.Vocab.__setstate__ = __setstate__


def filter_counter(freqs, min_freq):
    cnt = Counter()
    for k, v in freqs.items():
        if (min_freq is None) or (v >= min_freq):
            cnt[k] = v
    return cnt


def merge_vocabs(vocabs, min_freq=0, vocab_size=None):
    """
    Merge individual vocabularies (assumed to be generated from disjoint
    documents) into a larger vocabulary.

    Args:
        vocabs: `torchtext.vocab.Vocab` vocabularies to be merged
        vocab_size: `int` the final vocabulary size. `None` for no limit.
    Return:
        `torchtext.vocab.Vocab`
    """
    merged = Counter()
    for vocab in vocabs:
        merged += filter_counter(vocab.freqs, min_freq)
    return torchtext.vocab.Vocab(merged,
                                 specials=list(special_token_list),
                                 max_size=vocab_size, min_freq=min_freq)


def join_dicts(*args):
    """
    args: dictionaries with disjoint keys
    returns: a single dictionary that has the union of these keys
    """
    return dict(chain(*[d.items() for d in args]))

def _preprocess_json(js):
    t = SCode((js['token'], js['type']))
    js['lay'] = t.layout(add_skip=False)
    js['lay_skip'] = t.layout(add_skip=True)
    assert len(t.target()) == len(js['lay_skip']), (list(zip(t.target(), js['lay_skip'])), ' '.join(js['tgt']))
    js['tgt'] = t.target()

def read_anno_json(anno_path):
    with codecs.open(anno_path, "r", "utf-8") as corpus_file:
        js_list = [json.loads(line) for line in corpus_file]
        js_list = js_list[:5]
        for js in js_list:
            _preprocess_json(js)
    return js_list

In [11]:
js_list = read_anno_json(test_anno)

In [16]:
def process_data(js_list):
    for js in js_list:
        print("\n"+"-"*50)
        for field in ['src', "token", "type", 'lay', "lay_index", "lay_parent_index",\
                      "copy_to_tgt", "copy_to_ext", "tgt_mask", "tgt", "tgt_copy_ext",\
                      "tgt_parent_index", "tgt_loss"]:
            if field in ('src', 'lay', "token", "type"):
                lines = js[field]
            elif field in ('copy_to_tgt','copy_to_ext'):
                lines = js['src']
            elif field in ('tgt',):
                def _tgt(line):
                    r_list = []
                    for tk_tgt, tk_lay_skip in zip(line['tgt'], line['lay_skip']):
                        if tk_lay_skip in (SKP_WORD, RIG_WORD):
                            r_list.append(tk_tgt)
                        else:
                            r_list.append(PAD_WORD)
                    return r_list
                lines = _tgt(js)
            elif field in ('tgt_copy_ext',):
                def _tgt_copy_ext(line):
                    r_list = []
                    src_set = set(line['src'])
                    for tk_tgt in line['tgt']:
                        if tk_tgt in src_set:
                            r_list.append(tk_tgt)
                        else:
                            r_list.append(UNK_WORD)
                    return r_list
                lines = _tgt_copy_ext(js)
            elif field in ('tgt_loss',):
                lines = get_tgt_loss(js, False)
            elif field in ('tgt_mask',):
                lines = get_tgt_mask(js['lay_skip'])
            elif field in ('lay_index',):
                lines = get_lay_index(js['lay_skip'])
            elif field in ('lay_parent_index',):
                lines = get_parent_index(js['lay'])
            elif field in ('tgt_parent_index',):
                lines = get_parent_index(js['tgt'])
            else:
                raise NotImplementedError

            print(field + ": ", lines)

In [17]:
process_data(js_list)


--------------------------------------------------
src:  ['send', 'a', 'signal', '`', 'signal.SIGUSR1', '`', 'to', 'the', 'current', 'process']
token:  ['os', '.', 'kill', '(', 'os', '.', 'getpid', '(', ')', ',', 'signal', '.', 'SIGUSR1', ')']
type:  ['KEYWORD', 'OP', 'KEYWORD', 'OP', 'KEYWORD', 'OP', 'KEYWORD', 'OP', 'OP', 'OP', 'KEYWORD', 'OP', 'KEYWORD', 'OP']
lay:  ['os', '.', 'kill', '(', 'os', '.', 'getpid', '(', ')', ',', 'signal', '.', 'SIGUSR1', ')']
lay_index:  [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
lay_parent_index:  [0, 0, 0, 0, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 0]
copy_to_tgt:  ['send', 'a', 'signal', '`', 'signal.SIGUSR1', '`', 'to', 'the', 'current', 'process']
copy_to_ext:  ['send', 'a', 'signal', '`', 'signal.SIGUSR1', '`', 'to', 'the', 'current', 'process']
tgt_mask:  [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
tgt:  ['<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>',