In [1]:
from utils.config import Config
from dataloader.docomo_datasource import DocomoDataSource
from models.transformer import Transformer

  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
import tensorflow as tf
import numpy as np
import matplotlib as plt

In [3]:
#config = Config()
#ds = DocomoDataSource(config)
#ds.vocab_size

In [4]:
units = [256, 512]
layers = [2, 4, 6]
lrs = [0.001]
configs = []
for l in layers:
    for u in units:
        for lr in lrs:
            configs.append(Config(num_layers=l, num_units=u, learning_rate=lr, log_dir='./logs/transformer/'))

In [5]:
config = Config(num_layers=4, num_units=512, learning_rate=0.001, log_dir='./logs/ut/')

In [6]:
from datetime import datetime
from typing import Optional
import tensorflow as tf
from utils import transformer
from utils import utils


class UnivTransformer:

    def __init__(self, config, scope='universal_transformer', reuse: bool=None) -> None:

        # instance var
        self.config = config
        self.scope = scope

        self.build_model(scope, reuse)
        self.init_global_step()
        self.init_saver()

    def init_saver(self):
        self.saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope))

    def init_global_step(self):
        self.global_step = tf.train.get_or_create_global_step()

    def save(self, sess):
        self.saver.save(sess, self.config.checkpoint_path, self.global_step)

    def load(self, sess: tf.Session, path: str) -> None:
        '''
        モデルを読み込みます。
        :param sess: セッション
        :param path: モデルのパス。以下のいずれか
            - S3 上のファイルパス(s3://.../model.ckpt)
            - ローカルのディレクトリパス：ディレクトリをチェックポイントディレクトリとみなし
              最新のチェックポイントを読み込みます
            - ローカルのファイルパス： .../model.ckpt
        '''
        model_path = utils.get_model_path(path)
        if model_path:
            print("Loading model {} ...\n".format(model_path))
            self.saver.restore(sess, model_path)
            print("Model loaded")

    def build_model(self, scope, reuse: Optional[bool]):
        with tf.variable_scope(scope, reuse=reuse):
            # placeholder
            self.is_training = tf.placeholder(dtype=tf.bool, name='is_training')
            self.encoder_inputs = tf.placeholder(
                dtype=tf.int32,
                shape=[None, self.config.max_length],
                name='encoder_inputs'
            )
            self.decoder_targets = tf.placeholder(
                dtype=tf.int32,
                shape=[None, self.config.max_length],
                name='decoder_targets'
            )
            self.decoder_inputs = tf.placeholder(
                dtype=tf.int32,
                shape=[None, self.config.max_length],
                name='decoder_inputs'
            )

            # building
            sent_encoder_inputs_embedded = self._encoder()
            self.decoder_logits = self._decoder(sent_encoder_inputs_embedded, self.decoder_inputs)

        is_target = tf.to_float(tf.not_equal(self.decoder_targets, 0))
        # loss
        decoder_targets_one_hot = tf.one_hot(self.decoder_targets, self.config.vocab_size)
        decoder_targets_smoothed = utils.label_smoothing(decoder_targets_one_hot)
        cross_ents = tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.decoder_logits,
            labels=decoder_targets_smoothed
        )
        #self.loss = tf.reduce_sum(cross_ents * is_target) / tf.reduce_sum(is_target)
        self.loss = tf.reduce_mean(cross_ents)

        # acc
        predicted_ids = tf.to_int32(tf.argmax(self.decoder_logits, axis=2))
        correct = tf.equal(predicted_ids, self.decoder_targets)
        #self.accuracy = tf.reduce_sum(tf.to_float(correct)*is_target) / (tf.reduce_sum(is_target))
        self.accuracy = tf.reduce_mean(tf.to_float(correct))

    def _encoder(self):
        with tf.variable_scope('encoder'):
            encoder_inputs_embedded = transformer.embedding(
                self.encoder_inputs,
                self.config.vocab_size,
                self.config.num_units,
                is_scale=True,
                scope='enc_embed'
            )
            encoder_inputs_embedded += transformer.positional_encoding(
                self.encoder_inputs,
                num_units=self.config.num_units,
                is_zero_pad=True,
            )
            encoder_inputs_embedded = tf.layers.dropout(
                encoder_inputs_embedded,
                rate=self.config.dropout_in_rate,
                training=self.is_training
            )

            for i in range(self.config.num_pre_layers):
                with tf.variable_scope('pre_blocks_{}'.format(i)):
                    encoder_inputs_embedded = transformer.multihead_attention(
                        queries=encoder_inputs_embedded,
                        keys=encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=False
                    )

                    encoder_inputs_embedded = transformer.feedforward(
                        encoder_inputs_embedded,
                        num_units=[4*self.config.num_units, self.config.num_units],
                        scope='hier_feedforward'
                    )
            for i in range(self.config.num_layers):
                with tf.variable_scope('share_blocks'):
                    encoder_inputs_embedded += transformer.positional_encoding2(
                        encoder_inputs_embedded,
                        is_zero_pad=True
                    )
                    encoder_inputs_embedded += transformer.step_encoding(
                        encoder_inputs_embedded,
                        i,
                        is_zero_pad=True
                    )
                    encoder_inputs_embedded = transformer.multihead_attention(
                        queries=encoder_inputs_embedded,
                        keys=encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=False
                    )

                    encoder_inputs_embedded = transformer.feedforward(
                        encoder_inputs_embedded,
                        num_units=[4*self.config.num_units, self.config.num_units],
                        scope='hier_feedforward'
                    )
            for i in range(self.config.num_post_layers):
                with tf.variable_scope('post_blocks_{}'.format(i)):
                    encoder_inputs_embedded = transformer.multihead_attention(
                        queries=encoder_inputs_embedded,
                        keys=encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=False
                    )

                    encoder_inputs_embedded = transformer.feedforward(
                        encoder_inputs_embedded,
                        num_units=[4*self.config.num_units, self.config.num_units],
                        scope='hier_feedforward'
                    )

            return encoder_inputs_embedded

    def _decoder(self, hier_encoder_inputs_embedded, decoder_inputs):
        with tf.variable_scope('decoder'):
            decoder_inputs_embedded = transformer.embedding(
                decoder_inputs,
                vocab_size=self.config.vocab_size,
                num_units=self.config.num_units,
                is_scale=True,
                scope='dec_embed'
            )
            decoder_inputs_embedded += transformer.positional_encoding(
                decoder_inputs,
                num_units=self.config.num_units,
                is_zero_pad=True,
            )

            for i in range(self.config.num_pre_layers):
                with tf.variable_scope('pre_blocks_{}'.format(i)):
                    decoder_inputs_embedded = transformer.multihead_attention(
                        queries=decoder_inputs_embedded,
                        keys=decoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=True,
                        scope='self_attention'
                    )
                    decoder_inputs_embedded = transformer.multihead_attention(
                        queries=decoder_inputs_embedded,
                        keys=hier_encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=False,
                        scope='vanilla_attention'
                    )

                    decoder_inputs_embedded = transformer.feedforward(
                        decoder_inputs_embedded,
                        num_units=[4*self.config.num_units, self.config.num_units]
                    )

            for i in range(self.config.num_layers):
                with tf.variable_scope('shared_blocks'):
                    decoder_inputs_embedded += transformer.positional_encoding2(
                        decoder_inputs_embedded,
                        is_zero_pad=True
                    )
                    decoder_inputs_embedded += transformer.step_encoding(
                        decoder_inputs_embedded,
                        i,
                        is_zero_pad=True
                    )
                    decoder_inputs_embedded = transformer.multihead_attention(
                        queries=decoder_inputs_embedded,
                        keys=decoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=True,
                        scope='self_attention'
                    )
                    decoder_inputs_embedded = transformer.multihead_attention(
                        queries=decoder_inputs_embedded,
                        keys=hier_encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=False,
                        scope='vanilla_attention'
                    )

                    decoder_inputs_embedded = transformer.feedforward(
                        decoder_inputs_embedded,
                        num_units=[4*self.config.num_units, self.config.num_units]
                    )

            for i in range(self.config.num_post_layers):
                with tf.variable_scope('post_blocks_{}'.format(i)):
                    decoder_inputs_embedded = transformer.multihead_attention(
                        queries=decoder_inputs_embedded,
                        keys=decoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=True,
                        scope='self_attention'
                    )
                    decoder_inputs_embedded = transformer.multihead_attention(
                        queries=decoder_inputs_embedded,
                        keys=hier_encoder_inputs_embedded,
                        is_training=self.is_training,
                        dropout_rate=self.config.dropout_in_rate,
                        num_units=self.config.num_units,
                        num_heads=self.config.num_heads,
                        is_causality=False,
                        scope='vanilla_attention'
                    )

                    decoder_inputs_embedded = transformer.feedforward(
                        decoder_inputs_embedded,
                        num_units=[4*self.config.num_units, self.config.num_units]
                    )

            decoder_logits = tf.layers.dense(decoder_inputs_embedded, self.config.vocab_size)

            return decoder_logits


