In [1]:
import pickle
import json
import tensorflow as tf
import numpy as np

In [2]:
with open('session.pkl', 'rb') as fopen:
    data = pickle.load(fopen)
data.keys()

dict_keys(['train_X', 'test_X', 'train_Y', 'test_Y'])

In [3]:
train_X = data['train_X']
test_X = data['test_X']
train_Y = data['train_Y']
test_Y = data['test_Y']

In [4]:
with open('dictionary-entities.json') as fopen:
    dictionary = json.load(fopen)
dictionary.keys()

dict_keys(['word2idx', 'idx2word', 'tag2idx', 'idx2tag', 'char2idx'])

In [5]:
word2idx = dictionary['word2idx']
idx2word = {int(k): v for k, v in dictionary['idx2word'].items()}
tag2idx = dictionary['tag2idx']
idx2tag = {int(k): v for k, v in dictionary['idx2tag'].items()}
char2idx = dictionary['char2idx']

In [6]:
list(zip([idx2word[d] for d in train_X[-1]], [idx2tag[d] for d in train_Y[-1]]))

[('harahap', 'person'),
 (',', 'OTHER'),
 ('pedagang', 'OTHER'),
 ('ikan', 'OTHER'),
 ('kering', 'OTHER'),
 (',', 'OTHER'),
 ('kering', 'OTHER'),
 (',', 'OTHER'),
 ('para', 'OTHER'),
 ('pembeli', 'OTHER'),
 ('mayoritas', 'OTHER'),
 ('datang', 'OTHER'),
 ('dari', 'OTHER'),
 ('luar', 'OTHER'),
 ('kota', 'OTHER'),
 ('seperti', 'OTHER'),
 ('meranti', 'location'),
 (',', 'OTHER'),
 ('riau', 'location'),
 (',', 'OTHER'),
 ('jambi', 'location'),
 (',', 'OTHER'),
 ('dan', 'OTHER'),
 ('batam', 'location'),
 ('.', 'OTHER'),
 ('cabai', 'OTHER'),
 ('kering', 'OTHER'),
 (',', 'OTHER'),
 ('para', 'OTHER'),
 ('pembeli', 'OTHER'),
 ('mayoritas', 'OTHER'),
 ('datang', 'OTHER'),
 ('dari', 'OTHER'),
 ('luar', 'OTHER'),
 ('kota', 'OTHER'),
 ('seperti', 'OTHER'),
 ('chuah', 'location'),
 (',', 'OTHER'),
 ('riau', 'location'),
 (',', 'OTHER'),
 ('jambi', 'location'),
 (',', 'OTHER'),
 ('dan', 'OTHER'),
 ('batam', 'location'),
 ('.', 'OTHER'),
 ('cabai', 'OTHER'),
 ('para', 'OTHER'),
 ('pembeli', 'OTHER'),
 

In [7]:
list(zip([idx2word[d] for d in train_X[1]], [idx2tag[d] for d in train_Y[1]]))

[('politik', 'OTHER'),
 ('dari', 'OTHER'),
 ('Universitas', 'organization'),
 ('Gadjah', 'organization'),
 ('Mada', 'organization'),
 (',', 'OTHER'),
 ('Arie', 'person'),
 ('Sudjito', 'person'),
 (',', 'OTHER'),
 ('menilai,', 'OTHER'),
 ('keinginan', 'OTHER'),
 ('Ketua', 'OTHER'),
 ('Umum', 'OTHER'),
 ('Partai', 'organization'),
 ('Golkar', 'organization'),
 ('Aburizal', 'person'),
 ('Bakrie', 'person'),
 ('untuk', 'OTHER'),
 ('maju', 'OTHER'),
 ('kembali', 'OTHER'),
 ('sebagai', 'OTHER'),
 ('ketua', 'OTHER'),
 ('umum', 'OTHER'),
 ('merupakan', 'OTHER'),
 ('pemaksaan', 'OTHER'),
 ('kehendak.', 'OTHER'),
 ('Menurut', 'OTHER'),
 ('dia,', 'OTHER'),
 ('ada', 'OTHER'),
 ('kesan', 'OTHER'),
 ('bahwa', 'OTHER'),
 ('Aburizal', 'person'),
 ('menggunakan', 'OTHER'),
 ('segala', 'OTHER'),
 ('cara', 'OTHER'),
 ('untuk', 'OTHER'),
 ('memuluskan', 'OTHER'),
 ('jalannya', 'OTHER'),
 ('kembali', 'OTHER'),
 ('menduduki', 'OTHER'),
 ('Golkar', 'organization'),
 ('1.', 'OTHER'),
 ('Hal', 'OTHER'),
 ('ini

In [8]:
def generate_char_seq(batch):
    x = [[len(idx2word[i]) for i in k] for k in batch]
    maxlen = max([j for i in x for j in i])
    temp = np.zeros((batch.shape[0],batch.shape[1],maxlen),dtype=np.int32)
    for i in range(batch.shape[0]):
        for k in range(batch.shape[1]):
            for no, c in enumerate(idx2word[batch[i,k]]):
                temp[i,k,-1-no] = char2idx[c]
    return temp

In [9]:
generate_char_seq(data['train_X'][:10]).shape

(10, 50, 11)

In [10]:
class Model:
    def __init__(
        self,
        dim_word,
        dim_char,
        dropout,
        learning_rate,
        hidden_size_char,
        hidden_size_word,
        num_layers,
    ):
        def cells(size, reuse = False):
            return tf.contrib.rnn.DropoutWrapper(
                tf.nn.rnn_cell.LSTMCell(
                    size,
                    initializer = tf.orthogonal_initializer(),
                    reuse = reuse,
                ),
                output_keep_prob = dropout,
            )

        def bahdanau(embedded, size):
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                num_units = hidden_size_word, memory = embedded
            )
            return tf.contrib.seq2seq.AttentionWrapper(
                cell = cells(hidden_size_word),
                attention_mechanism = attention_mechanism,
                attention_layer_size = hidden_size_word,
            )

        self.word_ids = tf.placeholder(tf.int32, shape = [None, None])
        self.char_ids = tf.placeholder(tf.int32, shape = [None, None, None])
        self.labels = tf.placeholder(tf.int32, shape = [None, None])
        self.maxlen = tf.shape(self.word_ids)[1]
        self.lengths = tf.count_nonzero(self.word_ids, 1)
        
        self.word_embeddings = tf.Variable(
            tf.truncated_normal(
                [len(word2idx), dim_word], stddev = 1.0 / np.sqrt(dim_word)
            )
        )
        self.char_embeddings = tf.Variable(
            tf.truncated_normal(
                [len(char2idx), dim_char], stddev = 1.0 / np.sqrt(dim_char)
            )
        )

        word_embedded = tf.nn.embedding_lookup(
            self.word_embeddings, self.word_ids
        )
        char_embedded = tf.nn.embedding_lookup(
            self.char_embeddings, self.char_ids
        )
        s = tf.shape(char_embedded)
        char_embedded = tf.reshape(
            char_embedded, shape = [s[0] * s[1], s[-2], dim_char]
        )
        
        for n in range(num_layers):
            (out_fw, out_bw), (
                state_fw,
                state_bw,
            ) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cells(hidden_size_char),
                cell_bw = cells(hidden_size_char),
                inputs = char_embedded,
                dtype = tf.float32,
                scope = 'bidirectional_rnn_char_%d' % (n),
            )
            char_embedded = tf.concat((out_fw, out_bw), 2)
        output = tf.reshape(
            char_embedded[:, -1], shape = [s[0], s[1], 2 * hidden_size_char]
        )
        word_embedded = tf.concat([word_embedded, output], axis = -1)

        for n in range(num_layers):
            (out_fw, out_bw), (
                state_fw,
                state_bw,
            ) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = bahdanau(word_embedded, hidden_size_word),
                cell_bw = bahdanau(word_embedded, hidden_size_word),
                inputs = word_embedded,
                dtype = tf.float32,
                scope = 'bidirectional_rnn_word_%d' % (n),
            )
            word_embedded = tf.concat((out_fw, out_bw), 2)

        logits = tf.layers.dense(word_embedded, len(idx2tag))
        y_t = self.labels
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
            logits, y_t, self.lengths
        )
        self.cost = tf.reduce_mean(-log_likelihood)
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)
        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(
            logits, transition_params, self.lengths
        )
        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')

        y_t = tf.cast(y_t, tf.int32)
        self.prediction = tf.boolean_mask(self.tags_seq, mask)
        mask_label = tf.boolean_mask(y_t, mask)
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [11]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

