In [1]:
import pickle
import json
import tensorflow as tf
import numpy as np

In [2]:
with open('session-entities.pkl', 'rb') as fopen:
    data = pickle.load(fopen)
data.keys()

dict_keys(['train_X', 'test_X', 'train_Y', 'test_Y'])

In [3]:
train_X = data['train_X']
test_X = data['test_X']
train_Y = data['train_Y']
test_Y = data['test_Y']

In [4]:
with open('dictionary-entities.json') as fopen:
    dictionary = json.load(fopen)
dictionary.keys()

dict_keys(['word2idx', 'idx2word', 'tag2idx', 'idx2tag', 'char2idx'])

In [5]:
word2idx = dictionary['word2idx']
idx2word = {int(k): v for k, v in dictionary['idx2word'].items()}
tag2idx = dictionary['tag2idx']
idx2tag = {int(k): v for k, v in dictionary['idx2tag'].items()}
char2idx = dictionary['char2idx']

In [6]:
list(zip([idx2word[d] for d in train_X[-1]], [idx2tag[d] for d in train_Y[-1]]))

[('harahap', 'person'),
 (',', 'OTHER'),
 ('pedagang', 'OTHER'),
 ('ikan', 'OTHER'),
 ('kering', 'OTHER'),
 (',', 'OTHER'),
 ('kering', 'OTHER'),
 (',', 'OTHER'),
 ('para', 'OTHER'),
 ('pembeli', 'OTHER'),
 ('mayoritas', 'OTHER'),
 ('datang', 'OTHER'),
 ('dari', 'OTHER'),
 ('luar', 'OTHER'),
 ('kota', 'OTHER'),
 ('seperti', 'OTHER'),
 ('meranti', 'location'),
 (',', 'OTHER'),
 ('riau', 'location'),
 (',', 'OTHER'),
 ('jambi', 'location'),
 (',', 'OTHER'),
 ('dan', 'OTHER'),
 ('batam', 'location'),
 ('.', 'OTHER'),
 ('cabai', 'OTHER'),
 ('kering', 'OTHER'),
 (',', 'OTHER'),
 ('para', 'OTHER'),
 ('pembeli', 'OTHER'),
 ('mayoritas', 'OTHER'),
 ('datang', 'OTHER'),
 ('dari', 'OTHER'),
 ('luar', 'OTHER'),
 ('kota', 'OTHER'),
 ('seperti', 'OTHER'),
 ('chuah', 'location'),
 (',', 'OTHER'),
 ('riau', 'location'),
 (',', 'OTHER'),
 ('jambi', 'location'),
 (',', 'OTHER'),
 ('dan', 'OTHER'),
 ('batam', 'location'),
 ('.', 'OTHER'),
 ('cabai', 'OTHER'),
 ('para', 'OTHER'),
 ('pembeli', 'OTHER'),
 

In [7]:
list(zip([idx2word[d] for d in train_X[1]], [idx2tag[d] for d in train_Y[1]]))

[('politik', 'OTHER'),
 ('dari', 'OTHER'),
 ('Universitas', 'organization'),
 ('Gadjah', 'organization'),
 ('Mada', 'organization'),
 (',', 'OTHER'),
 ('Arie', 'person'),
 ('Sudjito', 'person'),
 (',', 'OTHER'),
 ('menilai,', 'OTHER'),
 ('keinginan', 'OTHER'),
 ('Ketua', 'OTHER'),
 ('Umum', 'OTHER'),
 ('Partai', 'organization'),
 ('Golkar', 'organization'),
 ('Aburizal', 'person'),
 ('Bakrie', 'person'),
 ('untuk', 'OTHER'),
 ('maju', 'OTHER'),
 ('kembali', 'OTHER'),
 ('sebagai', 'OTHER'),
 ('ketua', 'OTHER'),
 ('umum', 'OTHER'),
 ('merupakan', 'OTHER'),
 ('pemaksaan', 'OTHER'),
 ('kehendak.', 'OTHER'),
 ('Menurut', 'OTHER'),
 ('dia,', 'OTHER'),
 ('ada', 'OTHER'),
 ('kesan', 'OTHER'),
 ('bahwa', 'OTHER'),
 ('Aburizal', 'person'),
 ('menggunakan', 'OTHER'),
 ('segala', 'OTHER'),
 ('cara', 'OTHER'),
 ('untuk', 'OTHER'),
 ('memuluskan', 'OTHER'),
 ('jalannya', 'OTHER'),
 ('kembali', 'OTHER'),
 ('menduduki', 'OTHER'),
 ('Golkar', 'organization'),
 ('1.', 'OTHER'),
 ('Hal', 'OTHER'),
 ('ini

In [8]:
def generate_char_seq(batch):
    x = [[len(idx2word[i]) for i in k] for k in batch]
    maxlen = max([j for i in x for j in i])
    temp = np.zeros((batch.shape[0],batch.shape[1],maxlen),dtype=np.int32)
    for i in range(batch.shape[0]):
        for k in range(batch.shape[1]):
            for no, c in enumerate(idx2word[batch[i,k]]):
                temp[i,k,-1-no] = char2idx[c]
    return temp

In [9]:
generate_char_seq(data['train_X'][:10]).shape

(10, 50, 11)

In [10]:
class Model:
    def __init__(
        self,
        dim_word,
        dim_char,
        dropout,
        learning_rate,
        hidden_size_char,
        hidden_size_word,
        num_layers,
    ):
        def cells(size, reuse = False):
            return tf.contrib.rnn.DropoutWrapper(
                tf.nn.rnn_cell.LSTMCell(
                    size,
                    initializer = tf.orthogonal_initializer(),
                    reuse = reuse,
                ),
                output_keep_prob = dropout,
            )

        self.word_ids = tf.placeholder(tf.int32, shape = [None, None])
        self.char_ids = tf.placeholder(tf.int32, shape = [None, None, None])
        self.labels = tf.placeholder(tf.int32, shape = [None, None])
        self.maxlen = tf.shape(self.word_ids)[1]
        self.lengths = tf.count_nonzero(self.word_ids, 1)
        
        self.word_embeddings = tf.Variable(
            tf.truncated_normal(
                [len(word2idx), dim_word], stddev = 1.0 / np.sqrt(dim_word)
            )
        )
        self.char_embeddings = tf.Variable(
            tf.truncated_normal(
                [len(char2idx), dim_char], stddev = 1.0 / np.sqrt(dim_char)
            )
        )

        word_embedded = tf.nn.embedding_lookup(
            self.word_embeddings, self.word_ids
        )
        char_embedded = tf.nn.embedding_lookup(
            self.char_embeddings, self.char_ids
        )
        s = tf.shape(char_embedded)
        char_embedded = tf.reshape(
            char_embedded, shape = [s[0] * s[1], s[-2], dim_char]
        )
        
        for n in range(num_layers):
            (out_fw, out_bw), (
                state_fw,
                state_bw,
            ) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cells(hidden_size_char),
                cell_bw = cells(hidden_size_char),
                inputs = char_embedded,
                dtype = tf.float32,
                scope = 'bidirectional_rnn_char_%d' % (n),
            )
            char_embedded = tf.concat((out_fw, out_bw), 2)
        output = tf.reshape(
            char_embedded[:, -1], shape = [s[0], s[1], 2 * hidden_size_char]
        )
        word_embedded = tf.concat([word_embedded, output], axis = -1)

        for n in range(num_layers):
            (out_fw, out_bw), (
                state_fw,
                state_bw,
            ) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cells(hidden_size_word),
                cell_bw = cells(hidden_size_word),
                inputs = word_embedded,
                dtype = tf.float32,
                scope = 'bidirectional_rnn_word_%d' % (n),
            )
            word_embedded = tf.concat((out_fw, out_bw), 2)

        logits = tf.layers.dense(word_embedded, len(idx2tag))
        y_t = self.labels
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
            logits, y_t, self.lengths
        )
        self.cost = tf.reduce_mean(-log_likelihood)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)
        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(
            logits, transition_params, self.lengths
        )
        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')

        y_t = tf.cast(y_t, tf.int32)
        self.prediction = tf.boolean_mask(self.tags_seq, mask)
        mask_label = tf.boolean_mask(y_t, mask)
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [11]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

