In [1]:
import sys
import os
sys.path.append(os.getcwd())
import time
import random
import math
import re
import json
import pickle
from typing import List, Tuple, Dict, Callable, Optional, Any, Sequence, Mapping, NamedTuple
from datetime import datetime

In [2]:
import tensorflow as tf
import numpy as np
import matplotlib as plt

In [3]:
from utils.config import Config
from data_loader.ptb_datasource import PTBDataSource
from utils import transformer

In [4]:
configs = []
configs.append(Config(num_layers=2, num_units=512, batch_size=128, log_dir='./logs/transformer/'))
for config in configs:
    print(config)

Config(num_units=512, num_layers=2, num_heads=8, num_outputs=10000, batch_size=128, max_length=50, dropout_in_rate=0.1, dropout_out_rate=0.2, learning_rate=0.001, grad_clip=5.0, is_layer_norm=False, data_path='./data/', log_dir='./logs/transformer/')


In [5]:
data = PTBDataSource(configs[0])
print(data.vocab_size)

10000


In [6]:
class TransformerEncoder:
    
    def __init__(self, config, vocab_size, reuse=None):
        self._config = config
        self.vocab_size = vocab_size
        
        self._create_placeholder()
        self._create_model(reuse)
        self.loss = self._create_loss()
        self.accuracy = self._create_accuracy()
        
    def _create_placeholder(self):
        self.is_training = tf.placeholder(shape=(), dtype=tf.bool, name='is_training')
        self.inputs_data = tf.placeholder(shape=[None, self._config.max_length], name='inputs_data', dtype=tf.int32)  # batch_size x max_length
        self.target_classes = tf.placeholder(shape=[None], name='targets_classes', dtype=tf.int32)  # batch_size
        
    def _create_model(self, reuse):
        with tf.variable_scope('transformer', reuse=reuse):
            self.global_step = tf.train.get_or_create_global_step()
            self.encoder_inputs = self.inputs_data
            encoded_queries = self.encode()
            self.outputs_logits = tf.layers.dense(encoded_queries, self._config.num_outputs, name='outputs_layer')
            self.predicted_classes = tf.to_int32(tf.argmax(self.outputs_logits, axis=-1))
            
    def _create_loss(self):
        target_classes_one_hot = tf.one_hot(self.target_classes, self._config.num_outputs)
        target_classes_smoothed = self._label_smoothing(target_classes_one_hot)
        cross_ent = tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.outputs_logits, labels=target_classes_one_hot)
        return tf.reduce_mean(cross_ent)
    
    def _create_accuracy(self):
        correct = tf.equal(self.target_classes, self.predicted_classes)
        return tf.reduce_mean(tf.to_float(correct))
       
    def encode(self):
        with tf.variable_scope('encoder'):
            encoder_inputs_embedded = transformer.embedding(
                self.encoder_inputs,
                self._config.num_outputs,
                self._config.num_units,
                is_scale=True,
                scope='enc_embed'
            )
            encoder_inputs_embedded += transformer.positional_encoding(
                self.encoder_inputs,
                num_units=self._config.num_units,
                is_zero_pad=True,
            )
            encoder_inputs_embedded = tf.layers.dropout(
                encoder_inputs_embedded,
                rate=self._config.dropout_in_rate,
                training=self.is_training
            )

            sent_encoder_init_state = tf.get_variable(  # [1, hidden_units]
                'sent_encoder_init_state',
                dtype=tf.float32,
                shape=[1, self._config.num_units],
                initializer=tf.contrib.layers.xavier_initializer()
            )
            sent_encoder_inputs_embedded = tf.tile(
                tf.expand_dims(sent_encoder_init_state, 0),  # [1, 1, hidden_units]
                [tf.shape(encoder_inputs_embedded)[0], 1, 1]  # [batch_size, 1, hidden_units]
            )

            for i in range(self._config.num_layers):
                with tf.variable_scope('block_{}'.format(i)):
                    encoder_inputs_embedded = transformer.multihead_attention(
                        queries=encoder_inputs_embedded,
                        keys=encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self._config.dropout_in_rate,
                        num_units=self._config.num_units,
                        num_heads=self._config.num_heads,
                        is_causality=False,
                        scope='self_attention'
                    )

                    sent_encoder_inputs_embedded = transformer.multihead_attention(
                        queries=sent_encoder_inputs_embedded,
                        keys=encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self._config.dropout_in_rate,
                        num_units=self._config.num_units,
                        num_heads=self._config.num_heads,
                        is_causality=False,
                        scope='scale_dot_attention'
                    )
                    sent_encoder_inputs_embedded = transformer.feedforward(
                        sent_encoder_inputs_embedded,
                        num_units=[4*self._config.num_units, self._config.num_units],
                        scope='feedforward'
                    )
            sent_encoder_inputs_embedded = tf.squeeze(sent_encoder_inputs_embedded, 1)
            return sent_encoder_inputs_embedded
    
    def _label_smoothing(self, inputs, epsilon: float=0.1):
        feature_dim = inputs.get_shape().as_list()[-1]
        return (1-epsilon) * inputs + (epsilon / feature_dim)
    