dim_word = 128
dim_char = 256
dropout = 0.8
learning_rate = 1e-3
hidden_size_char = 128
hidden_size_word = 128
num_layers = 2
batch_size = 64

model = Model(dim_word,dim_char,dropout,learning_rate,hidden_size_char,hidden_size_word,num_layers)
sess.run(tf.global_variables_initializer())

W0802 03:09:38.548239 140429853955904 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.
Instructions for updating:
reduction_indices is deprecated, use axis instead
W0802 03:09:39.054165 140429853955904 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0802 03:09:39.055277 140429853955904 deprecation.py:323] From <ipython-input-10-c6dd067f10d4>:17: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for

In [12]:
string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'

import re

def entities_textcleaning(string, lowering = False):
    """
    use by entities recognition, pos recognition and dependency parsing
    """
    string = re.sub('[^A-Za-z0-9\-\/() ]+', ' ', string)
    string = re.sub(r'[ ]+', ' ', string).strip()
    original_string = string.split()
    if lowering:
        string = string.lower()
    string = [
        (original_string[no], word.title() if word.isupper() else word)
        for no, word in enumerate(string.split())
        if len(word)
    ]
    return [s[0] for s in string], [s[1] for s in string]

def char_str_idx(corpus, dic, UNK = 0):
    maxlen = max([len(i) for i in corpus])
    X = np.zeros((len(corpus), maxlen))
    for i in range(len(corpus)):
        for no, k in enumerate(corpus[i][:maxlen][::-1]):
            val = dic[k] if k in dic else UNK
            X[i, -1 - no] = val
    return X

