# Setup

In [1]:
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Parameters for the model and dataset.
TRAINING_SIZE = 50000
DIGITS = 3
REVERSE = True

# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
# int is DIGITS.
MAXLEN = DIGITS + 1 + DIGITS

# Generate the data


In [2]:
class CharacterTable:
    """Given a set of characters:
    + Encode them to a one-hot integer representation
    + Decode the one-hot or integer representation to their character output
    + Decode a vector of probabilities to their character output
    """

    def __init__(self, chars):
        """Initialize character table.
        # Arguments
            chars: Characters that can appear in the input.
        """
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

    def encode(self, C, num_rows):
        """One-hot encode given string C.
        # Arguments
            C: string, to be encoded.
            num_rows: Number of rows in the returned one-hot encoding. This is
                used to keep the # of rows for each data the same.
        """
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1
        return x

    def decode(self, x, calc_argmax=True):
        """Decode the given vector or 2D array to their character output.
        # Arguments
            x: A vector or a 2D array of probabilities or one-hot representations;
                or a vector of character indices (used with `calc_argmax=False`).
            calc_argmax: Whether to find the character index with maximum
                probability, defaults to `True`.
        """
        if calc_argmax:
            x = x.argmax(axis=-1)
        return "".join(self.indices_char[x] for x in x)


# All the numbers, plus sign and space for padding.
chars = "0123456789+ "
ctable = CharacterTable(chars)

questions = []
expected = []
seen = set()
print("Generating data...")
while len(questions) < TRAINING_SIZE:
    f = lambda: int(
        "".join(
            np.random.choice(list("0123456789"))
            for i in range(np.random.randint(1, DIGITS + 1))
        )
    )
    a, b = f(), f()
    # Skip any addition questions we've already seen
    # Also skip any such that x+Y == Y+x (hence the sorting).
    key = tuple(sorted((a, b)))
    if key in seen:
        continue
    seen.add(key)
    # Pad the data with spaces such that it is always MAXLEN.
    q = "{}+{}".format(a, b)
    query = q + " " * (MAXLEN - len(q))
    ans = str(a + b)
    # Answers can be of maximum size DIGITS + 1.
    ans += " " * (DIGITS + 1 - len(ans))
    if REVERSE:
        # Reverse the query, e.g., '12+345  ' becomes '  543+21'. (Note the
        # space used for padding.)
        query = query[::-1]
    questions.append(query)
    expected.append(ans)
print("Total questions:", len(questions))

Generating data...
Total questions: 50000


# Vectorize the data


In [3]:
print("Vectorization...")
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
    x[i] = ctable.encode(sentence, MAXLEN)
for i, sentence in enumerate(expected):
    y[i] = ctable.encode(sentence, DIGITS + 1)

# Shuffle (x, y) in unison as the later parts of x will almost all be larger
# digits.
indices = np.arange(len(y))
np.random.shuffle(indices)
x = x[indices]
y = y[indices]

# Explicitly set apart 10% for validation data that we never train over.
split_at = len(x) - len(x) // 10
(x_train, x_val) = x[:split_at], x[split_at:]
(y_train, y_val) = y[:split_at], y[split_at:]

print("Training Data:")
print(x_train.shape)
print(y_train.shape)

print("Validation Data:")
print(x_val.shape)
print(y_val.shape)

Vectorization...


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)


Training Data:
(45000, 7, 12)
(45000, 4, 12)
Validation Data:
(5000, 7, 12)
(5000, 4, 12)


# Build the model


In [4]:
print("Build model...")
num_layers = 1  # Try to add more LSTM layers!

model = keras.Sequential()
# "Encode" the input sequence using a LSTM, producing an output of size 128.
# Note: In a situation where your input sequences have a variable length,
# use input_shape=(None, num_feature).
model.add(layers.LSTM(128, input_shape=(MAXLEN, len(chars))))
# As the decoder RNN's input, repeatedly provide with the last output of
# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
model.add(layers.RepeatVector(DIGITS + 1))
# The decoder RNN could be multiple layers stacked or a single layer.
for _ in range(num_layers):
    # By setting return_sequences to True, return not only the last output but
    # all the outputs so far in the form of (num_samples, timesteps,
    # output_dim). This is necessary as TimeDistributed in the below expects
    # the first dimension to be the timesteps.
    model.add(layers.LSTM(128, return_sequences=True))

