<a href="https://colab.research.google.com/github/dwdb/dependency-parser/blob/master/fast_dpnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import pickle
from collections import defaultdict

import numpy as np
import tensorflow as tf

np.random.seed(7)

!nvidia-smi

Fri Jun 19 07:29:18 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.36.06    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8     7W /  75W |      0MiB /  7611MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# 依存句法单词节点类
节点id、词性、头节点id、左右孩子节点（指向的节点）、左右依存关系

In [2]:
class Token(object):

    def __init__(self, token_id, word, pos, dep, head_id):
        self.token_id = token_id
        self.word = word.lower()
        self.pos = pos
        if head_id >= token_id:
            self.dep = 'L_' + dep
        else:
            self.dep = 'R_' + dep
        self.head_id = head_id
        self.left = []
        self.right = []

    def __repr__(self):
        return 'Token(token_id={token_id},word={word},head_id={head_id})'.format(
            **self.__dict__)


ROOT_TOKEN = Token(-1, '<root>', '<root>', '<root>', -1)
NULL_TOKEN = Token(-1, '<null>', '<null>', '<null>', -1)
UNK_TOKEN = Token(-1, '<unk>', '<unk>', '<unk>', -1)

# 依存句法transfer-reduce句子类
节点、缓冲、栈，通过get_next_input可获取当前状态下的单词、磁性、依存关系共计18+18+12个特征

In [3]:
class Sentence(object):
    def __init__(self, tokens):
        self.tokens = tokens
        self.buff = tokens.copy()
        self.stack = [ROOT_TOKEN]
        self.deps = []

    def update_by_action(self, action=None):
        if action is None:
            action = self.get_action()
        # 转移
        if action == 'shift':
            self.stack.append(self.buff.pop(0))
        # 左弧 stack1 <- stack0
        elif action.startswith('L_'):
            token = self.stack.pop(-2)
            token.dep = action
            self.deps.append((self.stack[-1].token_id, token.token_id, action))
            self.binary_insert(self.stack[-1].left, token)
        # 右弧 stack1 -> stack0
        elif action.startswith('R_'):
            token = self.stack.pop(-1)
            token.dep = action
            self.deps.append((self.stack[-1].token_id, token.token_id, action))
            self.binary_insert(self.stack[-1].right, token)
        else:
            raise ValueError('unknown state!')
        return action

    def get_action(self):
        if len(self.stack) < 2:
            return 'shift'
        t1, t0 = self.stack[-2:]
        # left arc
        if t1.head_id == t0.token_id:
            return t1.dep
        # right arc
        if t0.head_id == t1.token_id:
            if any(t0.token_id == t.head_id for t in self.buff):
                return 'shift'
            return t0.dep
        return 'shift'

    def get_next_input(self, word2id, pos2id, dep2id):
        """将conll dataset转为fast dependency parser dataset
        18 features of words and pos tags:
            The top 3 words on the stack and buffer:
                s1, s2, s3, b1, b2, b3
            The first and second leftmost/rightmost children of top 2 words on stack:
                lc1(s1), rc1(s1), lc2(s1), rc2(s1), lc1(s2), rc1(s2), lc2(s2), rc2(s2)
            The leftmost/rightmost of leftmost/rightmost of top 2 words on stack:
                lc1(lc1(s1)), rc1(rc1(s1)), lc1(lc1(s2)), rc1(rc1(s2))
        12 features of dependencies, that excluding those 6 (18-6=12) words on the stack/buffer
        """

        def pad_tokens(tokens, maxlen):
            tokens = tokens[:maxlen]
            if len(tokens) < maxlen:
                tokens += [NULL_TOKEN] * (maxlen - len(tokens))
            return tokens

        def get_children(token):
            lc1, lc2 = pad_tokens(token.left, 2)
            rc1, rc2 = pad_tokens(token.right, 2)
            llc1, = pad_tokens(lc1.left, 1)
            rrc1, = pad_tokens(rc1.right, 1)
            return [lc1, rc1, lc2, rc2, llc1, rrc1]

        # top 3 features on stack
        s1, s2, s3 = pad_tokens(self.stack[-1::-1], 3)
        # 18 features
        tokens = [s1, s2, s3] + pad_tokens(self.buff, 3) + get_children(s1) + get_children(s2)
        # word, pos tag and dependency indices
        input_word = [word2id.get(token.word, word2id[UNK_TOKEN.word]) for token in tokens]
        input_pos = [pos2id.get(token.pos, pos2id[UNK_TOKEN.pos]) for token in tokens]
        input_dep = [dep2id.get(token.dep, dep2id[UNK_TOKEN.dep]) for token in tokens[6:]]
        return input_word, input_pos, input_dep

    @staticmethod
    def binary_insert(array, value, key=lambda x: x.token_id):
        start, end = 0, len(array) - 1
        while start <= end:
            mid = int((start + end) / 2)
            if key(value) >= key(array[mid]):
                start = mid + 1
            else:
                end = mid - 1
        array.insert(start, value)