In [15]:
from tqdm import tqdm
import time

for e in range(3):
    lasttime = time.time()
    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0
    pbar = tqdm(
        range(0, train_X.shape[0], batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, train_X.shape[0])
        batch_x = train_X[i : index]
        batch_char = generate_char_seq(batch_x)
        batch_y = train_Y[i : index]
        acc, cost, _ = sess.run(
            [model.accuracy, model.cost, model.optimizer],
            feed_dict = {
                model.word_ids: batch_x,
                model.char_ids: batch_char,
                model.labels: batch_y
            },
        )
        assert not np.isnan(cost)
        train_loss += cost
        train_acc += acc
        pbar.set_postfix(cost = cost, accuracy = acc)
        
    pbar = tqdm(
        range(0, test_X.shape[0], batch_size), desc = 'test minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, test_X.shape[0])
        batch_x = test_X[i : index]
        batch_char = generate_char_seq(batch_x)
        batch_y = test_Y[i : index]
        acc, cost = sess.run(
            [model.accuracy, model.cost],
            feed_dict = {
                model.word_ids: batch_x,
                model.char_ids: batch_char,
                model.labels: batch_y
            },
        )
        assert not np.isnan(cost)
        test_loss += cost
        test_acc += acc
        pbar.set_postfix(cost = cost, accuracy = acc)
    
    train_loss /= len(train_X) / batch_size
    train_acc /= len(train_X) / batch_size
    test_loss /= len(test_X) / batch_size
    test_acc /= len(test_X) / batch_size

    print('time taken:', time.time() - lasttime)
    print(
        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\n'
        % (e, train_loss, train_acc, test_loss, test_acc)
    )
    
    sequence = entities_textcleaning(string)[1]
    X_seq = char_str_idx([sequence], word2idx, 2)
    X_char_seq = generate_char_seq(X_seq)

    predicted = sess.run(model.tags_seq,
                feed_dict = {
                    model.word_ids: X_seq,
                    model.char_ids: X_char_seq,
                },
        )[0]

    for i in range(len(predicted)):
        print(sequence[i],idx2tag[predicted[i]])

train minibatch loop:  60%|█████▉    | 5474/9191 [1:11:56<55:55,  1.11it/s, accuracy=1, cost=0.00231]     IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

train minibatch loop: 100%|██████████| 9191/9191 [2:00:56<00:00,  1.24it/s, accuracy=1, cost=0.00317]   
test minibatch loop: 100%|██████████| 2298/2298 [13:16<00:00,  4.19it/s, accuracy=0.956, cost=16.7] 
train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

time taken: 8052.7077124118805
epoch: 0, training loss: 0.189442, training acc: 0.998745, valid loss: 15.468634, valid acc: 0.942164

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri event
minggu time
depan time
Perdana person
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang quantity
ramai quantity
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung OTHER
halaman OTHER
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) person
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk person
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:59:51<00:00,  1.26it/s, accuracy=1, cost=0.00309]     
test minibatch loop: 100%|██████████| 2298/2298 [12:32<00:00,  2.62it/s, accuracy=0.944, cost=21]   


