<a href="https://colab.research.google.com/github/data-better/ASL/blob/master/12%EA%B0%95_transformer_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Transfomer implementation
  - https://github.com/suyash/transformer

In [None]:
import os

from absl import app, flags, logging
import tensorflow as tf
from tensorflow.keras import Model  # pylint: disable=import-error
from tensorflow.keras.callbacks import TensorBoard  # pylint: disable=import-error
from tensorflow.keras.layers import Dense, GlobalAveragePooling1D, Input  # pylint: disable=import-error
import tensorflow_datasets as tfds

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Add, Dense, Dropout, Embedding, Layer, LayerNormalization, Multiply, Permute, Reshape  # pylint: disable=import-error


class Transformer:
    def __init__(self,
                 num_layers,
                 d_model,
                 num_heads,
                 d_ff,
                 input_vocab_size,
                 target_vocab_size,
                 dropout_rate,
                 scope="transformer"):
        self.encoder = Encoder(num_layers=num_layers,
                               d_model=d_model,
                               num_heads=num_heads,
                               d_ff=d_ff,
                               vocab_size=input_vocab_size,
                               dropout_rate=dropout_rate,
                               scope="%s/encoder" % scope)

        self.decoder = Decoder(num_layers=num_layers,
                               d_model=d_model,
                               num_heads=num_heads,
                               d_ff=d_ff,
                               vocab_size=target_vocab_size,
                               dropout_rate=dropout_rate,
                               scope="%s/decoder" % scope)

        self.final_layer = Dense(target_vocab_size,
                                 activation=None,
                                 name="%s/dense" % scope)

        self.padding_mask = PaddingMask(name="%s/padding_mask" % scope)
        self.lookahead_mask = PaddingAndLookaheadMask(
            name="%s/lookahead_mask" % scope)

    def __call__(self, inputs, target):
        padding_mask = self.padding_mask(inputs)
        lookahead_mask = self.lookahead_mask(target)

        enc_output, enc_attention = self.encoder(inputs, padding_mask)

        dec_output, dec_attention, enc_dec_attention = self.decoder(
            target, enc_output, lookahead_mask, padding_mask)

        final_output = self.final_layer(dec_output)

        return final_output, enc_attention, dec_attention, enc_dec_attention


class Decoder:
    def __init__(self,
                 num_layers,
                 d_model,
                 num_heads,
                 d_ff,
                 vocab_size,
                 dropout_rate,
                 scope="decoder"):
        self.d_model = d_model
        self.num_layers = num_layers
        self.scope = scope

        self.embedding = Embedding(input_dim=vocab_size,
                                   output_dim=d_model,
                                   name="%s/embedding" % scope)
        self.pos_encoding = PositionalEncoding(d_model,
                                               name="%s/positional_encoding" %
                                               scope)

        self.dec_layers = [
            DecoderLayer(d_model=d_model,
                         num_heads=num_heads,
                         d_ff=d_ff,
                         dropout_rate=dropout_rate,
                         scope="%s/decoder_layer_%d" % (scope, i))
            for i in range(num_layers)
        ]

        self.dropout = Dropout(dropout_rate, name="%s/dropout" % self.scope)

    def __call__(self, x, enc_output, lookahead_mask, padding_mask):
        x = self.embedding(x)
        x = MultiplyConstant(self.d_model, name="%s/multiply" % self.scope)(x)
        x = Add(name="%s/add" % self.scope)([x, self.pos_encoding(x)])

        dec_attention_weights = {}
        enc_dec_attention_weights = {}

        for i in range(self.num_layers):
            x, dec_attention, enc_dec_attention = self.dec_layers[i](
                x, enc_output, lookahead_mask, padding_mask)

            dec_attention_weights["layer_%d" % i] = dec_attention
            enc_dec_attention_weights["layer_%d" % i] = enc_dec_attention

        return x, dec_attention_weights, enc_dec_attention_weights


