In [1]:
%load_ext autoreload
%autoreload 2
import tensorflow_datasets as tfds
import tensorflow as tf

import time
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from project import *

## Setup input pipeline

In [2]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64
MAX_LENGTH = 200

In [3]:
train_examples, test_examples, dataset_info = get_datasets()
encoder = dataset_info.features['text'].encoder

In [4]:
train_dataset, test_dataset = transform_datasets(train_examples, test_examples, 
                                                 encoder=encoder,
                                                 batch_size=BATCH_SIZE, 
                                                 max_length=MAX_LENGTH, 
                                                 buffer_size=BUFFER_SIZE)

## Set hyperparameters

To keep this example small and relatively fast, the values for *num_layers, d_model, and dff* have been reduced. 

The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.

Note: By changing the values below, you can get the model that achieved state of the art on many tasks.

In [5]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = encoder.vocab_size
dropout_rate = 0.1

## Optimizer

In [6]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

## Loss and metrics

In [7]:
loss_function = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [8]:
train_loss = tf.keras.metrics.BinaryCrossentropy(name='train_loss', from_logits=True)
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')

## Training and checkpointing

In [9]:
transformer = TransformerEncoderClassifier(num_layers, d_model, num_heads, dff,
                          input_vocab_size,
                          pe_input=input_vocab_size, 
                          rate=dropout_rate)

In [10]:
ckpt_manager = create_checkpoint_manager(transformer, optimizer)

In [11]:
EPOCHS = 10

In [12]:
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

# train_step_signature = [
#     tf.TensorSpec(shape=(None, None), dtype=tf.int64),
#     tf.TensorSpec(shape=(None,), dtype=tf.int64),
# ]

# @tf.function #(input_signature=train_step_signature)
def train_step(inp, tar):
    enc_padding_mask = create_padding_mask(inp)
    with tf.GradientTape() as tape:
        prediction_logits, _ = transformer(inp, True, enc_padding_mask)
        
        loss = loss_function(tar, prediction_logits)

    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(tar, prediction_logits)
    
    predictions = tf.sigmoid(prediction_logits)
    train_accuracy(tar, predictions)

## Training Loop

In [13]:
for epoch in range(EPOCHS):
    start = time.time()
  
    train_loss.reset_states()
    train_accuracy.reset_states()

    # inp -> review, tar -> sentiment
    for (batch, (inp, tar)) in enumerate(train_dataset):
        train_step(inp, tar)

        if batch % 25 == 0:
            print (f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')
      
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print (f'Saving checkpoint for epoch {epoch+1} at {ckpt_save_path}')
    
    print (f'Epoch {epoch + 1} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    print (f'Time taken for 1 epoch: {time.time() - start} secs\n')

Epoch 1 Batch 0 Loss 13.4215 Accuracy 0.4844
Epoch 1 Batch 25 Loss 10.1580 Accuracy 0.4868
Epoch 1 Batch 50 Loss 8.2323 Accuracy 0.4979
Epoch 1 Batch 75 Loss 7.1530 Accuracy 0.5074
Epoch 1 Batch 100 Loss 6.5393 Accuracy 0.5138
Epoch 1 Batch 125 Loss 6.0016 Accuracy 0.5179
Epoch 1 Loss 5.8854 Accuracy 0.5184
Time taken for 1 epoch: 22.80799150466919 secs

Epoch 2 Batch 0 Loss 8.7036 Accuracy 0.5625
Epoch 2 Batch 25 Loss 5.4558 Accuracy 0.5210
Epoch 2 Batch 50 Loss 4.4372 Accuracy 0.5316
Epoch 2 Batch 75 Loss 4.2776 Accuracy 0.5380
Epoch 2 Batch 100 Loss 3.8788 Accuracy 0.5452
Epoch 2 Batch 125 Loss 3.6768 Accuracy 0.5506
Epoch 2 Loss 3.6767 Accuracy 0.5501
Time taken for 1 epoch: 21.312196254730225 secs

Epoch 3 Batch 0 Loss 2.9321 Accuracy 0.5938
Epoch 3 Batch 25 Loss 4.0851 Accuracy 0.5487
Epoch 3 Batch 50 Loss 4.5067 Accuracy 0.5392
Epoch 3 Batch 75 Loss 3.7021 Accuracy 0.5633
Epoch 3 Batch 100 Loss 3.3017 Accuracy 0.5738
Epoch 3 Batch 125 Loss 3.1820 Accuracy 0.5755
Epoch 3 Loss 3.1

## Testing loop

In [14]:
start = time.time()

test_loss = tf.keras.metrics.BinaryCrossentropy(name='test_loss', from_logits=True)
test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')

test_loss.reset_states()
test_accuracy.reset_states()

# inp -> review, tar -> sentiment
for (batch, (inp, tar)) in enumerate(test_dataset):
    enc_padding_mask = create_padding_mask(inp)
    prediction_logits, _ = transformer(inp, True, enc_padding_mask)

    loss = loss_function(tar, prediction_logits)
    
    test_loss(tar, prediction_logits)
    
    predictions = tf.sigmoid(prediction_logits)
    test_accuracy(tar, predictions)
    
    if batch % 25 == 0:
        print (f'Batch {batch} Loss {test_loss.result():.4f} Accuracy {test_accuracy.result():.4f}')

print (f'Loss {test_loss.result():.4f} Accuracy {test_accuracy.result():.4f}')

print (f'Time taken for testing: {time.time() - start} secs\n')

Batch 0 Loss 0.6949 Accuracy 0.8906
Batch 25 Loss 1.4512 Accuracy 0.8395
Batch 50 Loss 1.3999 Accuracy 0.8474
Batch 75 Loss 1.4042 Accuracy 0.8458
Batch 100 Loss 1.4445 Accuracy 0.8447
Batch 125 Loss 1.4350 Accuracy 0.8470
Loss 1.4326 Accuracy 0.8472
Time taken for testing: 9.522772550582886 secs



### Examples of sentiment analysis

In [18]:
sent = partial(sentiment, encoder=encoder, transformer=transformer)

In [19]:
sent("This was not a very good movie. It dragged on for far too long and we couldn't wait to leave the theater.")

Input: This was not a very good movie. It dragged on for far too long and we couldn't wait to leave the theater.
Predicted sentiment: pos


In [23]:
neg_review = 'We loved the movie and definitely recommend it!'

In [24]:
sent(neg_review)

Input: We loved the movie and definitely recommend it!
Predicted sentiment: pos