# Conll数据集类

In [7]:
class ConllDataset(object):

    @staticmethod
    def load(path):
        with open(path, encoding='utf8') as f:
            dataset, tokens = [], []
            for line in f.readlines():
                if line == '\n':
                    dataset.append(Sentence(tokens))
                    tokens = []
                else:
                    line = line.strip().split('\t')
                    token = Token(int(line[0]) - 1, line[1], line[4], line[7], int(line[6]) - 1)
                    tokens.append(token)
            if tokens:
                dataset.append(Sentence(tokens))
        return dataset

    def fit_transform(self, path, min_count=2, shuffle=True):
        dataset = self.load(path)
        # build vocabulary
        vocab = defaultdict(int)
        pos_tags, deps = set(), set()
        for sentence in dataset:
            for token in sentence.tokens:
                vocab[token.word] += 1
                pos_tags.add(token.pos)
                deps.add(token.dep)

        # create word dictionary
        vocab = {k for k, v in vocab.items() if v >= min_count}
        vocab.update((ROOT_TOKEN.word, NULL_TOKEN.word, UNK_TOKEN.word))
        self.word2id = dict(zip(sorted(vocab), range(len(vocab))))
        # create part-of-speech dictionary
        pos_tags.update((ROOT_TOKEN.pos, NULL_TOKEN.pos, UNK_TOKEN.pos))
        self.pos2id = dict(zip(sorted(pos_tags), range(len(pos_tags))))
        # create dependency and labelsdictionary
        deps.update((ROOT_TOKEN.dep, NULL_TOKEN.dep, UNK_TOKEN.dep))
        labels = deps.copy() | {'shift', UNK_TOKEN.dep}
        self.dep2id = dict(zip(sorted(deps), range(len(deps))))
        self.id2label = sorted(labels)
        self.label2id = dict(zip(self.id2label, range(len(self.id2label))))

        self._fit = True
        dataset = self.transform(path, dataset=dataset, shuffle=shuffle)
        return dataset

    def transform(self, path=None, dataset=None, shuffle=False):
        if not dataset and path:
            dataset = self.load(path)
        assert getattr(self, '_fit', None), 'Model must be fit before transform!'

        error_count = 0
        inputs = []
        for i, sentence in enumerate(dataset):
            if i % 5000 == 0:
                print('transforming at line %d' % i)
            while len(sentence.stack) > 1 or sentence.buff:
                input_word, input_pos, input_dep = sentence.get_next_input(
                    self.word2id, self.pos2id, self.dep2id)
                try:
                    output = sentence.update_by_action()
                    output = self.label2id.get(output, self.label2id[UNK_TOKEN.dep])
                except (ValueError, IndexError) as e:
                    error_count += 1
                    break
                inputs.append((input_word, input_pos, input_dep, output))

        # shuffle dataset
        if shuffle:
            np.random.shuffle(inputs)

        print('%s: error count:%d, total count:%d, total examples:%d' % (
            os.path.basename(str(path)), error_count, len(dataset), len(inputs)))
        # input_word, input_pos, input_dep, output
        return tuple(np.array(data, np.int32) for data in zip(*inputs))