dim_word = 128
dim_char = 256
dropout = 0.8
learning_rate = 1e-3
hidden_size_char = 128
hidden_size_word = 128
num_layers = 2
batch_size = 64

model = Model(dim_word,dim_char,dropout,learning_rate,hidden_size_char,hidden_size_word,num_layers)
sess.run(tf.global_variables_initializer())

W0805 00:21:47.352130 140268250425152 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.
Instructions for updating:
reduction_indices is deprecated, use axis instead
W0805 00:21:47.857108 140268250425152 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0805 00:21:47.858026 140268250425152 deprecation.py:323] From <ipython-input-10-9add79807405>:17: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for

In [12]:
string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'

import re

def entities_textcleaning(string, lowering = False):
    """
    use by entities recognition, pos recognition and dependency parsing
    """
    string = re.sub('[^A-Za-z0-9\-\/() ]+', ' ', string)
    string = re.sub(r'[ ]+', ' ', string).strip()
    original_string = string.split()
    if lowering:
        string = string.lower()
    string = [
        (original_string[no], word.title() if word.isupper() else word)
        for no, word in enumerate(string.split())
        if len(word)
    ]
    return [s[0] for s in string], [s[1] for s in string]

def char_str_idx(corpus, dic, UNK = 0):
    maxlen = max([len(i) for i in corpus])
    X = np.zeros((len(corpus), maxlen))
    for i in range(len(corpus)):
        for no, k in enumerate(corpus[i][:maxlen][::-1]):
            val = dic[k] if k in dic else UNK
            X[i, -1 - no] = val
    return X

