In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from project import *

In [2]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64
MAX_LENGTH = 200

In [3]:
train_examples, test_examples, dataset_info = get_datasets()
encoder = dataset_info.features['text'].encoder
train_dataset, test_dataset = transform_datasets(train_examples, test_examples, 
                                                 encoder=encoder, 
                                                 batch_size=BATCH_SIZE,  
                                                 max_length=MAX_LENGTH, 
                                                 buffer_size=BUFFER_SIZE)

In [4]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = encoder.vocab_size
dropout_rate = 0.1

In [5]:
loss_function = tf.keras.losses.BinaryCrossentropy(from_logits=True)
transformer = TransformerEncoderClassifier(num_layers, d_model, num_heads, dff, 
                                           input_vocab_size, 
                                           pe_input=input_vocab_size, 
                                           rate=dropout_rate)
transformer.compile(optimizer='adam', loss=loss_function, metrics=['accuracy'])

checkpoint = tf.train.latest_checkpoint("./checkpoints")
if checkpoint is not None:
    print("Loading previously trained model.")
    transformer.load_weights(checkpoint)

Loading previously trained model.


In [6]:
EPOCHS = 100

In [14]:
tb = cb.TensorBoard(histogram_freq=1, embeddings_freq=1)
csv = cb.CSVLogger('training.csv')
early = cb.EarlyStopping(monitor='val_accuracy', patience=5)
save = cb.ModelCheckpoint(filepath="checkpoints/train",
         monitor='val_accuracy',
         save_best_only=True,
         save_weights_only=True)

model_history = transformer.fit(train_dataset, 
                                validation_data=test_dataset,  
                                validation_freq=1,
                                shuffle=False,
                                callbacks=[early, save, tb, csv],
                                epochs=EPOCHS)

Epoch 1/100
Epoch 2/100


Epoch 3/100


Epoch 4/100


Epoch 5/100


Epoch 6/100


Epoch 7/100


Epoch 8/100


Epoch 9/100


Epoch 10/100


Epoch 11/100


Epoch 12/100


Epoch 13/100


Epoch 14/100


Epoch 15/100


Epoch 16/100


Epoch 17/100




In [None]:
model_history.history

## Example of Evaluating a model

In [12]:
eval_loss, eval_acc = transformer.evaluate(test_dataset)



In [13]:
print(f'Loss: {eval_loss}, Acc: {eval_acc}')

Loss: 2.6485911242283176, Acc: 0.8304972648620605


### Examples of sentiment analysis

In [9]:
sent = partial(sentiment, encoder=encoder, transformer=transformer)

In [10]:
sent("This was not a very good movie.")

Input: This was not a very good movie.
Predicted sentiment: pos


In [11]:
sent('We loved the movie and definitely recommend it!')

Input: We loved the movie and definitely recommend it!
Predicted sentiment: pos
