In [1]:
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import tensorflow as tf

from core.data import make_batches, positional_encoding, tokenizers
from core.transformer import Transformer, MiniTransformer
from core.scheduler import CustomSchedule
from core.callbacks import get_callbacks

In [2]:
# Hyperparameters
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1
exp_path = "./experiments/train"
EPOCHS = 20
BATCHSIZE = 64

In [3]:
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
                               as_supervised=True)
test_examples = examples['test']
test_batches = make_batches(test_examples, batchsize=BATCHSIZE)

In [4]:
# Optimizer
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

# Loss Function
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, 
                                                            reduction='none')

transformer = MiniTransformer(num_layers=num_layers,
                              d_model=d_model,
                              num_heads=num_heads,
                              dff=dff,
                              input_vocab_size=tokenizers.pt.get_vocab_size(),
                              target_vocab_size=tokenizers.en.get_vocab_size(),
                              pe_input=1000,
                              rate=0.1)
transformer.compile(optimizer=optimizer, 
                    loss_function=loss_object,
                    metrics=['accuracy', loss_object])

transformer.model(BATCHSIZE)

transformer.evaluate(test_batches.take(1))



[0.0, 9.199488639831543]

In [5]:
expdir = '{}/train_model.h5'.format(exp_path)
transformer.load_weights(expdir)

In [6]:
prediction = transformer.predict(test_batches.take(1))

In [17]:
y_pred = prediction[0]
y_true = prediction[1]

In [19]:
for t, k in zip(y_pred, y_true):
    print(k.decode('utf-8'))
    print(t.decode('utf-8'))
    print('-'*10)

b"and maybe the hardest thing of all is to figure out that what other people think and feel is n ' t actually exactly like what we think and feel ."
and i it problem thing that the ? , do out how the we people are of how that that ' t going going what . they ' . what .
----------
b"that ' s what you should have said . right ? why is this ?"
and ' s what we see do done . ? ? it ?
----------
b'what we have here is one exploit file .'
and i do to is that ofings
----------
b"it ' s called a microscope ."
and ' s a the p .
----------
b'he had 498 people to prepare his dinner every night .'
and ' to24 pounds in read to own for day .
----------
b'translate that for me .'
ands is a is
----------
b'it is caused by prematurity and genetic conditions .'
and ' a by actionrine and building variation .
----------
b"it ' s for the number needed to treat ."
and ' s a a first of . be the
----------
b"bl : right ? so we ' ve got our observations . we ' ve got our data ."
andoo : thank ? , ' re got to fi