In [13]:
from tqdm import tqdm
import time

for e in range(6):
    lasttime = time.time()
    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0
    pbar = tqdm(
        range(0, train_X.shape[0], batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, train_X.shape[0])
        batch_x = train_X[i : index]
        batch_char = generate_char_seq(batch_x)
        batch_y = train_Y[i : index]
        acc, cost, _ = sess.run(
            [model.accuracy, model.cost, model.optimizer],
            feed_dict = {
                model.word_ids: batch_x,
                model.char_ids: batch_char,
                model.labels: batch_y
            },
        )
        assert not np.isnan(cost)
        train_loss += cost
        train_acc += acc
        pbar.set_postfix(cost = cost, accuracy = acc)
        
    pbar = tqdm(
        range(0, test_X.shape[0], batch_size), desc = 'test minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, test_X.shape[0])
        batch_x = test_X[i : index]
        batch_char = generate_char_seq(batch_x)
        batch_y = test_Y[i : index]
        acc, cost = sess.run(
            [model.accuracy, model.cost],
            feed_dict = {
                model.word_ids: batch_x,
                model.char_ids: batch_char,
                model.labels: batch_y
            },
        )
        assert not np.isnan(cost)
        test_loss += cost
        test_acc += acc
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    train_loss /= len(train_X) / batch_size
    train_acc /= len(train_X) / batch_size
    test_loss /= len(test_X) / batch_size
    test_acc /= len(test_X) / batch_size

    print('time taken:', time.time() - lasttime)
    print(
        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\n'
        % (e, train_loss, train_acc, test_loss, test_acc)
    )
    
    sequence = entities_textcleaning(string)[1]
    X_seq = char_str_idx([sequence], word2idx, 2)
    X_char_seq = generate_char_seq(X_seq)

    predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.word_ids: X_seq,
                    model.char_ids: X_char_seq,
                },
        )[0]

    for i in range(len(predicted)):
        print(sequence[i],idx2tag[predicted[i]])

train minibatch loop: 100%|██████████| 9191/9191 [1:12:15<00:00,  2.35it/s, accuracy=0.999, cost=0.235]
test minibatch loop: 100%|██████████| 2298/2298 [08:35<00:00,  4.46it/s, accuracy=0.676, cost=52.6]


time taken: 4850.963254451752
epoch: 0, training loss: 8.659316, training acc: 0.939020, valid loss: 25.182621, valid acc: 0.859950



train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri OTHER
minggu time
depan OTHER
Perdana OTHER
Menteri OTHER
Tun person
Dr person
Mahathir person
Mohamad OTHER
dan OTHER
Menteri OTHER
Pengangkutan OTHER
Anthony OTHER
Loke OTHER
Siew person
Fook person
menitipkan OTHER
pesanan OTHER
khas OTHER
kepada OTHER
orang quantity
ramai quantity
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan OTHER
Raya OTHER
(Jkjr) OTHER
itu OTHER
Dr OTHER
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk quantity
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:12:14<00:00,  2.73it/s, accuracy=1, cost=0.0409]    
test minibatch loop: 100%|██████████| 2298/2298 [08:36<00:00,  4.45it/s, accuracy=0.9, cost=26]     
train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