conll_path = '/content/drive/My Drive/dependency parsing/data/conll'
conll = ConllDataset()
train_dataset = conll.fit_transform(os.path.join(conll_path, 'train.conll'))
valid_dataset = conll.transform(os.path.join(conll_path, 'dev.conll'))

output_path = './output/'
if not os.path.exists(output_path):
    os.mkdir(output_path)
pickle.dump(conll.word2id, open(os.path.join(output_path, 'word2id.pkl'), 'wb'))
pickle.dump(conll.pos2id, open(os.path.join(output_path, 'pos2id.pkl'), 'wb'))
pickle.dump(conll.dep2id, open(os.path.join(output_path, 'dep2id.pkl'), 'wb'))
pickle.dump(conll.label2id, open(os.path.join(output_path, 'label2id.pkl'), 'wb'))

transforming at line 0
transforming at line 5000
transforming at line 10000
transforming at line 15000
transforming at line 20000
transforming at line 25000
transforming at line 30000
transforming at line 35000
train.conll: error count:120, total count:39832, total examples:1899368
transforming at line 0
dev.conll: error count:5, total count:1700, total examples:80201


# 创建fast dependency model
单词、词性、依存关系存于同一字典，只需创建一个embedding变量即可，但是这样做不好

In [11]:
def build_model(vocab_size, pos_size, dep_size, embedding_size, n_classes):
    print('vocab_size:%d, pos_size:%d, dep_size:%d, n_classes:%d' % (
        vocab_size, pos_size, dep_size, n_classes))
    l2_regularizer = tf.keras.regularizers.l2(1e-5)
    # input layer
    input_word = tf.keras.layers.Input(shape=(18,))
    input_pos = tf.keras.layers.Input(shape=(18,))
    input_dep = tf.keras.layers.Input(shape=(12,))
    # embedding layer，initial weight range of [-0.01, 0.01]
    word_embedding = tf.keras.layers.Embedding(
        vocab_size, embedding_size,
        embeddings_initializer=tf.keras.initializers.RandomUniform(-0.01, 0.01),
        embeddings_regularizer=l2_regularizer)(input_word)
    pos_embedding = tf.keras.layers.Embedding(
        pos_size, embedding_size,
        embeddings_initializer=tf.keras.initializers.RandomUniform(-0.01, 0.01),
        embeddings_regularizer=l2_regularizer)(input_pos)
    dep_embedding = tf.keras.layers.Embedding(
        dep_size, embedding_size,
        embeddings_initializer=tf.keras.initializers.RandomUniform(-0.01, 0.01),
        embeddings_regularizer=l2_regularizer)(input_dep)
    # shape=(batch_size, 48, embedding_size)
    embedding = tf.concat((word_embedding, pos_embedding, dep_embedding), axis=1)
    embedding = tf.reshape(embedding, shape=(-1, embedding_size * 48))
    embedding = tf.keras.layers.Dropout(rate=0.4)(embedding)
    # dense layer
    dense1 = tf.keras.layers.Dense(
        units=100,
        kernel_regularizer=l2_regularizer,
        bias_regularizer=l2_regularizer)(embedding)
    dense1 = tf.pow(dense1, 3)
    dense1 = tf.keras.layers.Dropout(rate=0.4)(dense1)
    # dense layer
    outputs = tf.keras.layers.Dense(
        units=n_classes,
        kernel_regularizer=l2_regularizer,
        bias_regularizer=l2_regularizer)(dense1)
    model = tf.keras.Model(inputs=(input_word, input_pos, input_dep), outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    model.summary()
    return model


model = build_model(vocab_size=len(conll.word2id), pos_size=len(conll.pos2id), 
                    dep_size=len(conll.dep2id), embedding_size=50, n_classes=len(conll.label2id))

vocab_size:21679, pos_size:48, dep_size:73, n_classes:74
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 18)]         0                                            
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 18)]         0                                            
__________________________________________________________________________________________________
input_9 (InputLayer)            [(None, 12)]         0                                            
__________________________________________________________________________________________________
embedding_6 (Embedding)         (None, 18, 50)       1083950     input_7[0][0]                    
___________________________________