time taken: 7944.539978504181
epoch: 1, training loss: 0.155295, training acc: 0.998993, valid loss: 14.770646, valid acc: 0.945195



train minibatch loop:   0%|          | 0/9191 [00:00<?, ?it/s]

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri OTHER
minggu OTHER
depan OTHER
Perdana person
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang organization
ramai OTHER
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman location
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) person
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti person
berehat person
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk OTHER
ketika OTHER
memandu OTHER


train minibatch loop: 100%|██████████| 9191/9191 [1:54:50<00:00,  1.73it/s, accuracy=1, cost=0.000509]    
test minibatch loop: 100%|██████████| 2298/2298 [12:40<00:00,  3.47it/s, accuracy=0.956, cost=22.7] 

time taken: 7651.497144460678
epoch: 2, training loss: 0.116464, training acc: 0.999273, valid loss: 13.835564, valid acc: 0.948873

Kuala location
Lumpur location
Sempena OTHER
sambutan OTHER
Aidilfitri event
minggu OTHER
depan OTHER
Perdana organization
Menteri person
Tun person
Dr person
Mahathir person
Mohamad person
dan OTHER
Menteri person
Pengangkutan person
Anthony person
Loke person
Siew person
Fook person
menitipkan person
pesanan OTHER
khas OTHER
kepada OTHER
orang OTHER
ramai OTHER
yang OTHER
mahu OTHER
pulang OTHER
ke OTHER
kampung location
halaman OTHER
masing-masing OTHER
Dalam OTHER
video OTHER
pendek OTHER
terbitan OTHER
Jabatan organization
Keselamatan organization
Jalan organization
Raya organization
(Jkjr) person
itu OTHER
Dr person
Mahathir person
menasihati OTHER
mereka OTHER
supaya OTHER
berhenti OTHER
berehat OTHER
dan OTHER
tidur OTHER
sebentar OTHER
sekiranya OTHER
mengantuk organization
ketika OTHER
memandu OTHER





In [16]:
sequence = entities_textcleaning('mahathir suka akta 19977')[1]
X_seq = char_str_idx([sequence], word2idx, 2)
X_char_seq = generate_char_seq(X_seq)

predicted = sess.run(model.tags_seq,
            feed_dict = {
                model.word_ids: X_seq,
                model.char_ids: X_char_seq,
            },
    )[0]

for i in range(len(predicted)):
    print(sequence[i],idx2tag[predicted[i]])

mahathir person
suka OTHER
akta law
19977 law


In [17]:
def pred2label(pred):
    out = []
    for pred_i in pred:
        out_i = []
        for p in pred_i:
            out_i.append(idx2tag[p])
        out.append(out_i)
    return out

In [18]:
real_Y, predict_Y = [], []

pbar = tqdm(
    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'
)
for i in pbar:
    batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]
    batch_char = generate_char_seq(batch_x)
    batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]
    predicted = pred2label(sess.run(model.tags_seq,
            feed_dict = {
                model.word_ids: batch_x,
                model.char_ids: batch_char,
            },
    ))
    real = pred2label(batch_y)
    predict_Y.extend(predicted)
    real_Y.extend(real)

validation minibatch loop: 100%|██████████| 2298/2298 [12:14<00:00,  4.29it/s]


In [19]:
from sklearn.metrics import classification_report
print(classification_report(np.array(real_Y).ravel(), np.array(predict_Y).ravel(),
                           digits = 6))

              precision    recall  f1-score   support

       OTHER   0.974847  0.994647  0.984648   5160854
       event   0.984159  0.230737  0.373830    143787
         law   0.981267  0.869745  0.922146    146950
    location   0.790109  0.969399  0.870619    428869
organization   0.950195  0.736809  0.830007    694150
      person   0.894418  0.951801  0.922218    507960
    quantity   0.873435  0.996122  0.930753     88200
        time   0.830533  0.994663  0.905218    179880

    accuracy                       0.948443   7350650
   macro avg   0.909871  0.842990  0.842430   7350650
weighted avg   0.951745  0.948443  0.943289   7350650