time taken: 4850.45486998558
epoch: 1, training loss: 2.052917, training acc: 0.984684, valid loss: 20.512893, valid acc: 0.898136

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri time
minggu time
depan OTHER
Perdana OTHER
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang quantity
ramai quantity
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) location
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk location
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:12:38<00:00,  2.92it/s, accuracy=1, cost=0.00787]   
test minibatch loop: 100%|██████████| 2298/2298 [08:22<00:00,  4.58it/s, accuracy=0.94, cost=21.3]  
train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

time taken: 4860.948942899704
epoch: 2, training loss: 0.703229, training acc: 0.994912, valid loss: 14.662600, valid acc: 0.930679

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri time
minggu time
depan OTHER
Perdana person
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang quantity
ramai quantity
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) location
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk location
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:12:51<00:00,  2.69it/s, accuracy=1, cost=0.0424]    
test minibatch loop: 100%|██████████| 2298/2298 [08:45<00:00,  4.37it/s, accuracy=0.936, cost=22.5] 
train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

time taken: 4896.710217475891
epoch: 3, training loss: 0.304583, training acc: 0.997820, valid loss: 14.302636, valid acc: 0.943620

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri event
minggu time
depan time
Perdana person
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang OTHER
ramai OTHER
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) location
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat organization
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk location
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:12:49<00:00,  2.64it/s, accuracy=1, cost=0.000336]  
test minibatch loop: 100%|██████████| 2298/2298 [08:38<00:00,  4.43it/s, accuracy=0.94, cost=24.1]  
train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

time taken: 4888.042459726334
epoch: 4, training loss: 0.149868, training acc: 0.998946, valid loss: 15.276245, valid acc: 0.947501

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri event
minggu time
depan time
Perdana person
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang OTHER
ramai OTHER
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) location
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk location
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:12:40<00:00,  2.62it/s, accuracy=1, cost=0.00242]   
test minibatch loop: 100%|██████████| 2298/2298 [08:39<00:00,  4.42it/s, accuracy=0.944, cost=13.5] 

time taken: 4880.330358505249
epoch: 5, training loss: 0.096976, training acc: 0.999330, valid loss: 13.203072, valid acc: 0.951456

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri event
minggu time
depan time
Perdana person
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang OTHER
ramai OTHER
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) location
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk location
ketika OTHER
memandu OTHER





In [14]:
sequence = entities_textcleaning('mahathir suka akta 19977')[1]
X_seq = char_str_idx([sequence], word2idx, 2)
X_char_seq = generate_char_seq(X_seq)

predicted = sess.run(model.tags_seq,
            feed_dict = {
                model.word_ids: X_seq,
                model.char_ids: X_char_seq,
            },
    )[0]

for i in range(len(predicted)):
    print(sequence[i],idx2tag[predicted[i]])

mahathir person
suka OTHER
akta law
19977 law


In [15]:
def pred2label(pred):
    out = []
    for pred_i in pred:
        out_i = []
        for p in pred_i:
            out_i.append(idx2tag[p])
        out.append(out_i)
    return out

In [16]:
real_Y, predict_Y = [], []

pbar = tqdm(
    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'
)
for i in pbar:
    batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]
    batch_char = generate_char_seq(batch_x)
    batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]
    predicted = pred2label(sess.run(model.tags_seq,
            feed_dict = {
                model.word_ids: batch_x,
                model.char_ids: batch_char,
            },
    ))
    real = pred2label(batch_y)
    predict_Y.extend(predicted)
    real_Y.extend(real)

validation minibatch loop: 100%|██████████| 2298/2298 [08:19<00:00,  4.60it/s]


In [17]:
from sklearn.metrics import classification_report
print(classification_report(np.array(real_Y).ravel(), np.array(predict_Y).ravel(),
                           digits = 6))

              precision    recall  f1-score   support

       OTHER   0.971261  0.995777  0.983367   5160854
       event   0.991603  0.181505  0.306844    143787
         law   0.989329  0.879490  0.931181    146950
    location   0.848477  0.960757  0.901133    428869
organization   0.960967  0.761301  0.849560    694150
      person   0.850705  0.969984  0.906437    507960
    quantity   0.996606  0.972120  0.984211     88200
        time   0.879509  0.986763  0.930054    179880

    accuracy                       0.951052   7350650
   macro avg   0.936057  0.838462  0.849098   7350650
