# Implementation of a full transformer

In [None]:
import sys
from transformers import AutoConfig, AutoTokenizer
import tensorflow as tf
from tensorflow.keras import Input, Model

sys.path.append('../modules/')

from transformer import Transformer

Load config file and tokenizer.

In [None]:
model_ckpt = 'distilbert-base-uncased'

config = AutoConfig.from_pretrained(model_ckpt)

# Should we use different tokenizers for the encoder and
# the decoder inputs?
tokenizer_encoder = AutoTokenizer.from_pretrained(model_ckpt)
tokenizer_decoder = AutoTokenizer.from_pretrained(model_ckpt)

Define some example text. We'll work with machine translation.

In [None]:
text_encoder = [
    "Six o’clock the siren kicks",
    "him from a dream",
    "Tries to shake it off but it just won’t stop",
    "Can’t find the strength",
    "but he’s got promises to keep",
    "And wood to chop before he sleeps"
]

text_decoder = [
    "Alle sei del mattino le sirene",
    "lo cacciano fuori da un sogno",
    "Provano a scuoterlo ma non si vuole fermare",
    "Non riesce a trovare la forza",
    "ma ha delle promesse da mantenere",
    "E ha della legna da tagliare prima di dormire"
]

input_ids_encoder = tokenizer_encoder(
    text_encoder,
    padding=True,
    return_tensors='tf'
)['input_ids']

input_ids_decoder = tokenizer_decoder(
    text_decoder,
    padding=True,
    return_tensors='tf'
)['input_ids']

print(input_ids_encoder.shape, input_ids_decoder.shape)

Test a full (encoder-decoder) transformer model.

In [None]:
trnsf = Transformer(config=config)

Forward pass.

In [None]:
trnsf([input_ids_encoder, input_ids_decoder])

## Wrap the transformer into a Keras `Model` object

Build a `Model` object.

In [None]:
input_encoder = Input(shape=input_ids_encoder.shape[1:])
input_decoder = Input(shape=input_ids_decoder.shape[1:])

inputs = [input_encoder, input_decoder]
outputs = trnsf(inputs)

transformer_model = Model(
    inputs=inputs,
    outputs=outputs
)

Compile the model.

In [None]:
transformer_model.compile(
    optimizer='rmsprop',
    # Loss is chosen randomly: we just want to test
    # one training epoch on fake target data.
    loss='mse'
)

In [None]:
print('N parameters:', transformer_model.count_params())

Generate fake target data and fit the model.

In [None]:
fake_targets = tf.ones_like(transformer_model([input_ids_encoder, input_ids_decoder]))

transformer_model.fit(
    x=[input_ids_encoder, input_ids_decoder],
    y=fake_targets,
    epochs=1
)

Test generating output after training.

In [None]:
transformer_model([input_ids_encoder, input_ids_decoder])