In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import os
from collections import Counter
from typing import NamedTuple, List, Dict, Tuple
import random
from datetime import datetime
import time
import math

In [14]:
from utils.config import Config
from data_loader.ptb_datasource import PTBDataSource

# config, data

In [15]:
units = [512, 1024]
layers = [4]
lrs = [0.001]
configs = []
for l in layers:
    for u in units:
        for lr in lrs:
            configs.append(Config(num_layers=l, num_units=u, learning_rate=lr))

# model

In [9]:
class RNN:
    
    def __init__(self, config: Config, vocab_size):
        self.config = config
        self.vocab_size = vocab_size
        self._create_placeholder()
        self._create_model()
        self.loss = self._create_loss()
        self.accuracy = self._create_acc()
    
    def _create_placeholder(self):
        self.is_training = tf.placeholder(shape=(), dtype=tf.bool, name='is_training')
        self.inputs = tf.placeholder(shape=[None, self.config.max_length], dtype=tf.int32, name='inputs')
        self.inputs_length = tf.placeholder(shape=[None], dtype=tf.int32, name='inputs_length')
        self.target_ids = tf.placeholder(shape=[None], dtype=tf.int32, name='target_ids')
    
    def _create_model(self):
        self.global_step = tf.train.get_or_create_global_step()
        embedded_inputs = self._embedding(self.inputs)
        _, encoder_state = self._encode(embedded_inputs)
        # encoder_state = tf.layers.dense(encoder_state, num_units, activation=tf.nn.relu, name='hidden_layer')
        self.outputs_logits = tf.layers.dense(encoder_state, self.vocab_size, name='outputs_layer')
        self.predicted_id = tf.to_int32(tf.argmax(self.outputs_logits, axis=-1))
        
    def _create_loss(self):
        is_target = tf.to_float(tf.not_equal(self.target_ids, 0))
        target_ids_one_hot = tf.one_hot(self.target_ids, self.vocab_size)
        target_ids_smoothed = self._label_smoothing(target_ids_one_hot)
        cross_ent = tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.outputs_logits, labels=target_ids_smoothed)
        return tf.reduce_sum(cross_ent * is_target) / tf.reduce_sum(is_target)
        
    def _create_acc(self):
        return tf.reduce_mean(tf.to_float(tf.equal(self.target_ids, self.predicted_id)))
    
    def _embedding(self, inputs):
        lookup_table = tf.get_variable('lookup_table', shape=[self.vocab_size, self.config.embedding_size], dtype=tf.float32)
        embedded_inputs = tf.nn.embedding_lookup(lookup_table, inputs)
        return embedded_inputs
    
    def _encode(self, embedded_inputs):
        outputs, final_state = self._bidirectional_cell(
            embedded_inputs,
            self.config.num_layers,
            self.config.num_units,
            self.config.dropout_in_rate,
            self.config.dropout_out_rate
        )
        return outputs, final_state
    
    def _bidirectional_cell(self, inputs, num_layers, num_units, dropout_in_rate, dropout_out_rate):
        cell_fw = self._gru(num_layers, num_units, dropout_in_rate, dropout_out_rate, name='cell_fw')
        cell_bw = self._gru(num_layers, num_units, dropout_in_rate, dropout_out_rate, name='cell_bw')
        (fw_outputs, bw_outputs), (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell_fw,
            cell_bw=cell_bw,
            inputs=inputs,
            sequence_length=self.inputs_length,
            dtype=tf.float32,
            scope='bidirectional_cells')
        outputs = tf.concat([fw_outputs, bw_outputs], axis=-1)
        final_state = tf.reduce_sum([fw_state, bw_state], axis=0)
        final_state = tf.concat(tf.unstack(final_state, axis=0), axis=-1)
        return outputs, final_state
    
    def _gru(self, num_layers: int, num_units: int, dropout_in_rate: float, dropout_out_rate: float, name: str):
        cells = []
        for l in range(num_layers):
            cell = tf.nn.rnn_cell.GRUCell(num_units, tf.nn.relu, kernel_initializer=tf.contrib.layers.xavier_initializer(), name=name)
            if l == 0:
                cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob=1-dropout_in_rate)
            if l == num_layers-1:
                cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1-dropout_out_rate)
            cells.append(cell)
        return tf.nn.rnn_cell.MultiRNNCell(cells)
    
    def _label_smoothing(self, inputs, epsilon: float=0.1):
        feature_dim = inputs.get_shape().as_list()[-1]
        return (1-epsilon) * inputs + (epsilon / feature_dim)

