[Return to Top](#returnToTop)  
<a id = 'shakespeare'></a>

## 6. Shakespeare-to-Modern English Translation with Seq2Seq Transformers

In this example, we'll build a sequence-to-sequence Transformer based model, which
we'll train on a Shakespearian English to Modern English machine translation task.

We'll need to:

- Use a `TransformerEncoder` layer, a `TransformerDecoder` layer,
and a `PositionalEmbedding` layer.
- Prepare data as sentence pairs for training a sequence-to-sequence machine translation model.
- Use the trained model to generate translations.

How does this differ from using T5 or M2M100?  This model is not pre-trained so it knows nothing about language when we start to train it.

This notebook is based on a [Keras notebook](https://github.com/keras-team/keras-io/blob/master/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb) that translates English into Spanish. This lesson notebook reuses the transformer code.

 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/datasci-w266/2025-spring-main/blob/master/materials/lesson_notebooks/lesson_6_Machine_Translation_With_Transformer.ipynb)

## Setup

In [1]:
# We set the backend to TensorFlow. The code works with
# both `tensorflow` and `torch`. It does not work with JAX
# due to the behavior of `jax.numpy.tile` in a jit scope
# (used in `TransformerDecoder.get_causal_attention_mask()`:
# `tile` in JAX does not support a dynamic `reps` argument.
# You can make the code work in JAX by wrapping the
# inside of the `get_causal_attention_mask` method in
# a decorator to prevent jit compilation:
# `with jax.ensure_compile_time_eval():`.
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import pathlib
import random
import string
import re
import numpy as np

import tensorflow.data as tf_data
import tensorflow.strings as tf_strings

import keras
from keras import layers
from keras import ops
from keras.layers import TextVectorization

Let's define a set of hyperparameters that we can use to configure this model.

In [2]:
BATCH_SIZE = 64
EPOCHS = 15  # This should be at least 10 for convergence
MAX_SEQUENCE_LENGTH = 40

#The size of our source and target language vocabularies
ORG_VOCAB_SIZE = 15000
MOD_VOCAB_SIZE = 15000

#define some hyperparameter values for our transformers
EMBED_DIM = 256
INTERMEDIATE_DIM = 2048
NUM_HEADS = 8

### 6.1 Downloading the data

The data includes aligned sentences from a number of plays by William Shakespeare.  The data was copied from this repo --[https://github.com/cocoxu/Shakespeare](https://github.com/cocoxu/Shakespeare) -- and consolidated into one file for easier handling.

You will to grab a copy from our git repo and import it to your Google drive.  From there you'll be able to easily load it in to a Colab notebook.

In [3]:
#This cell will authenticate you and mount your Drive in the Colab.
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
#Modify this path to the appropriate location in your Drive
text_file = 'drive/MyDrive/data/train_plays-org-mod.txt'

In [5]:
!ls

drive  sample_data


### 6.2 Parsing the data

Each line contains a Shakespearean sentence and its corresponding modern English translation.
The Shakesperean sentence is the *source sequence* and modern English one is the *target sequence*.

In [6]:
with open(text_file) as f:
    lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
    org, mod = line.split("\t")
    mod = "[start] " + mod + " [end]"
    text_pairs.append((org, mod))

Here's what our sentence pairs look like:

In [7]:
for _ in range(5):
    print(random.choice(text_pairs))

('Help, ho!', '[start] Help! [end]')
('Ay, too gentle.', '[start] Yes, too gentle. [end]')
("Yet, 'tis the plague of great ones; Prerogatived are they less than the base; 'Tis destiny unshunnable, like death: Even then this forked plague is fated to us When we do quicken.", '[start] Still, it is the plague of great men, They have fewer choices than common men; It is an unshakeable destiny, like death. [end]')
('Speak softly.', '[start] Speak softly. [end]')
('I will do everything that thou wilt have me.', '[start] I’ll do everything she wants me to do. [end]')


Now, let's split the sentence pairs into a training set, a validation set,
and a test set.

In [8]:
random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

19088 total pairs
13362 training pairs
2863 validation pairs
2863 test pairs


Note we have roughly 13,000 sentence pairs for training from scratch.  How does this compare with the number of sentences in other sentence pair corpora we've seen?

## Vectorizing the text data

We'll use two instances of the `TextVectorization` layer to vectorize the text
data (one for Shakespearean English and one for Modern English),
that is to say, to turn the original strings into integer sequences
where each integer represents the index of a word in a vocabulary.

The Shakespearean and modern English layers will use the default string standardization (strip punctuation characters)
and splitting scheme (split on whitespace). We can also remove any problematic punctuation characters if necessary.



In [9]:
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

vocab_size = ORG_VOCAB_SIZE = MOD_VOCAB_SIZE
sequence_length = MAX_SEQUENCE_LENGTH
batch_size = BATCH_SIZE


def custom_standardization(input_string):
    lowercase = tf_strings.lower(input_string)
    return tf_strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")


org_vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)

mod_vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length + 1,
    standardize=custom_standardization,
)