In [20]:
saver = tf.train.Saver(tf.trainable_variables())
saver.save(sess, 'bahdanau/model.ckpt')

strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 'alphas' in n.name)
        and 'Adam' not in n.name
        and 'beta' not in n.name
        and 'OptimizeLoss' not in n.name
        and 'Global_Step' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'Placeholder_2',
 'Variable',
 'Variable_1',
 'bidirectional_rnn_char_0/fw/lstm_cell/kernel',
 'bidirectional_rnn_char_0/fw/lstm_cell/bias',
 'bidirectional_rnn_char_0/bw/lstm_cell/kernel',
 'bidirectional_rnn_char_0/bw/lstm_cell/bias',
 'bidirectional_rnn_char_1/fw/lstm_cell/kernel',
 'bidirectional_rnn_char_1/fw/lstm_cell/bias',
 'bidirectional_rnn_char_1/bw/lstm_cell/kernel',
 'bidirectional_rnn_char_1/bw/lstm_cell/bias',
 'memory_layer/kernel',
 'memory_layer_1/kernel',
 'bidirectional_rnn_word_0/fw/attention_wrapper/lstm_cell/kernel',
 'bidirectional_rnn_word_0/fw/attention_wrapper/lstm_cell/bias',
 'bidirectional_rnn_word_0/fw/attention_wrapper/bahdanau_attention/query_layer/kernel',
 'bidirectional_rnn_word_0/fw/attention_wrapper/bahdanau_attention/attention_v',
 'bidirectional_rnn_word_0/fw/attention_wrapper/attention_layer/kernel',
 'bidirectional_rnn_word_0/bw/attention_wrapper/lstm_cell/kernel',
 'bidirectional_rnn_word_0/bw/attention_wrapp

In [21]:
def freeze_graph(model_dir, output_node_names):

    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))
        
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph


In [22]:
freeze_graph('bahdanau', strings)

W0803 00:46:58.110622 140429853955904 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
W0803 00:46:58.593833 140429853955904 deprecation.py:323] From <ipython-input-21-3d8392da7830>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
W0803 00:46:58.594611 140429853955904 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/graph_util_impl.py:270: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.com

1928 ops in the final graph.


In [23]:
g = load_graph('bahdanau/frozen_model.pb')

In [25]:
word_ids = g.get_tensor_by_name('import/Placeholder:0')
char_ids = g.get_tensor_by_name('import/Placeholder_1:0')
tags_seq = g.get_tensor_by_name('import/logits:0')
tags_state_fw = g.get_tensor_by_name('import/transitions:0')
tags_state_bw = g.get_tensor_by_name('import/Variable:0')
test_sess = tf.InteractiveSession(graph = g)
predicted = test_sess.run([tags_seq, tags_state_fw, tags_state_bw],
            feed_dict = {
                word_ids: X_seq,
                char_ids: X_char_seq,
            })



In [26]:
predicted

[array([[6, 2, 8, 8]], dtype=int32),
 array([[-0.08600868, -0.3673812 , -6.0727143 , -4.3111515 , -3.144992  ,
         -3.6502264 , -2.2521126 , -2.1055558 , -2.5405333 , -1.5567911 ],
        [-0.19181503, -0.69721836, -5.6127443 , -4.5875454 , -3.4643524 ,
         -3.7046876 , -2.973486  , -1.9980274 , -2.3059711 , -2.2200272 ],
        [-4.2048063 , -4.8778405 ,  1.7996664 , -1.4549167 , -1.4113306 ,
         -1.3110472 , -0.79679346, -1.2164339 , -1.6451806 , -1.3484511 ],
        [-3.3417375 , -4.240648  , -1.4909729 ,  1.7231508 , -2.080556  ,
         -2.6205368 , -1.1666468 , -3.0111673 , -5.055568  , -3.860059  ],
        [-2.7892244 , -2.660718  , -0.8998524 , -2.5417776 ,  2.0364769 ,
         -1.7014722 , -2.2487772 , -2.2349706 , -4.590649  , -2.542902  ],
        [-2.6509175 , -3.3361633 , -1.3775858 , -1.8805891 , -1.8057353 ,
          1.5868361 , -2.1673765 , -4.4093328 , -2.3444242 , -2.2622836 ],
        [-2.6717138 , -3.228594  , -0.9509716 , -3.610709  , -1.41448