In [10]:
num_epoch = 200

In [11]:
def start():
    with tf.Graph().as_default():
        now = datetime.now()
        logdir = now.strftime("%Y%m%d-%H%M%S") + "/"

        datasource = PTBDataSource(config)

        rnn = RNN(config, datasource.vocab_size)
        optimizer = tf.train.AdamOptimizer(config.learning_rate)
        train_vars = tf.trainable_variables()
        gradients = tf.gradients(rnn.loss, train_vars)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, config.grad_clip)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, train_vars), global_step=rnn.global_step)
        with tf.name_scope('training'):
            s_loss = tf.summary.scalar('loss', rnn.loss)
            s_acc = tf.summary.scalar('accuracy', rnn.accuracy)
        with tf.name_scope('test'):
            test_s_acc = tf.summary.scalar('accuracy', rnn.accuracy)

        with tf.Session() as sess:
            saver = tf.train.Saver()
            writer = tf.summary.FileWriter(config.to_log_dir() + '/' + logdir, sess.graph)
            sess.run(tf.global_variables_initializer())
            for i in range(num_epoch):
                start = time.time()
                datasource.shuffle()
                batch_list = datasource.feed_dict_list(rnn)
                losses = []
                accuracies = []
                for (j, fd) in enumerate(batch_list):
                    loss, acc, _, smr_loss, smr_acc, step = sess.run([rnn.loss, rnn.accuracy, train_op, s_loss, s_acc, rnn.global_step], feed_dict=fd)
                    losses.append(loss)
                    accuracies.append(acc)
                    writer.add_summary(smr_loss, step)
                    writer.add_summary(smr_acc, step)
                    if j % 100 == 0:
                        #print('loss: {:.3f}, acc: {:.3f}'.format(loss, acc))
                        inference(sess, rnn, datasource, writer, test_s_acc, step)
                elapsed = time.time() - start
                print('epoch {}/{} finished, {} step, elapsed {} sec. average loss: {:.3f}, average accuracy: {:.3f}'.format(i+1, num_epoch, step, elapsed, np.average(losses), np.average(accuracies)))
                # loss が nan なら 飛ばす
                if math.isnan(np.average(losses)):
                    print('loss is nan')
                    break
                saver.save(sess, config.to_ckpt_path(), global_step=step)

In [12]:
def inference(sess, model, datasource, writer, s_acc, step):
    with tf.name_scope('inference'):
        test_list = datasource.feed_test_list(model)
        acc, smr_acc = sess.run([model.accuracy, s_acc], feed_dict=test_list)
        writer.add_summary(smr_acc, step)

In [None]:
for config in configs:
    start()

Instructions for updating:
seq_dim is deprecated, use seq_axis instead
Instructions for updating:
batch_dim is deprecated, use batch_axis instead
epoch 1/200 finished, 327 step, elapsed 63.05414271354675 sec. average loss: 7.291, average accuracy: 0.066
epoch 2/200 finished, 655 step, elapsed 62.784444093704224 sec. average loss: 6.763, average accuracy: 0.122
epoch 3/200 finished, 983 step, elapsed 64.81405925750732 sec. average loss: 6.470, average accuracy: 0.148
epoch 4/200 finished, 1311 step, elapsed 64.68730759620667 sec. average loss: 6.336, average accuracy: 0.161
epoch 5/200 finished, 1639 step, elapsed 66.68140697479248 sec. average loss: 6.196, average accuracy: 0.175
epoch 6/200 finished, 1967 step, elapsed 65.44995546340942 sec. average loss: 6.112, average accuracy: 0.184
epoch 7/200 finished, 2295 step, elapsed 65.29462242126465 sec. average loss: 6.018, average accuracy: 0.190
epoch 8/200 finished, 2623 step, elapsed 64.58694672584534 sec. average loss: 5.995, average 