class Encoder:
    def __init__(self,
                 num_layers,
                 d_model,
                 num_heads,
                 d_ff,
                 vocab_size,
                 dropout_rate,
                 scope="encoder"):
        self.d_model = d_model
        self.num_layers = num_layers
        self.scope = scope

        self.embedding = Embedding(input_dim=vocab_size,
                                   output_dim=d_model,
                                   name="%s/embedding" % scope)
        self.pos_encoding = PositionalEncoding(d_model,
                                               name="%s/positional_encoding" %
                                               scope)

        self.enc_layers = [
            EncoderLayer(d_model=d_model,
                         num_heads=num_heads,
                         d_ff=d_ff,
                         dropout_rate=dropout_rate,
                         scope="%s/encoder_layer_%d" % (scope, i))
            for i in range(num_layers)
        ]

        self.dropout = Dropout(dropout_rate, name="%s/dropout" % self.scope)

    def __call__(self, x, padding_mask):
        x = self.embedding(x)
        x = MultiplyConstant(self.d_model, name="%s/multiply" % self.scope)(x)
        x = Add(name="%s/add" % self.scope)([x, self.pos_encoding(x)])

        enc_attention_weights = {}

        for i in range(self.num_layers):
            x, enc_attention = self.enc_layers[i](x, padding_mask)
            enc_attention_weights["layer_%d" % i] = enc_attention

        return x, enc_attention_weights


class DecoderLayer:
    def __init__(self,
                 d_model,
                 num_heads,
                 d_ff,
                 dropout_rate,
                 scope="decoder_layer"):
        self.scope = scope

        self.mha1 = MultiHeadAttention(d_model,
                                       num_heads,
                                       scope="%s/multi_head_attention_1" %
                                       scope)
        self.mha2 = MultiHeadAttention(d_model,
                                       num_heads,
                                       scope="%s/multi_head_attention_2" %
                                       scope)
        self.ffn = PointwiseFeedForwardNetwork(
            d_model, d_ff, scope="%s/pointwise_feed_forward_network" % scope)

        self.layernorm1 = LayerNormalization(epsilon=1e-6,
                                             name="%s/layer_norm_1" % scope)
        self.layernorm2 = LayerNormalization(epsilon=1e-6,
                                             name="%s/layer_norm_2" % scope)
        self.layernorm3 = LayerNormalization(epsilon=1e-6,
                                             name="%s/layer_norm_3" % scope)

        self.dropout1 = Dropout(dropout_rate, name="%s/dropout_1" % scope)
        self.dropout2 = Dropout(dropout_rate, name="%s/dropout_2" % scope)
        self.dropout3 = Dropout(dropout_rate, name="%s/dropout_3" % scope)

    def __call__(self, x, enc_output, lookahead_mask, padding_mask):
        out1, dec_dec_attention = self.mha1(x, x, x, lookahead_mask)
        out1 = self.dropout1(out1)
        x = Add(name="%s/add_1" % self.scope)([x, out1])
        x = self.layernorm1(x)

        out2, enc_dec_attention = self.mha2(x, enc_output, enc_output,
                                            padding_mask)
        out2 = self.dropout2(out2)
        x = Add(name="%s/add_2" % self.scope)([x, out2])
        x = self.layernorm2(x)

        ffn_output = self.ffn(x)
        ffn_output = self.dropout3(ffn_output)
        x = Add(name="%s/add_3" % self.scope)([x, ffn_output])
        x = self.layernorm3(x)

        return x, dec_dec_attention, enc_dec_attention


class EncoderLayer:
    def __init__(self,
                 d_model,
                 num_heads,
                 d_ff,
                 dropout_rate,
                 scope="encoder_layer"):
        self.scope = scope

        self.mha1 = MultiHeadAttention(d_model,
                                       num_heads,
                                       scope="%s/multi_head_attention_1" %
                                       scope)
        self.ffn = PointwiseFeedForwardNetwork(
            d_model, d_ff, scope="%s/pointwise_feed_forward_network" % scope)

        self.layernorm1 = LayerNormalization(epsilon=1e-6,
                                             name="%s/layer_norm_1" % scope)
        self.layernorm2 = LayerNormalization(epsilon=1e-6,
                                             name="%s/layer_norm_2" % scope)

        self.dropout1 = Dropout(dropout_rate, name="%s/dropout_1" % scope)
        self.dropout2 = Dropout(dropout_rate, name="%s/dropout_2" % scope)

    def __call__(self, x, padding_mask):
        out1, enc_enc_attention = self.mha1(x, x, x, padding_mask)
        out1 = self.dropout1(out1)
        x = Add(name="%s/add_1" % self.scope)([x, out1])
        x = self.layernorm1(x)

        ffn_output = self.ffn(x)
        ffn_output = self.dropout2(ffn_output)
        x = Add(name="%s/add_2" % self.scope)([x, ffn_output])
        x = self.layernorm2(x)

        return x, enc_enc_attention


