The model takes a sequence of numbers (e.g., [1, 2, 3]) as input and outputs the reversed sequence (e.g., [3, 2, 1]).

In [3]:
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense

# Generate the data
num_samples = 1000  # Number of sequences
input_length = 5     # Length of each input sequence
input_vocab_size = 10  # Vocabulary size (digits 0-9)

# Generate random input sequences
input_sequences = np.random.randint(1, input_vocab_size, (num_samples, input_length))
target_sequences = np.flip(input_sequences, axis=1)  # Reverse the sequences for the target

# One-hot encode the input and target sequences
input_onehot = np.eye(input_vocab_size)[input_sequences]  # (num_samples, input_length, input_vocab_size)
target_onehot = np.eye(input_vocab_size)[target_sequences]

# Prepare decoder input sequences (shifted target sequences)
decoder_input_sequences = np.zeros_like(target_onehot)
decoder_input_sequences[:, 1:, :] = target_onehot[:, :-1, :]  # Shift target sequences to the right
decoder_input_sequences[:, 0, :] = 0  # Set <start> token as all zeros

# Define the Encoder
encoder_inputs = Input(shape=(None, input_vocab_size))
encoder_lstm = LSTM(64, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)
encoder_states = [state_h, state_c]

# Define the Decoder
decoder_inputs = Input(shape=(None, input_vocab_size))
decoder_lstm = LSTM(64, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = Dense(input_vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the full model
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit([input_onehot, decoder_input_sequences], target_onehot, epochs=20, batch_size=64, verbose=1)

# Define the inference models (for prediction)
# Encoder model
encoder_model = Model(encoder_inputs, encoder_states)

# Decoder model
decoder_state_input_h = Input(shape=(64,))
decoder_state_input_c = Input(shape=(64,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_lstm_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs
)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_lstm_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)

# Function to decode a sequence
def decode_sequence(input_seq):
    # Encode the input sequence to get the states
    states = encoder_model.predict(input_seq)

    # Create an empty target sequence with a <start> token (all zeros)
    target_seq = np.zeros((1, 1, input_vocab_size))

    # Initialize the decoded sequence
    decoded_sequence = []

    # Generate tokens one-by-one
    for _ in range(input_length):
        # Predict the next token
        output_tokens, h, c = decoder_model.predict([target_seq] + states)

        # Get the token with the highest probability
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        decoded_sequence.append(sampled_token_index)

        # Update the target sequence (input to the next step of the decoder)
        target_seq = np.zeros((1, 1, input_vocab_size))
        target_seq[0, 0, sampled_token_index] = 1

        # Update the decoder's states
        states = [h, c]

    return decoded_sequence

# Test the model with a new input
test_sequence = np.random.randint(1, input_vocab_size, (1, input_length))  # Single test sequence
test_sequence_onehot = np.eye(input_vocab_size)[test_sequence]  # One-hot encode the input
decoded_output = decode_sequence(test_sequence_onehot)

# Print the results
print("Input sequence:", test_sequence[0])  # Original input sequence
print("Decoded sequence:", decoded_output)  # Reversed sequence predicted by the model


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Input sequence: [8 4 5 4 5]
Decoded sequence: [5, 4, 5, 4, 8]
