In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
import sys
sys.path.insert(0, "/home/husein/parsing/self-attentive-parser/src")
sys.path.append("/home/husein/parsing/self-attentive-parser")

In [4]:
import tensorflow as tf

In [5]:
from transformers import XLNetTokenizer
tokenizer = XLNetTokenizer.from_pretrained(
    'huseinzol05/xlnet-base-bahasa-cased', do_lower_case = False
)

In [6]:
import json

with open('vocab-xlnet-base.json') as fopen:
    data = json.load(fopen)
    
LABEL_VOCAB = data['label']
TAG_VOCAB = data['tag']

In [7]:
with tf.gfile.GFile('export/xlnet-base.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

In [9]:
input_ids = graph.get_tensor_by_name('import/input_ids:0')
word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')
charts = graph.get_tensor_by_name('import/charts:0')
tags = graph.get_tensor_by_name('import/tags:0')
sess = tf.InteractiveSession(graph = graph)

In [29]:
BERT_MAX_LEN = 512
import numpy as np
from parse_nk import BERT_TOKEN_MAPPING

def make_feed_dict_bert(sentences):
    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)

    subword_max_len = 0
    for snum, sentence in enumerate(sentences):
        tokens = []
        word_end_mask = []

        cleaned_words = []
        for word in sentence:
            word = BERT_TOKEN_MAPPING.get(word, word)
            if word == "n't" and cleaned_words:
                cleaned_words[-1] = cleaned_words[-1] + "n"
                word = "'t"
            cleaned_words.append(word)

        for word in cleaned_words:
            word_tokens = tokenizer.tokenize(word)
            for _ in range(len(word_tokens)):
                word_end_mask.append(0)
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append("<sep>")
        word_end_mask.append(1)
        tokens.append("<cls>")
        word_end_mask.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)

        subword_max_len = max(subword_max_len, len(input_ids))

        all_input_ids[snum, :len(input_ids)] = input_ids
        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

    all_input_ids = all_input_ids[:, :subword_max_len]
    all_word_end_mask = all_word_end_mask[:, :subword_max_len]
    return all_input_ids, all_word_end_mask

In [30]:
s = 'Saya sedang membaca buku tentang Perlembagaan'.split()
sentences = [s]
i, m = make_feed_dict_bert(sentences)
i, m

(array([[ 287,  461, 1524,  598,  454, 3809,    4,    3]]),
 array([[1, 1, 1, 1, 1, 1, 1, 1]]))

In [31]:
charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})
charts_val, tags_val

(array([[[[ 0.        , -4.529982  , -3.3814292 , ..., -2.5720258 ,
           -2.4132185 , -2.6000128 ],
          [ 0.        , -1.617211  , -3.0199025 , ..., -1.5042725 ,
           -1.7649275 , -1.9249804 ],
          [ 0.        , -1.5820696 , -3.0010629 , ..., -1.7666583 ,
           -1.9988693 , -1.9891591 ],
          ...,
          [ 0.        , -2.5202546 , -5.1390266 , ..., -2.2531857 ,
           -2.4834492 , -2.132457  ],
          [ 0.        ,  0.48587704, -3.3051388 , ..., -2.5070767 ,
           -2.182617  , -2.33619   ],
          [ 0.        ,  0.        ,  0.        , ...,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        , -3.5230384 , -1.2199852 , ..., -2.0789745 ,
           -2.1880505 , -2.3633444 ],
          [ 0.        , -4.529982  , -3.3814292 , ..., -2.5720258 ,
           -2.4132185 , -2.6000128 ],
          [ 0.        , -2.2948356 , -2.4096372 , ..., -2.3134325 ,
           -2.268665  , -2.1652873 ],
          ...,
          [ 0

In [32]:
for snum, sentence in enumerate(sentences):
    chart_size = len(sentence) + 1
    chart = charts_val[snum,:chart_size,:chart_size,:]

In [33]:
# !wget https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx

In [34]:
import chart_decoder_py

In [35]:
chart_decoder_py.decode(chart)

(16.243681,
 array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5]),
 array([6, 1, 6, 2, 6, 3, 6, 4, 6, 5, 6]),
 array([1, 4, 5, 0, 5, 0, 0, 3, 2, 0, 3]))

In [36]:
import nltk
from nltk import Tree

In [37]:
PTB_TOKEN_ESCAPE = {u"(": u"-LRB-",
    u")": u"-RRB-",
    u"{": u"-LCB-",
    u"}": u"-RCB-",
    u"[": u"-LSB-",
    u"]": u"-RSB-"}


def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):

    # Python 2 doesn't support "nonlocal", so wrap idx in a list
    idx_cell = [-1]
    def make_tree():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            tree = Tree(tag, [word])
            for sublabel in label[::-1]:
                tree = Tree(sublabel, [tree])
            return [tree]
        else:
            left_trees = make_tree()
            right_trees = make_tree()
            children = left_trees + right_trees
            if label:
                tree = Tree(label[-1], children)
                for sublabel in reversed(label[:-1]):
                    tree = Tree(sublabel, [tree])
                return [tree]
            else:
                return children

    tree = make_tree()[0]
    tree.score = score
    return tree

In [38]:
tree = make_nltk_tree(s, tags_val[0], *chart_decoder_py.decode(chart))
print(str(tree))

(S
  (NP-SBJ (<START> Saya))
  (VP
    (PRP sedang)
    (VP
      (MD membaca)
      (NP (VB buku))
      (PP (NN tentang) (NP (IN Perlembagaan))))))


In [39]:
def make_str_tree(sentence, tags, score, p_i, p_j, p_label):
    idx_cell = [-1]
    def make_str():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            s = u"({} {})".format(tag, word)
        else:
            children = []
            while ((idx_cell[0] + 1) < len(p_i)
                and i <= p_i[idx_cell[0] + 1]
                and p_j[idx_cell[0] + 1] <= j):
                children.append(make_str())

            s = u" ".join(children)
            
        for sublabel in reversed(label):
            s = u"({} {})".format(sublabel, s)
        return s
    return make_str()

In [40]:
make_str_tree(s, tags_val[0], *chart_decoder_py.decode(chart))

'(S (NP-SBJ (<START> Saya)) (VP (PRP sedang) (VP (MD membaca) (NP (VB buku)) (PP (NN tentang) (NP (IN Perlembagaan))))))'