class PointwiseFeedForwardNetwork:
    def __init__(self, d_model, d_ff, scope="pointwise_feed_forward_network"):
        self.dense_1 = Dense(d_ff,
                             activation="relu",
                             name="%s/dense_1" % scope)
        self.dense_2 = Dense(d_model,
                             activation=None,
                             name="%s/dense_2" % scope)

    def __call__(self, x):
        return self.dense_2(self.dense_1(x))


class MultiHeadAttention:
    def __init__(self, d_model, num_heads, scope="multi_head_attention"):
        assert d_model % num_heads == 0

        self.wq = Dense(d_model, name="%s/dense_q" % scope)
        self.wk = Dense(d_model, name="%s/dense_k" % scope)
        self.wv = Dense(d_model, name="%s/dense_v" % scope)

        self.reshapeq = Reshape((-1, num_heads, d_model // num_heads),
                                name="%s/reshape_q" % scope)
        self.reshapek = Reshape((-1, num_heads, d_model // num_heads),
                                name="%s/reshape_k" % scope)
        self.reshapev = Reshape((-1, num_heads, d_model // num_heads),
                                name="%s/reshape_v" % scope)

        self.transposeq = Permute((2, 1, 3), name="%s/transpose_q" % scope)
        self.transposek = Permute((2, 1, 3), name="%s/transpose_k" % scope)
        self.transposev = Permute((2, 1, 3), name="%s/transpose_v" % scope)

        self.reshape_output = Reshape((-1, d_model),
                                      name="%s/reshape_output" % scope)

        self.transpose_output = Permute((2, 1, 3),
                                        name="%s/transpose_output" % scope)

        self.dense = Dense(d_model, name="%s/dense" % scope)

        self.attention = Attention(name="%s/attention" % scope)

    def __call__(self, q, k, v, mask):
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.reshapeq(q)
        k = self.reshapek(k)
        v = self.reshapev(v)

        q = self.transposeq(q)
        k = self.transposek(k)
        v = self.transposev(v)

        x, attention_weights = self.attention([q, k, v, mask])

        x = self.transpose_output(x)
        x = self.reshape_output(x)
        x = self.dense(x)

        return x, attention_weights


class Attention(Layer):
    def call(self, inputs):
        q, k, v, mask = inputs

        matmul_qk = tf.matmul(q, k, transpose_b=True)

        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        scaled_attention_logits += mask * -1e9

        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

        output = tf.matmul(attention_weights, v)

        return output, attention_weights


class PositionalEncoding(Layer):
    def __init__(self, d_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.d_model = d_model

    def call(self, inputs):
        position = tf.shape(inputs)[1]

        position_dims = tf.range(position)[:, tf.newaxis]
        embed_dims = tf.range(self.d_model)[tf.newaxis, :]
        angle_rates = 1 / tf.pow(
            10000.0, tf.cast(
                (2 * (embed_dims // 2)) / self.d_model, tf.float32))
        angle_rads = tf.cast(position_dims, tf.float32) * angle_rates

        sines = tf.sin(angle_rads[:, 0::2])
        cosines = tf.cos(angle_rads[:, 1::2])

        pos_encoding = tf.concat([sines, cosines], axis=-1)
        pos_encoding = pos_encoding[tf.newaxis, ...]
        return tf.cast(pos_encoding, tf.float32)

    def get_config(self):
        base = super().get_config()
        return dict(list(base.items()) + [("d_model", self.d_model)])


class MultiplyConstant(Layer):
    def __init__(self, c, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.c = c

    def call(self, inputs):
        return inputs * self.c

    def get_config(self):
        base = super().get_config()
        return dict(list(base.items()) + [("c", self.c)])


class PaddingMask(Layer):
    def call(self, inputs):
        seq = tf.cast(tf.math.equal(inputs, 0), tf.float32)
        return seq[:, tf.newaxis, tf.newaxis, :]


class PaddingAndLookaheadMask(Layer):
    def call(self, inputs):
        size = tf.shape(inputs)[1]
        lhm = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)

        seq = tf.cast(tf.math.equal(inputs, 0), tf.float32)
        seq = seq[:, tf.newaxis, tf.newaxis, :]

        return tf.maximum(lhm, seq)


In [None]:
def main(_):
    data, info = tfds.load("imdb_reviews/subwords8k",
                           with_info=True,
                           as_supervised=True,
                           data_dir=flags.FLAGS.tfds_data_dir)

    train_data, test_data = data[tfds.Split.TRAIN], data[tfds.Split.TEST]

    train_data = train_data.filter(
        lambda x, y: tf.shape(x)[0] < flags.FLAGS.max_len)
    train_data = train_data \
        .padded_batch(flags.FLAGS.batch_size, train_data.output_shapes) \
        .shuffle(flags.FLAGS.shuffle_buffer_size) \
        .repeat()

    test_data = test_data.filter(
        lambda x, y: tf.shape(x)[0] < flags.FLAGS.max_len)
    test_data = test_data \
        .padded_batch(flags.FLAGS.batch_size, test_data.output_shapes)

    vocab_size = info.features["text"].encoder.vocab_size

    inp = Input((None, ), dtype="int32", name="inp")
    mask = PaddingMask()(inp)
    net, enc_enc_attention_weights = Encoder(
        num_layers=flags.FLAGS.num_layers,
        d_model=flags.FLAGS.d_model,
        num_heads=flags.FLAGS.num_heads,
        d_ff=flags.FLAGS.d_ff,
        vocab_size=vocab_size,
        dropout_rate=flags.FLAGS.dropout_rate)(inp, mask)
    net = GlobalAveragePooling1D()(net)
    net = Dense(1, activation="sigmoid")(net)

    learning_rate = CustomSchedule(flags.FLAGS.d_model)
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    loss_object = tf.keras.losses.BinaryCrossentropy()

    if flags.FLAGS.use_custom_training_loop:
        model = Model(inputs=inp, outputs=[net, enc_enc_attention_weights])
        model.summary()

        train(train_data=train_data,
              validation_data=test_data,
              model=model,
              loss_object=loss_object,
              optimizer=optimizer,
              max_steps=flags.FLAGS.epochs * flags.FLAGS.steps_per_epoch,
              save_summary_steps=flags.FLAGS.steps_per_epoch,
              validation_steps=flags.FLAGS.validation_steps,
              job_dir=flags.FLAGS["job-dir"].value)
    else:
        model = Model(inputs=inp, outputs=net)
        model.summary()

        model.compile(optimizer=optimizer,
                      loss=loss_object,
                      metrics=[tf.keras.metrics.BinaryAccuracy()])

        model.fit(train_data,
                  epochs=flags.FLAGS.epochs,
                  steps_per_epoch=flags.FLAGS.steps_per_epoch,
                  validation_data=test_data,
                  validation_steps=flags.FLAGS.validation_steps,
                  callbacks=[
                      TensorBoard(log_dir=flags.FLAGS["job-dir"].value),
                  ])

    model.save(os.path.join(flags.FLAGS["job-dir"].value, "model"))

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps**-1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        return dict([("d_model", self.d_model.numpy()),
                     ("warmup_steps", self.warmup_steps)])


In [None]:
@tf.function
def train_step(inp, tar, model, loss_object, optimizer, loss_mean, acc):
    with tf.GradientTape() as tape:
        out, _ = model(inp, training=True)
        loss = loss_object(y_true=tf.expand_dims(tar, 1), y_pred=out)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    loss_mean(loss)
    acc(y_true=tar, y_pred=out)


In [None]:
def train(train_data, validation_data, model, loss_object, optimizer,
          max_steps, save_summary_steps, validation_steps, job_dir):
    loss_mean = tf.keras.metrics.Mean()
    acc = tf.keras.metrics.BinaryAccuracy()

    with tf.summary.create_file_writer(job_dir).as_default():  # pylint: disable=not-context-manager
        for step, (inputs, outputs) in enumerate(train_data):
            train_step(inputs,
                       outputs,
                       model=model,
                       loss_object=loss_object,
                       optimizer=optimizer,
                       loss_mean=loss_mean,
                       acc=acc)

            if step % save_summary_steps == 0:
                logging.info("Step: %d: Loss: %f, Accuracy: %f", step,
                             loss_mean.result(), acc.result())
                tf.summary.scalar("Train Loss", loss_mean.result(), step=step)
                tf.summary.scalar("Train Accuracy", acc.result(), step=step)

                loss_mean.reset_states()
                acc.reset_states()

                current_validation_step = 0
                for current_validation_step, (
                        x, y_true) in enumerate(validation_data):
                    y_pred, _ = model(x, training=False)
                    loss = loss_object(y_true=tf.expand_dims(y_true, 1),
                                       y_pred=y_pred)
                    loss_mean(loss)
                    acc(y_true, y_pred)

                    if current_validation_step >= validation_steps:
                        break

                logging.info(
                    "Step: %d, validation_loss: %f, validation accuracy: %f",
                    step, loss_mean.result(), acc.result())
                tf.summary.scalar("Validation Loss",
                                  loss_mean.result(),
                                  step=step)
                tf.summary.scalar("Validation Accuracy",
                                  acc.result(),
                                  step=step)
                loss_mean.reset_states()
                acc.reset_states()

            if step >= max_steps:
                break

In [None]:
flags = app.flags
FLAGS = flags.FLAGS

app.flags.DEFINE_string('f', '', 'kernel')
  
app.flags.DEFINE_integer("d_model", 128, "d_model")
app.flags.DEFINE_integer("d_ff", 512, "d_ff")
app.flags.DEFINE_integer("num_layers", 2, "num_layers")
app.flags.DEFINE_integer("num_heads", 8, "num_heads")
app.flags.DEFINE_float("dropout_rate", 0.1, "dropout_rate")
app.flags.DEFINE_integer("epochs", 50, "epochs")
app.flags.DEFINE_integer("steps_per_epoch", 250, "steps_per_epoch")
app.flags.DEFINE_integer("max_len", 500, "max_len")
app.flags.DEFINE_integer("batch_size", 64, "batch_size")
app.flags.DEFINE_integer("shuffle_buffer_size", 500, "shuffle_buffer_size")
app.flags.DEFINE_integer("validation_steps", 50, "validation_steps")
app.flags.DEFINE_boolean("use_custom_training_loop", False,
                         "use_custom_training_loop")
app.flags.DEFINE_string("tfds_data_dir", "~/tensorflow_datasets",
                        "tfds_data_dir")
app.flags.DEFINE_string("job-dir", "runs/text_classification", "job-dir")
app.run(main)

I0811 01:52:24.725407 140363197581184 dataset_builder.py:184] Overwrite dataset info from restored data version.
I0811 01:52:24.856547 140363197581184 dataset_builder.py:253] Reusing dataset imdb_reviews (/root/tensorflow_datasets/imdb_reviews/subwords8k/0.1.0)
I0811 01:52:24.857880 140363197581184 dataset_builder.py:399] Constructing tf.data.Dataset for split None, from /root/tensorflow_datasets/imdb_reviews/subwords8k/0.1.0
W0811 01:52:25.019160 140363197581184 deprecation.py:323] From <ipython-input-3-1476b29ac834>:11: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(dataset)`.
W0811 01:52:25.128939 140363197581184 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a fu

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inp (InputLayer)                [(None, None)]       0                                            
__________________________________________________________________________________________________
encoder/embedding (Embedding)   (None, None, 128)    1047680     inp[0][0]                        
__________________________________________________________________________________________________
encoder/multiply (MultiplyConst (None, None, 128)    0           encoder/embedding[0][0]          
__________________________________________________________________________________________________
encoder/positional_encoding (Po (1, None, 128)       0           encoder/multiply[0][0]           
______________________________________________________________________________________________

KeyboardInterrupt: ignored

In [None]:
!ls ../