In [7]:
num_epoch = 1000

In [8]:
# with tf.device('/device:GPU:1'):
#     datasource = PTBDataSource(config)
#     model = TransformerEncoder(config, data.vocab_size)
#     tf_config = tf.ConfigProto(
#             allow_soft_placement=True,
#             log_device_placement=True,
#             gpu_options=tf.GPUOptions(
#                 allow_growth=True,
#             ))
#     with tf.Session(config=tf_config) as sess:
#         sess.run(tf.global_variables_initializer())
#         sents = sess.run([model.t1, model.t2], data.feed_test_transformer(model))
#         print(sents)

In [9]:
# test = data.feed_test_transformer(model)
# print(test[model.inputs_data][11])
# print(test[model.target_classes][10])

In [10]:
def start(config):
    with tf.device('/device:GPU:1'):
        now = datetime.now()
        logdir = now.strftime("%Y%m%d-%H%M%S") + "/"
        datasource = PTBDataSource(config)
        model = TransformerEncoder(config, data.vocab_size, reuse=False)

        optimizer = tf.train.AdamOptimizer(config.learning_rate)
        train_vars = tf.trainable_variables()
        gradients = tf.global_norm(tf.gradients(model.loss, train_vars))
        train_op = optimizer.minimize(model.loss, global_step=model.global_step)
        with tf.name_scope('training'):
            s_loss = tf.summary.scalar('loss', model.loss)
            s_grad = tf.summary.scalar('gradient', gradients)
            s_acc = tf.summary.scalar('accuracy', model.accuracy)
            s_training = tf.summary.merge([s_loss, s_grad, s_acc])
        with tf.name_scope('test'):
            test_s_acc = tf.summary.scalar('accuracy', model.accuracy)
           
        tf_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=True,
            gpu_options=tf.GPUOptions(
                allow_growth=True,
            ))
        with tf.Session(config=tf_config) as sess:
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            #saver.restore(sess, config.to_ckpt_path() + '-433615')
            writer = tf.summary.FileWriter(config.to_log_dir() + '/' + logdir, sess.graph)
            for i in range(num_epoch):
                start = time.time()
                datasource.shuffle()
                batch_list = datasource.feed_dict_transformer(model)
                losses = []
                accuracies = []
                for fd in batch_list:
                    loss, acc, _, smr_training, step = sess.run([model.loss, model.accuracy, train_op, s_training, model.global_step], feed_dict=fd)
                    losses.append(loss)
                    accuracies.append(acc)
                    writer.add_summary(smr_training, step)
                    if step % 100 == 0:
                        #print('loss: {:.3f}, acc: {:.3f}'.format(loss, acc))
                        inference(sess, model, datasource, writer, test_s_acc, step)
                elapsed = time.time() - start
                print('epoch {}/{} finished, {} step, elapsed {} sec. average loss: {:.3f}, average accuracy: {:.3f}'.format(i+1, num_epoch, step, elapsed, np.average(losses), np.average(accuracies)))
                saver.save(sess, config.to_ckpt_path(), global_step=step)

In [11]:
def inference(sess, model, datasource, writer, s_acc, step):
    with tf.name_scope('inference'):
        test_list = datasource.feed_test_transformer(model)
        acc, smr_acc = sess.run([model.accuracy, s_acc], feed_dict=test_list)
        writer.add_summary(smr_acc, step)