weighted avg   0.953613  0.951052  0.945046   7350650



In [18]:
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, 'concat/model.ckpt')

strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'alphas' in n.name)
        and 'Adam' not in n.name
        and 'beta' not in n.name
        and 'OptimizeLoss' not in n.name
        and 'Global_Step' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'Placeholder_2',
 'Variable',
 'Variable_1',
 'bidirectional_rnn_char_0/fw/lstm_cell/kernel',
 'bidirectional_rnn_char_0/fw/lstm_cell/bias',
 'bidirectional_rnn_char_0/bw/lstm_cell/kernel',
 'bidirectional_rnn_char_0/bw/lstm_cell/bias',
 'bidirectional_rnn_char_1/fw/lstm_cell/kernel',
 'bidirectional_rnn_char_1/fw/lstm_cell/bias',
 'bidirectional_rnn_char_1/bw/lstm_cell/kernel',
 'bidirectional_rnn_char_1/bw/lstm_cell/bias',
 'bidirectional_rnn_word_0/fw/lstm_cell/kernel',
 'bidirectional_rnn_word_0/fw/lstm_cell/bias',
 'bidirectional_rnn_word_0/bw/lstm_cell/kernel',
 'bidirectional_rnn_word_0/bw/lstm_cell/bias',
 'bidirectional_rnn_word_1/fw/lstm_cell/kernel',
 'bidirectional_rnn_word_1/fw/lstm_cell/bias',
 'bidirectional_rnn_word_1/bw/lstm_cell/kernel',
 'bidirectional_rnn_word_1/bw/lstm_cell/bias',
 'dense/kernel',
 'dense/bias',
 'transitions',
 'logits']

In [19]:
def freeze_graph(model_dir, output_node_names):

    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))
        
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph


In [20]:
freeze_graph('concat', strings)

W0805 08:38:52.328392 140268250425152 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
W0805 08:38:52.601029 140268250425152 deprecation.py:323] From <ipython-input-19-3d8392da7830>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
W0805 08:38:52.601755 140268250425152 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/graph_util_impl.py:270: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.com

1532 ops in the final graph.


In [21]:
g = load_graph('concat/frozen_model.pb')

In [22]:
word_ids = g.get_tensor_by_name('import/Placeholder:0')
char_ids = g.get_tensor_by_name('import/Placeholder_1:0')
tags_seq = g.get_tensor_by_name('import/logits:0')
tags_state_fw = g.get_tensor_by_name('import/transitions:0')
tags_state_bw = g.get_tensor_by_name('import/Variable:0')
test_sess = tf.InteractiveSession(graph = g)
predicted = test_sess.run([tags_seq, tags_state_fw, tags_state_bw],
            feed_dict = {
                word_ids: X_seq,
                char_ids: X_char_seq,
            })



In [23]:
predicted

[array([[4, 2, 8, 8]], dtype=int32),
 array([[ 0.47884172, -0.16190596, -2.4069993 , -1.3115252 , -1.9645609 ,
         -0.20467333, -1.109368  , -0.4051585 , -0.73241717, -0.57735586],
        [-0.41838917, -0.40878856, -3.450508  , -1.0079795 , -1.126499  ,
         -0.8411751 , -0.4628218 , -0.53364605, -0.6538642 , -0.98395264],
        [-1.4484956 , -2.4792776 ,  1.0506705 , -1.0934587 , -0.9977262 ,
         -1.2000083 , -0.6880191 , -1.6287059 , -1.5211865 , -1.2802426 ],
        [-1.041553  , -1.6351454 , -1.4314326 ,  1.9525024 , -1.3669676 ,
         -1.4555649 , -0.41423967, -1.5732112 , -3.9457858 , -3.0300593 ],
        [-1.3480742 , -2.1231194 , -1.2117796 , -2.1732378 ,  1.8638    ,
         -0.6669971 , -1.952644  , -1.663421  , -3.3421414 , -1.807233  ],
        [-0.3863962 , -0.47493416, -1.2358948 , -1.161847  , -0.49586853,
          1.8206847 , -0.7389852 , -3.117508  , -1.6827074 , -1.134648  ],
        [-0.9731228 , -1.271768  , -0.81208587, -2.388461  , -0.62136