epoch 69/200 finished, 22631 step, elapsed 65.65715432167053 sec. average loss: 4.435, average accuracy: 0.345
epoch 70/200 finished, 22959 step, elapsed 65.47271466255188 sec. average loss: 4.390, average accuracy: 0.357
epoch 71/200 finished, 23287 step, elapsed 65.68262529373169 sec. average loss: 4.416, average accuracy: 0.351
epoch 72/200 finished, 23615 step, elapsed 65.07423663139343 sec. average loss: 4.370, average accuracy: 0.355
epoch 73/200 finished, 23943 step, elapsed 65.46194982528687 sec. average loss: 4.410, average accuracy: 0.351
epoch 76/200 finished, 24927 step, elapsed 64.86922359466553 sec. average loss: 4.303, average accuracy: 0.366
epoch 77/200 finished, 25255 step, elapsed 66.1532244682312 sec. average loss: 4.320, average accuracy: 0.365
epoch 78/200 finished, 25583 step, elapsed 65.49885034561157 sec. average loss: 4.285, average accuracy: 0.373
epoch 79/200 finished, 25911 step, elapsed 65.18831014633179 sec. average loss: 4.344, average accuracy: 0.361
ep

epoch 145/200 finished, 47559 step, elapsed 65.75812411308289 sec. average loss: 3.641, average accuracy: 0.477
epoch 146/200 finished, 47887 step, elapsed 66.68588399887085 sec. average loss: 3.628, average accuracy: 0.477
epoch 147/200 finished, 48215 step, elapsed 65.21528911590576 sec. average loss: 3.684, average accuracy: 0.472
epoch 148/200 finished, 48543 step, elapsed 65.27357459068298 sec. average loss: 3.628, average accuracy: 0.480
epoch 149/200 finished, 48871 step, elapsed 65.90938758850098 sec. average loss: 3.640, average accuracy: 0.478
epoch 150/200 finished, 49199 step, elapsed 64.8878698348999 sec. average loss: 3.642, average accuracy: 0.476
epoch 151/200 finished, 49527 step, elapsed 65.18003988265991 sec. average loss: 3.561, average accuracy: 0.492
epoch 152/200 finished, 49855 step, elapsed 66.3232204914093 sec. average loss: 3.627, average accuracy: 0.478
epoch 153/200 finished, 50183 step, elapsed 65.18017244338989 sec. average loss: 3.572, average accuracy: 

epoch 19/200 finished, 6231 step, elapsed 177.60658025741577 sec. average loss: 5.352, average accuracy: 0.252
epoch 20/200 finished, 6559 step, elapsed 176.48895263671875 sec. average loss: 5.350, average accuracy: 0.252
epoch 21/200 finished, 6887 step, elapsed 176.93118166923523 sec. average loss: 5.318, average accuracy: 0.255
epoch 22/200 finished, 7215 step, elapsed 175.54602932929993 sec. average loss: 5.291, average accuracy: 0.255
epoch 23/200 finished, 7543 step, elapsed 177.23572278022766 sec. average loss: 5.270, average accuracy: 0.255
epoch 24/200 finished, 7871 step, elapsed 173.88826298713684 sec. average loss: 5.220, average accuracy: 0.263
epoch 25/200 finished, 8199 step, elapsed 175.53213095664978 sec. average loss: 5.161, average accuracy: 0.265
epoch 26/200 finished, 8527 step, elapsed 174.46259355545044 sec. average loss: 5.144, average accuracy: 0.266
epoch 27/200 finished, 8855 step, elapsed 174.3110921382904 sec. average loss: 5.142, average accuracy: 0.266
ep