In [12]:
for config in configs:
    start(config)

epoch 1/1000 finished, 327 step, elapsed 32.46992373466492 sec. average loss: 6.909, average accuracy: 0.069
epoch 2/1000 finished, 655 step, elapsed 30.660547971725464 sec. average loss: 6.678, average accuracy: 0.083
epoch 3/1000 finished, 983 step, elapsed 30.873374223709106 sec. average loss: 6.573, average accuracy: 0.087
epoch 4/1000 finished, 1311 step, elapsed 31.320802211761475 sec. average loss: 6.500, average accuracy: 0.088
epoch 5/1000 finished, 1639 step, elapsed 31.550824880599976 sec. average loss: 6.336, average accuracy: 0.102
epoch 6/1000 finished, 1967 step, elapsed 31.477858781814575 sec. average loss: 6.284, average accuracy: 0.102
epoch 7/1000 finished, 2295 step, elapsed 31.597407341003418 sec. average loss: 6.212, average accuracy: 0.110
epoch 8/1000 finished, 2623 step, elapsed 31.842812299728394 sec. average loss: 6.121, average accuracy: 0.114
epoch 9/1000 finished, 2951 step, elapsed 31.427751541137695 sec. average loss: 6.050, average accuracy: 0.119
epoch

epoch 79/1000 finished, 25911 step, elapsed 31.309062957763672 sec. average loss: 4.164, average accuracy: 0.253
epoch 80/1000 finished, 26239 step, elapsed 31.446746587753296 sec. average loss: 4.147, average accuracy: 0.256
epoch 81/1000 finished, 26567 step, elapsed 31.24166965484619 sec. average loss: 4.175, average accuracy: 0.253
epoch 82/1000 finished, 26895 step, elapsed 31.03608989715576 sec. average loss: 4.108, average accuracy: 0.257
epoch 83/1000 finished, 27223 step, elapsed 31.42402720451355 sec. average loss: 4.077, average accuracy: 0.262
epoch 84/1000 finished, 27551 step, elapsed 31.414864778518677 sec. average loss: 4.052, average accuracy: 0.268
epoch 85/1000 finished, 27879 step, elapsed 31.288513898849487 sec. average loss: 4.079, average accuracy: 0.261
epoch 86/1000 finished, 28207 step, elapsed 31.503530979156494 sec. average loss: 4.077, average accuracy: 0.260
epoch 87/1000 finished, 28535 step, elapsed 31.236156225204468 sec. average loss: 4.067, average ac

epoch 152/1000 finished, 49855 step, elapsed 31.123961448669434 sec. average loss: 3.469, average accuracy: 0.315
epoch 153/1000 finished, 50183 step, elapsed 31.199281930923462 sec. average loss: 3.396, average accuracy: 0.325
epoch 154/1000 finished, 50511 step, elapsed 31.419883966445923 sec. average loss: 3.372, average accuracy: 0.329
epoch 155/1000 finished, 50839 step, elapsed 31.484700202941895 sec. average loss: 3.413, average accuracy: 0.322
epoch 156/1000 finished, 51167 step, elapsed 31.126755714416504 sec. average loss: 3.399, average accuracy: 0.322
epoch 157/1000 finished, 51495 step, elapsed 31.20990538597107 sec. average loss: 3.375, average accuracy: 0.329
epoch 158/1000 finished, 51823 step, elapsed 31.42597222328186 sec. average loss: 3.396, average accuracy: 0.321
epoch 159/1000 finished, 52151 step, elapsed 31.366059064865112 sec. average loss: 3.308, average accuracy: 0.334
epoch 160/1000 finished, 52479 step, elapsed 30.900760650634766 sec. average loss: 3.357, 