train_org_texts = [pair[0] for pair in train_pairs]
train_mod_texts = [pair[1] for pair in train_pairs]
org_vectorization.adapt(train_org_texts)
mod_vectorization.adapt(train_mod_texts)

Next, we'll format our datasets.

At each training step, the model will seek to predict target words N+1 (and beyond)
using the source sentence and the target words 0 to N.

As such, the training dataset will yield a tuple `(inputs, targets)`, where:

- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
- `target` is the target sentence offset by one step:
it provides the next words in the target sentence -- what the model will try to predict.

In [10]:
def format_dataset(org, mod):
    org = org_vectorization(org)
    mod = mod_vectorization(mod)
    return (
        {
            "encoder_inputs": org,
            "decoder_inputs": mod[:, :-1],
        },
        mod[:, 1:],
    )


def make_dataset(pairs):
    org_texts, mod_texts = zip(*pairs)
    org_texts = list(org_texts)
    mod_texts = list(mod_texts)
    dataset = tf_data.Dataset.from_tensor_slices((org_texts, mod_texts))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(format_dataset)
    return dataset.cache().shuffle(2048).prefetch(16)


train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

Let's take a quick look at the sequence shapes
(we have batches of 64 pairs, and all sequences are 20 steps long):

In [11]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f"targets.shape: {targets.shape}")

inputs["encoder_inputs"].shape: (64, 40)
inputs["decoder_inputs"].shape: (64, 40)
targets.shape: (64, 40)


## Building the model

Our sequence-to-sequence Transformer consists of a `TransformerEncoder`
and a `TransformerDecoder` chained together. To make the model aware of word order,
we also use a `PositionalEmbedding` layer.

The source sequence will be passed to the `TransformerEncoder`,
which will produce a new representation of it.

In [12]:
import keras.ops as ops

#Define the encoder
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        if mask is not None:
            padding_mask = ops.cast(mask[:, None, :], dtype="int32")
        else:
            padding_mask = None

        attention_output = self.attention(
            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "dense_dim": self.dense_dim,
                "num_heads": self.num_heads,
            }
        )
        return config


class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = ops.shape(inputs)[-1]
        positions = ops.arange(0, length, 1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return ops.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config




This new representation built by te encoder will then be passed
to the `TransformerDecoder`, together with the target sequence so far (target words 0 to N).
The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).


A key detail that makes this possible is causal masking
(see method `get_causal_attention_mask()` on the `TransformerDecoder`).
The `TransformerDecoder` sees the entire sequences at once, and thus we must make
sure that it only uses information from target tokens 0 to N when predicting token N+1
(otherwise, it could use information from the future, which would
result in a model that cannot be used at inference time).

In [13]:
#Define the decoder
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        inputs, encoder_outputs = inputs
        causal_mask = self.get_causal_attention_mask(inputs)

        if mask is None:
            inputs_padding_mask, encoder_outputs_padding_mask = None, None
        else:
            inputs_padding_mask, encoder_outputs_padding_mask = mask

        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask,
            query_mask=inputs_padding_mask,
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            query_mask=inputs_padding_mask,
            key_mask=encoder_outputs_padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        input_shape = ops.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = ops.arange(sequence_length)[:, None]
        j = ops.arange(sequence_length)
        mask = ops.cast(i >= j, dtype="int32")
        mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = ops.concatenate(
            [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
            axis=0,
        )
        return ops.tile(mask, mult)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "embed_dim": self.embed_dim,
                "latent_dim": self.latent_dim,
                "num_heads": self.num_heads,
            }
        )
        return config

Next, we assemble the end-to-end model.