# Apply a dense layer to the every temporal slice of an input. For each of step
# of the output sequence, decide which character should be chosen.
model.add(layers.Dense(len(chars), activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

Build model...
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm (LSTM)                 (None, 128)               72192     
                                                                 
 repeat_vector (RepeatVector  (None, 4, 128)           0         
 )                                                               
                                                                 
 lstm_1 (LSTM)               (None, 4, 128)            131584    
                                                                 
 dense (Dense)               (None, 4, 12)             1548      
                                                                 
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
_________________________________________________________________


# Train the model


In [5]:
epochs = 30
batch_size = 32


# Train the model each generation and show predictions against the validation
# dataset.
for epoch in range(1, epochs):
    print()
    print("Iteration", epoch)
    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=1,
        validation_data=(x_val, y_val),
    )
    # Select 10 samples from the validation set at random so we can visualize
    # errors.
    for i in range(10):
        ind = np.random.randint(0, len(x_val))
        rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
        preds = np.argmax(model.predict(rowx), axis=-1)
        q = ctable.decode(rowx[0])
        correct = ctable.decode(rowy[0])
        guess = ctable.decode(preds[0], calc_argmax=False)
        print("Q", q[::-1] if REVERSE else q, end=" ")
        print("T", correct, end=" ")
        if correct == guess:
            print("☑ " + guess)
        else:
            print("☒ " + guess)


Iteration 1
Q 580+731 T 1311 ☒ 1270
Q 878+657 T 1535 ☒ 1587
Q 537+54  T 591  ☒ 557 
Q 357+1   T 358  ☒ 547 
Q 804+1   T 805  ☒ 158 
Q 14+732  T 746  ☒ 597 
Q 396+660 T 1056 ☒ 100 
Q 3+303   T 306  ☒ 344 
Q 25+41   T 66   ☒ 14  
Q 69+839  T 908  ☒ 907 

Iteration 2
Q 243+955 T 1198 ☒ 1299
Q 951+12  T 963  ☒ 969 
Q 391+6   T 397  ☒ 390 
Q 66+858  T 924  ☒ 941 
Q 463+957 T 1420 ☒ 1400
Q 288+909 T 1197 ☒ 1299
Q 824+31  T 855  ☒ 858 
Q 819+876 T 1695 ☒ 1801
Q 358+49  T 407  ☒ 411 
Q 8+599   T 607  ☒ 599 

Iteration 3
Q 753+267 T 1020 ☒ 1012
Q 633+82  T 715  ☒ 711 
Q 15+1    T 16   ☑ 16  
Q 179+6   T 185  ☒ 181 
Q 23+572  T 595  ☒ 695 
Q 13+497  T 510  ☒ 515 
Q 615+599 T 1214 ☒ 1275
Q 330+628 T 958  ☒ 964 
Q 256+94  T 350  ☒ 345 
Q 186+68  T 254  ☒ 255 

Iteration 4
Q 773+895 T 1668 ☒ 1678
Q 12+407  T 419  ☒ 417 
Q 35+37   T 72   ☒ 76  
Q 277+774 T 1051 ☒ 1052
Q 954+447 T 1401 ☒ 1419
Q 29+464  T 493  ☒ 499 
Q 97+942  T 1039 ☒ 1041
Q 35+28   T 63   ☒ 66  
Q 97+61   T 158  ☒ 155 
Q 339+332 T 

Q 56+12   T 68   ☑ 68  
Q 918+1   T 919  ☑ 919 
Q 16+772  T 788  ☑ 788 
Q 681+858 T 1539 ☑ 1539
Q 310+6   T 316  ☑ 316 
Q 654+46  T 700  ☒ 600 
Q 637+620 T 1257 ☑ 1257
Q 88+83   T 171  ☑ 171 
Q 484+18  T 502  ☑ 502 
Q 848+615 T 1463 ☑ 1463

Iteration 11
Q 226+780 T 1006 ☑ 1006
Q 2+808   T 810  ☑ 810 
Q 930+3   T 933  ☑ 933 
Q 1+548   T 549  ☑ 549 
Q 166+4   T 170  ☑ 170 
Q 320+592 T 912  ☑ 912 
Q 85+30   T 115  ☑ 115 
Q 62+799  T 861  ☑ 861 
Q 852+944 T 1796 ☑ 1796
Q 885+38  T 923  ☑ 923 

Iteration 12
Q 930+3   T 933  ☑ 933 
Q 787+780 T 1567 ☑ 1567
Q 371+667 T 1038 ☑ 1038
Q 593+49  T 642  ☑ 642 
Q 749+58  T 807  ☒ 707 
Q 12+81   T 93   ☑ 93  
Q 490+412 T 902  ☑ 902 
Q 726+48  T 774  ☑ 774 
Q 42+834  T 876  ☑ 876 
Q 35+37   T 72   ☑ 72  

Iteration 13
Q 318+73  T 391  ☑ 391 
Q 3+100   T 103  ☑ 103 
Q 42+455  T 497  ☑ 497 
Q 460+8   T 468  ☑ 468 
Q 101+754 T 855  ☑ 855 
Q 773+67  T 840  ☑ 840 
Q 3+949   T 952  ☑ 952 
Q 71+996  T 1067 ☑ 1067
Q 3+303   T 306  ☑ 306 
Q 26+89   T 115  ☑ 115

Q 7+15    T 22   ☑ 22  
Q 739+617 T 1356 ☑ 1356
Q 157+664 T 821  ☑ 821 
Q 50+569  T 619  ☑ 619 
Q 703+829 T 1532 ☑ 1532
Q 69+296  T 365  ☑ 365 
Q 379+390 T 769  ☑ 769 
Q 441+499 T 940  ☑ 940 
Q 3+846   T 849  ☑ 849 
Q 455+35  T 490  ☑ 490 

Iteration 20
Q 716+953 T 1669 ☑ 1669
Q 737+742 T 1479 ☑ 1479
Q 820+575 T 1395 ☑ 1395
Q 133+454 T 587  ☑ 587 
Q 91+194  T 285  ☑ 285 
Q 332+0   T 332  ☑ 332 
Q 147+46  T 193  ☑ 193 
Q 426+45  T 471  ☑ 471 
Q 640+986 T 1626 ☑ 1626
Q 15+456  T 471  ☑ 471 

Iteration 21
Q 970+54  T 1024 ☑ 1024
Q 327+98  T 425  ☑ 425 
Q 156+384 T 540  ☑ 540 
Q 543+482 T 1025 ☑ 1025
Q 567+952 T 1519 ☒ 1509
Q 537+54  T 591  ☑ 591 
Q 924+45  T 969  ☑ 969 
Q 56+868  T 924  ☑ 924 
Q 1+484   T 485  ☑ 485 
Q 98+825  T 923  ☑ 923 

Iteration 22
Q 825+588 T 1413 ☒ 1412
Q 299+235 T 534  ☑ 534 
Q 11+26   T 37   ☑ 37  
Q 416+572 T 988  ☒ 887 
Q 654+790 T 1444 ☒ 1344
Q 0+801   T 801  ☑ 801 
Q 658+6   T 664  ☒ 674 
Q 639+94  T 733  ☑ 733 
Q 70+668  T 738  ☑ 738 
Q 30+209  T 239  ☑ 239

Q 51+20   T 71   ☑ 71  
Q 758+800 T 1558 ☑ 1558
Q 630+729 T 1359 ☒ 1259
Q 841+0   T 841  ☑ 841 
Q 623+5   T 628  ☑ 628 
Q 807+1   T 808  ☑ 808 
Q 12+962  T 974  ☑ 974 
Q 36+99   T 135  ☑ 135 
Q 339+53  T 392  ☑ 392 
Q 682+123 T 805  ☑ 805 

Iteration 29
Q 150+387 T 537  ☑ 537 
Q 26+455  T 481  ☑ 481 
Q 182+841 T 1023 ☒ 122 
Q 595+711 T 1306 ☑ 1306
Q 2+373   T 375  ☑ 375 
Q 522+651 T 1173 ☑ 1173
Q 658+6   T 664  ☑ 664 
Q 667+950 T 1617 ☑ 1617
Q 4+535   T 539  ☑ 539 
Q 29+78   T 107  ☑ 107 