epoch 225/1000 finished, 73799 step, elapsed 31.15962314605713 sec. average loss: 2.996, average accuracy: 0.371
epoch 226/1000 finished, 74127 step, elapsed 31.386252403259277 sec. average loss: 3.008, average accuracy: 0.367
epoch 227/1000 finished, 74455 step, elapsed 31.37447237968445 sec. average loss: 3.026, average accuracy: 0.365
epoch 228/1000 finished, 74783 step, elapsed 31.21468758583069 sec. average loss: 3.029, average accuracy: 0.364
epoch 229/1000 finished, 75111 step, elapsed 31.313480854034424 sec. average loss: 3.017, average accuracy: 0.363
epoch 230/1000 finished, 75439 step, elapsed 31.31938076019287 sec. average loss: 2.978, average accuracy: 0.371
epoch 231/1000 finished, 75767 step, elapsed 31.337427377700806 sec. average loss: 2.994, average accuracy: 0.370
epoch 232/1000 finished, 76095 step, elapsed 31.352169513702393 sec. average loss: 2.981, average accuracy: 0.367
epoch 233/1000 finished, 76423 step, elapsed 31.277212381362915 sec. average loss: 2.952, av

epoch 298/1000 finished, 97743 step, elapsed 31.209625720977783 sec. average loss: 2.724, average accuracy: 0.402
epoch 299/1000 finished, 98071 step, elapsed 31.281752824783325 sec. average loss: 2.731, average accuracy: 0.400
epoch 300/1000 finished, 98399 step, elapsed 31.170263528823853 sec. average loss: 2.734, average accuracy: 0.404
epoch 301/1000 finished, 98727 step, elapsed 31.35391592979431 sec. average loss: 2.730, average accuracy: 0.402
epoch 302/1000 finished, 99055 step, elapsed 31.044712781906128 sec. average loss: 2.754, average accuracy: 0.400
epoch 303/1000 finished, 99383 step, elapsed 31.29228401184082 sec. average loss: 2.705, average accuracy: 0.405
epoch 304/1000 finished, 99711 step, elapsed 31.231653213500977 sec. average loss: 2.721, average accuracy: 0.404
epoch 305/1000 finished, 100039 step, elapsed 31.22807812690735 sec. average loss: 2.713, average accuracy: 0.404
epoch 306/1000 finished, 100367 step, elapsed 31.385528326034546 sec. average loss: 2.697,

epoch 370/1000 finished, 121359 step, elapsed 31.393862009048462 sec. average loss: 2.539, average accuracy: 0.430
epoch 371/1000 finished, 121687 step, elapsed 31.27353549003601 sec. average loss: 2.550, average accuracy: 0.424
epoch 372/1000 finished, 122015 step, elapsed 31.37419557571411 sec. average loss: 2.573, average accuracy: 0.424
epoch 373/1000 finished, 122343 step, elapsed 31.449875593185425 sec. average loss: 2.534, average accuracy: 0.427
epoch 374/1000 finished, 122671 step, elapsed 31.21677827835083 sec. average loss: 2.525, average accuracy: 0.432
epoch 375/1000 finished, 122999 step, elapsed 31.218381881713867 sec. average loss: 2.501, average accuracy: 0.434
epoch 376/1000 finished, 123327 step, elapsed 31.24512028694153 sec. average loss: 2.496, average accuracy: 0.433
epoch 377/1000 finished, 123655 step, elapsed 31.6361346244812 sec. average loss: 2.540, average accuracy: 0.427
epoch 378/1000 finished, 123983 step, elapsed 31.291186571121216 sec. average loss: 2.

epoch 442/1000 finished, 144975 step, elapsed 31.271873235702515 sec. average loss: 2.387, average accuracy: 0.451
epoch 443/1000 finished, 145303 step, elapsed 31.366140604019165 sec. average loss: 2.386, average accuracy: 0.452
epoch 444/1000 finished, 145631 step, elapsed 31.18592643737793 sec. average loss: 2.360, average accuracy: 0.455
epoch 445/1000 finished, 145959 step, elapsed 31.438945293426514 sec. average loss: 2.382, average accuracy: 0.453
epoch 446/1000 finished, 146287 step, elapsed 31.33485245704651 sec. average loss: 2.407, average accuracy: 0.448
epoch 447/1000 finished, 146615 step, elapsed 31.492464065551758 sec. average loss: 2.413, average accuracy: 0.443
epoch 448/1000 finished, 146943 step, elapsed 31.255504846572876 sec. average loss: 2.352, average accuracy: 0.454
epoch 449/1000 finished, 147271 step, elapsed 31.350217819213867 sec. average loss: 2.311, average accuracy: 0.465
epoch 450/1000 finished, 147599 step, elapsed 31.17046046257019 sec. average loss:

epoch 514/1000 finished, 168591 step, elapsed 31.36019253730774 sec. average loss: 2.252, average accuracy: 0.476
epoch 515/1000 finished, 168919 step, elapsed 31.583788633346558 sec. average loss: 2.249, average accuracy: 0.471
epoch 516/1000 finished, 169247 step, elapsed 31.16168999671936 sec. average loss: 2.244, average accuracy: 0.475
epoch 517/1000 finished, 169575 step, elapsed 31.348962545394897 sec. average loss: 2.274, average accuracy: 0.467
epoch 518/1000 finished, 169903 step, elapsed 31.23127794265747 sec. average loss: 2.258, average accuracy: 0.472
epoch 519/1000 finished, 170231 step, elapsed 31.487110137939453 sec. average loss: 2.251, average accuracy: 0.471
epoch 520/1000 finished, 170559 step, elapsed 31.20898175239563 sec. average loss: 2.246, average accuracy: 0.472
epoch 521/1000 finished, 170887 step, elapsed 31.488622188568115 sec. average loss: 2.223, average accuracy: 0.479
epoch 522/1000 finished, 171215 step, elapsed 32.12162375450134 sec. average loss: 2

epoch 586/1000 finished, 192207 step, elapsed 31.395383834838867 sec. average loss: 2.167, average accuracy: 0.487
epoch 587/1000 finished, 192535 step, elapsed 31.340420722961426 sec. average loss: 2.146, average accuracy: 0.493
epoch 588/1000 finished, 192863 step, elapsed 31.460540056228638 sec. average loss: 2.183, average accuracy: 0.483
epoch 589/1000 finished, 193191 step, elapsed 31.064892292022705 sec. average loss: 2.183, average accuracy: 0.483
epoch 590/1000 finished, 193519 step, elapsed 31.315449714660645 sec. average loss: 2.123, average accuracy: 0.495
epoch 591/1000 finished, 193847 step, elapsed 31.228978633880615 sec. average loss: 2.137, average accuracy: 0.493
epoch 592/1000 finished, 194175 step, elapsed 31.47329616546631 sec. average loss: 2.149, average accuracy: 0.487
epoch 593/1000 finished, 194503 step, elapsed 31.30982995033264 sec. average loss: 2.135, average accuracy: 0.490
epoch 594/1000 finished, 194831 step, elapsed 31.626991510391235 sec. average loss

epoch 658/1000 finished, 215823 step, elapsed 31.328503370285034 sec. average loss: 2.073, average accuracy: 0.502
epoch 659/1000 finished, 216151 step, elapsed 31.24043083190918 sec. average loss: 2.049, average accuracy: 0.507
epoch 660/1000 finished, 216479 step, elapsed 31.34628939628601 sec. average loss: 2.071, average accuracy: 0.506
epoch 661/1000 finished, 216807 step, elapsed 31.52840781211853 sec. average loss: 2.078, average accuracy: 0.507
epoch 662/1000 finished, 217135 step, elapsed 31.158100843429565 sec. average loss: 2.042, average accuracy: 0.508
epoch 663/1000 finished, 217463 step, elapsed 31.417835474014282 sec. average loss: 2.070, average accuracy: 0.502
epoch 664/1000 finished, 217791 step, elapsed 31.202744722366333 sec. average loss: 2.097, average accuracy: 0.498
epoch 665/1000 finished, 218119 step, elapsed 31.44320821762085 sec. average loss: 2.059, average accuracy: 0.507
epoch 666/1000 finished, 218447 step, elapsed 31.211437225341797 sec. average loss: 