In [14]:
embed_dim = EMBED_DIM
latent_dim = INTERMEDIATE_DIM
num_heads = NUM_HEADS

#define the encoder
encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
encoder = keras.Model(encoder_inputs, encoder_outputs)

#define the decoder
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])
x = layers.Dropout(0.5)(x)
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)


In [15]:
#connect the encoder and decoder together in sequence
transformer = keras.Model(
    {"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
    decoder_outputs,
    name="s2s_transformer",
)

In [16]:
#compile the model
transformer.compile(
    "rmsprop",
    loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
    metrics=["accuracy"],
)

In [17]:
transformer.summary()

## Training our model

We'll use accuracy as a quick way to monitor training progress on the validation data.
Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.

Here we only train for 15 epoch.

In [18]:
epochs = EPOCHS  #

transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)

Epoch 1/15
[1m209/209[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 176ms/step - accuracy: 0.0276 - loss: 6.7986 - val_accuracy: 0.0453 - val_loss: 5.6607
Epoch 2/15
[1m209/209[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 88ms/step - accuracy: 0.0509 - loss: 5.4871 - val_accuracy: 0.0756 - val_loss: 4.9805
Epoch 3/15
[1m209/209[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 92ms/step - accuracy: 0.0818 - loss: 4.7401 - val_accuracy: 0.0930 - val_loss: 4.5826
Epoch 4/15
[1m209/209[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 91ms/step - accuracy: 0.1011 - loss: 4.2435 - val_accuracy: 0.1005 - val_loss: 4.3724
Epoch 5/15
[1m209/209[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 93ms/step - accuracy: 0.1137 - loss: 3.8635 - val_accuracy: 0.1053 - val_loss: 4.2897
Epoch 6/15
[1m209/209[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 93ms/step - accuracy: 0.1255 - loss: 3.5232 - val_accuracy: 0.1059 - val_loss: 4.2789
Epoch 7/15
[1m

<keras.src.callbacks.history.History at 0x7fcaa9299810>

## Decoding test sentences

Finally, let's demonstrate how to translate brand new English sentences.
We simply feed into the model the vectorized English sentence
as well as the target token `"[start]"`, then we repeatedly generated the next token, until
we hit the token `"[end]"`.

In [19]:
mod_vocab = mod_vectorization.get_vocabulary()
mod_index_lookup = dict(zip(range(len(mod_vocab)), mod_vocab))
max_decoded_sentence_length = 20


def decode_sequence(input_sentence):
    tokenized_input_sentence = org_vectorization([input_sentence])
    decoded_sentence = "[start]"
    for i in range(max_decoded_sentence_length):
        tokenized_target_sentence = mod_vectorization([decoded_sentence])[:, :-1]
        predictions = transformer(
            {
                "encoder_inputs": tokenized_input_sentence,
                "decoder_inputs": tokenized_target_sentence,
            }
        )

        # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
        sampled_token_index = ops.convert_to_numpy(
            ops.argmax(predictions[0, i, :])
        ).item(0)
        sampled_token = mod_index_lookup[sampled_token_index]
        decoded_sentence += " " + sampled_token

        if sampled_token == "[end]":
            break
    return decoded_sentence


test_org_texts = [pair[0] for pair in test_pairs]
for _ in range(10):
    input_sentence = random.choice(test_org_texts)
    translated = decode_sequence(input_sentence)
    print(input_sentence, translated)

Why, what’s a moveable? [start] why what’s this [end]
If you were civil and knew courtesy, You would not do me thus much injury. [start] if you were civil and knew me you would not you like this much me much and you [end]
This is a knavery of them to make me afeard. [start] this is a ring of them to make me afraid of them [end]
What man dare, I dare. [start] what man i would dare [end]
Tis new to thee. [start] it’s a new you [end]
We wish your peace. [start] we wish your peace [end]
Faith, like enough. [start] really like it [end]
God’s arm strike with us! [start] god’s arm let’s hit us [end]
I durst, my lord, to wager she is honest, Lay down my soul at stake. [start] i dare to my lord wager she is down [end]
And if your love Can labor ought in sad invention, Hang her an epitaph upon her tomb And sing it to her bones. [start] and if your love can hang on the imagination sad without her bedroom dog [end]


What things could we do to improve the output?

* add more sentence pairs
* ensure a good distribution over all the sentence lengths
* leverage techniques to better train both the encoder and decoder
???