In [13]:
history = model.fit(train_dataset[:3], train_dataset[3],
    batch_size=2048, epochs=10, validation_data=(valid_dataset[:3], valid_dataset[3]))
model.save(os.path.join(output_path, 'checkpoint'))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./output/checkpoint/assets


In [25]:
def evaluate(sentences, output_path):
    # restore model
    model = tf.keras.models.load_model(os.path.join(output_path, 'checkpoint'))
    # load local dictionary
    label2id = pickle.load(open(os.path.join(output_path, 'label2id.pkl'), 'rb'))
    word2id = pickle.load(open(os.path.join(output_path, 'word2id.pkl'), 'rb'))
    pos2id = pickle.load(open(os.path.join(output_path, 'pos2id.pkl'), 'rb'))
    dep2id = pickle.load(open(os.path.join(output_path, 'dep2id.pkl'), 'rb'))
    id2label, _ = zip(*sorted(label2id.items(), key=lambda x:x[1]))

    uas = las = count = 0
    for i, sentence in enumerate(sentences):
        raw_deps = {}
        for token in sentence.tokens:
            raw_deps[(token.head_id, token.token_id)] = token.dep
            token.dep = NULL_TOKEN.dep
            token.head_id = -1
        count += len(sentence.tokens)
        while len(sentence.stack) > 1 or sentence.buff:
            input_word, input_pos, input_dep = sentence.get_next_input(word2id, pos2id, dep2id)
            input_word = np.array([input_word], dtype=np.int32)
            input_pos = np.array([input_pos], dtype=np.int32)
            input_dep = np.array([input_dep], dtype=np.int32)
            output = model.predict((input_word, input_pos, input_dep))
            action = id2label[np.argmax(output[0])]
            try:
                sentence.update_by_action(action)
            except (IndexError, ValueError):
                # print(len(sentence.tokens), len(sentence.deps))
                break

        for head, tail, action in sentence.deps:
            raw_action = raw_deps.get((head, tail))
            if raw_action is not None:
                uas += 1
                if raw_action == action:
                    las += 1
        if (i + 1) % 10 == 0:
            print('total sentence:%d, UAS:%.3f, LAS:%.3f' % (i + 1, uas / count, las / count))


sentences = ConllDataset.load(os.path.join(conll_path, 'dev.conll'))
evaluate(sentences, output_path)

total sentence:10, UAS:0.927, LAS:0.906
total sentence:20, UAS:0.897, LAS:0.870
total sentence:30, UAS:0.886, LAS:0.862
total sentence:40, UAS:0.884, LAS:0.861
total sentence:50, UAS:0.876, LAS:0.853
total sentence:60, UAS:0.866, LAS:0.844
total sentence:70, UAS:0.871, LAS:0.848
total sentence:80, UAS:0.861, LAS:0.836
total sentence:90, UAS:0.855, LAS:0.833
total sentence:100, UAS:0.853, LAS:0.830
total sentence:110, UAS:0.857, LAS:0.834
total sentence:120, UAS:0.851, LAS:0.828
total sentence:130, UAS:0.858, LAS:0.835
total sentence:140, UAS:0.863, LAS:0.841
total sentence:150, UAS:0.861, LAS:0.839
total sentence:160, UAS:0.862, LAS:0.840
total sentence:170, UAS:0.864, LAS:0.842
total sentence:180, UAS:0.864, LAS:0.842
total sentence:190, UAS:0.863, LAS:0.842
total sentence:200, UAS:0.860, LAS:0.839
total sentence:210, UAS:0.856, LAS:0.834
total sentence:220, UAS:0.859, LAS:0.837
total sentence:230, UAS:0.861, LAS:0.840
total sentence:240, UAS:0.863, LAS:0.842
total sentence:250, UAS:0

KeyboardInterrupt: ignored