epoch 730/1000 finished, 239439 step, elapsed 31.260164737701416 sec. average loss: 2.005, average accuracy: 0.518
epoch 731/1000 finished, 239767 step, elapsed 31.394208192825317 sec. average loss: 2.002, average accuracy: 0.514
epoch 732/1000 finished, 240095 step, elapsed 31.329391956329346 sec. average loss: 1.961, average accuracy: 0.521
epoch 733/1000 finished, 240423 step, elapsed 31.549319744110107 sec. average loss: 1.992, average accuracy: 0.518
epoch 734/1000 finished, 240751 step, elapsed 31.328973293304443 sec. average loss: 1.968, average accuracy: 0.525
epoch 735/1000 finished, 241079 step, elapsed 31.38990831375122 sec. average loss: 1.981, average accuracy: 0.520
epoch 736/1000 finished, 241407 step, elapsed 31.244614601135254 sec. average loss: 1.967, average accuracy: 0.522
epoch 737/1000 finished, 241735 step, elapsed 31.31451678276062 sec. average loss: 1.999, average accuracy: 0.518
epoch 738/1000 finished, 242063 step, elapsed 31.289085865020752 sec. average loss

epoch 802/1000 finished, 263055 step, elapsed 31.174572467803955 sec. average loss: 1.920, average accuracy: 0.532
epoch 803/1000 finished, 263383 step, elapsed 31.324732780456543 sec. average loss: 1.934, average accuracy: 0.530
epoch 804/1000 finished, 263711 step, elapsed 31.486884593963623 sec. average loss: 1.924, average accuracy: 0.529
epoch 805/1000 finished, 264039 step, elapsed 31.359790563583374 sec. average loss: 1.904, average accuracy: 0.536
epoch 806/1000 finished, 264367 step, elapsed 31.419434785842896 sec. average loss: 1.899, average accuracy: 0.533
epoch 807/1000 finished, 264695 step, elapsed 31.288922786712646 sec. average loss: 1.929, average accuracy: 0.527
epoch 808/1000 finished, 265023 step, elapsed 31.52194881439209 sec. average loss: 1.936, average accuracy: 0.531
epoch 809/1000 finished, 265351 step, elapsed 31.254759073257446 sec. average loss: 1.891, average accuracy: 0.535
epoch 810/1000 finished, 265679 step, elapsed 31.486626625061035 sec. average los

epoch 874/1000 finished, 286671 step, elapsed 31.25973153114319 sec. average loss: 1.868, average accuracy: 0.538
epoch 875/1000 finished, 286999 step, elapsed 31.250394582748413 sec. average loss: 1.862, average accuracy: 0.543
epoch 876/1000 finished, 287327 step, elapsed 31.518730401992798 sec. average loss: 1.883, average accuracy: 0.538
epoch 877/1000 finished, 287655 step, elapsed 31.505981922149658 sec. average loss: 1.883, average accuracy: 0.536
epoch 878/1000 finished, 287983 step, elapsed 31.394367694854736 sec. average loss: 1.876, average accuracy: 0.537
epoch 879/1000 finished, 288311 step, elapsed 31.41842293739319 sec. average loss: 1.887, average accuracy: 0.539
epoch 880/1000 finished, 288639 step, elapsed 31.38597297668457 sec. average loss: 1.846, average accuracy: 0.545
epoch 881/1000 finished, 288967 step, elapsed 31.499877452850342 sec. average loss: 1.881, average accuracy: 0.534
epoch 882/1000 finished, 289295 step, elapsed 31.164665937423706 sec. average loss:

epoch 946/1000 finished, 310287 step, elapsed 31.545145273208618 sec. average loss: 1.805, average accuracy: 0.553
epoch 947/1000 finished, 310615 step, elapsed 31.56830596923828 sec. average loss: 1.803, average accuracy: 0.553
epoch 948/1000 finished, 310943 step, elapsed 31.3009135723114 sec. average loss: 1.835, average accuracy: 0.546
epoch 949/1000 finished, 311271 step, elapsed 31.525449514389038 sec. average loss: 1.806, average accuracy: 0.553
epoch 950/1000 finished, 311599 step, elapsed 31.23220992088318 sec. average loss: 1.849, average accuracy: 0.541
epoch 951/1000 finished, 311927 step, elapsed 31.393353939056396 sec. average loss: 1.822, average accuracy: 0.548
epoch 952/1000 finished, 312255 step, elapsed 31.45598340034485 sec. average loss: 1.802, average accuracy: 0.550
epoch 953/1000 finished, 312583 step, elapsed 31.537786960601807 sec. average loss: 1.854, average accuracy: 0.542
epoch 954/1000 finished, 312911 step, elapsed 31.74419140815735 sec. average loss: 1.