In [7]:
num_epochs = 20

In [8]:
def run(configs, gpu_index, num_epochs=100):
    for config in configs:
        with tf.Graph().as_default():
            with tf.device('/gpu:{}'.format(gpu_index)):
                ds = DocomoDataSource(config)
                model = Transformer(config, 'transformer')

                global_step = tf.train.get_or_create_global_step()
                optimizer = tf.train.AdamOptimizer(config.learning_rate)
                train_op = optimizer.minimize(model.loss, global_step=global_step)


                with tf.name_scope('summary'):
                    loss_smr = tf.summary.scalar('loss', model.loss)
                    acc_smr = tf.summary.scalar('acc', model.accuracy)
                    merged_summary = tf.summary.merge_all()

                tf_config = tf.ConfigProto(
                    allow_soft_placement=True,
                    gpu_options=tf.GPUOptions(
                        allow_growth=True
                    )
                )
                with tf.Session(config=tf_config) as sess:
                    writer = tf.summary.FileWriter(model.config.to_log_dir() , sess.graph)
                    sess.run(tf.global_variables_initializer())
                    for epoch in range(num_epochs):
                        ds.shuffle()
                        batch_list = ds.feed_dict(model, model.config.batch_size, is_transformer=True)
                        for fd in batch_list:
                            fd[model.is_training] = True
                            _, step, loss, acc, smr = sess.run([train_op, global_step, model.loss, model.accuracy, merged_summary], feed_dict=fd)
                            #step = sess.run(global_step)
                            writer.add_summary(smr, step)
                            #if step % 100 == 0:
                            #print('step: {}, loss: {:.3f}, acc: {:.3f}'.format(step, loss, acc))
                        print('epoch {}/{} finished.'.format(epoch+1, num_epochs))

In [9]:
run([config], 0, num_epochs)

epoch 1/20 finished.
epoch 2/20 finished.
epoch 3/20 finished.
epoch 4/20 finished.
epoch 5/20 finished.
epoch 6/20 finished.
epoch 7/20 finished.
epoch 8/20 finished.
epoch 9/20 finished.
epoch 10/20 finished.
epoch 11/20 finished.
epoch 12/20 finished.
epoch 13/20 finished.
epoch 14/20 finished.
epoch 15/20 finished.
epoch 16/20 finished.
epoch 17/20 finished.
epoch 18/20 finished.
epoch 19/20 finished.
epoch 20/20 finished.
