In [1]:
import sys
sys.path.insert(0, "/home/husein/parsing/self-attentive-parser/src")
sys.path.append("/home/husein/parsing/self-attentive-parser")

In [2]:
import tensorflow as tf
from transformers import BertTokenizer

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [4]:
import json

with open('vocab.json') as fopen:
    data = json.load(fopen)
    
LABEL_VOCAB = data['label']
TAG_VOCAB = data['tag']

In [5]:
with tf.gfile.GFile('export/model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

In [6]:
input_ids = graph.get_tensor_by_name('import/input_ids:0')
word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')
charts = graph.get_tensor_by_name('import/charts:0')
tags = graph.get_tensor_by_name('import/tags:0')
sess = tf.InteractiveSession(graph = graph)

In [7]:
BERT_MAX_LEN = 512
import numpy as np
from parse_nk import BERT_TOKEN_MAPPING

def make_feed_dict_bert(sentences):
    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)

    subword_max_len = 0
    for snum, sentence in enumerate(sentences):
        tokens = []
        word_end_mask = []

        tokens.append(u"[CLS]")
        word_end_mask.append(1)

        cleaned_words = []
        for word in sentence:
            word = BERT_TOKEN_MAPPING.get(word, word)
            # BERT is pre-trained with a tokenizer that doesn't split off
            # n't as its own token
            if word == u"n't" and cleaned_words:
                cleaned_words[-1] = cleaned_words[-1] + u"n"
                word = u"'t"
            cleaned_words.append(word)

        for word in cleaned_words:
            word_tokens = tokenizer.tokenize(word)
            if not word_tokens:
                # The tokenizer used in conjunction with the parser may not
                # align with BERT; in particular spaCy will create separate
                # tokens for whitespace when there is more than one space in
                # a row, and will sometimes separate out characters of
                # unicode category Mn (which BERT strips when do_lower_case
                # is enabled). Substituting UNK is not strictly correct, but
                # it's better than failing to return a valid parse.
                word_tokens = ["[UNK]"]
            for _ in range(len(word_tokens)):
                word_end_mask.append(0)
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append(u"[SEP]")
        word_end_mask.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        subword_max_len = max(subword_max_len, len(input_ids))

        all_input_ids[snum, :len(input_ids)] = input_ids
        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

    all_input_ids = all_input_ids[:, :subword_max_len]
    all_word_end_mask = all_word_end_mask[:, :subword_max_len]
    return all_input_ids, all_word_end_mask

In [8]:
sentences = ['i like to eat'.split()]
i, m = make_feed_dict_bert(sentences)

In [9]:
charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})

In [10]:
for snum, sentence in enumerate(sentences):
    chart_size = len(sentence) + 1
    chart = charts_val[snum,:chart_size,:chart_size,:]

In [11]:
# !wget https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx

In [12]:
import chart_decoder

In [13]:
chart_decoder.decode(chart)

(4.400097846984863,
 array([0, 0, 1, 1, 2, 2, 3]),
 array([4, 1, 4, 2, 4, 3, 4]),
 array([35,  3,  5,  0, 11,  0,  5]))

In [14]:
import nltk
from nltk import Tree

In [56]:
PTB_TOKEN_ESCAPE = {u"(": u"-LRB-",
    u")": u"-RRB-",
    u"{": u"-LCB-",
    u"}": u"-RCB-",
    u"[": u"-LSB-",
    u"]": u"-RSB-"}


def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):

    # Python 2 doesn't support "nonlocal", so wrap idx in a list
    idx_cell = [-1]
    def make_tree():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            tree = Tree(tag, [word])
            for sublabel in label[::-1]:
                tree = Tree(sublabel, [tree])
            return [tree]
        else:
            left_trees = make_tree()
            right_trees = make_tree()
            children = left_trees + right_trees
            if label:
                tree = Tree(label[-1], children)
                for sublabel in reversed(label[:-1]):
                    tree = Tree(sublabel, [tree])
                return [tree]
            else:
                return children

    tree = make_tree()[0]
    tree.score = score
    return tree

In [57]:
tree = make_nltk_tree('i like to eat'.split(), tags_val[0], *chart_decoder.decode(chart))
print(str(tree))

(WHADJP
  (NP (<START> i))
  (VP (PRP like) (S (VP (VBP to) (VP (TO eat))))))


In [63]:
def make_str_tree(sentence, tags, score, p_i, p_j, p_label):
    idx_cell = [-1]
    def make_str():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            s = u"({} {})".format(tag, word)
        else:
            children = []
            while ((idx_cell[0] + 1) < len(p_i)
                and i <= p_i[idx_cell[0] + 1]
                and p_j[idx_cell[0] + 1] <= j):
                children.append(make_str())

            s = u" ".join(children)
            
        for sublabel in reversed(label):
            s = u"({} {})".format(sublabel, s)
        return s
    return make_str()

In [64]:
make_str_tree('i like to eat'.split(), tags_val[0], *chart_decoder.decode(chart))

'(WHADJP (NP (<START> i)) (VP (PRP like) (S (VP (VBP to) (VP (TO eat))))))'