<a href="https://colab.research.google.com/github/martin-fabbri/colab-notebooks/blob/master/nlp/seq-to-seq/seq_to_seq_arithmetic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<img src="https://github.com/martin-fabbri/colab-notebooks/raw/master/nlp/seq-to-seq/images/seq_to_seq_arithmetics.png" width=600 alt="Seq-to-seq">

In [1]:
import numpy as np
from tqdm import tqdm
from tensorflow.keras import layers
from tensorflow.keras import Sequential

In [2]:
SEP = ' '
OPERATIONS = '+-'
DIGITS = '0123456789' 
CHARS = SEP + OPERATIONS + DIGITS
VOCAB_SIZE = len(CHARS)
# 'int + int' (e.g., '345+678')
chars = sorted(set(CHARS))
char_to_index = {c: i for i, c in enumerate(chars)}
index_to_char = {i: c for i, c in enumerate(chars)}

In [3]:
def encode(num_str, max_length):
  '''
  One hot encode <num_str>
  # Arguments
    num_rows: Number of rows in the returned one hot encoding. This is
              used to keep the # of rows for each data the same.
  '''
  x = np.zeros((max_length, VOCAB_SIZE))
  for i, c in enumerate(num_str):
    x[i, char_to_index[c]] = 1
  return x

In [4]:
encode('051+123', 3*2 + 1)

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])

In [5]:
def decode(x, calc_argmax=True):
  if calc_argmax:
    x = x.argmax(axis=-1)
  return (''.join(index_to_char[i] for i in x)).strip()

In [6]:
decode(np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

'51'

In [24]:
training_size = 50000
digits = 3
min_val = 0
max_val = 999
hidden_size = 128
BATCH_SIZE = 128
SAMPLE_MAX_LENGHT = digits * 2 + 1
LABEL_MAX_LENGHT = digits + 2

In [8]:
def generate_sample():
  left_operant = np.random.randint(min_val, max_val)
  right_operant = np.random.randint(min_val, max_val)
  operation = np.random.choice(list(OPERATIONS))

  left_operant, right_operant, operation
  sample = f'{left_operant:0{digits}}{operation}{right_operant:0{digits}}'
  if operation == '+':
    label = f'{left_operant + right_operant:0{LABEL_MAX_LENGHT}}'
  else:
    label = f'{left_operant - right_operant:0{LABEL_MAX_LENGHT}}'
  return sample, label

In [9]:
pbar = tqdm(total=training_size)
samples = []
while len(samples) < training_size:
  sample = generate_sample()
  if sample not in samples:
    samples.append(sample)
    pbar.update(1)
pbar.close()
len(samples), samples[:5]

100%|██████████| 50000/50000 [00:44<00:00, 1136.11it/s]


(50000,
 [('699-916', '-0217'),
  ('179-939', '-0760'),
  ('881-519', '00362'),
  ('079+445', '00524'),
  ('812-723', '00089')])

In [15]:
X = np.array([encode(sample, SAMPLE_MAX_LENGHT) for sample, _ in samples])
y = np.array([encode(label, LABEL_MAX_LENGHT) for _, label in samples])

## Split training and validation sets. Separate 10% for validation.

In [19]:
split_at = training_size - training_size // 10
(X_train, y_train) = X[:split_at], y[:split_at]
(X_val, y_val) = X[split_at:], y[split_at:]

In [21]:
model = Sequential([
  layers.LSTM(hidden_size, input_shape=(SAMPLE_MAX_LENGHT, len(chars))),
  layers.RepeatVector(LABEL_MAX_LENGHT),
  layers.LSTM(hidden_size, return_sequences=True),
  layers.TimeDistributed(layers.Dense(len(chars), activation='softmax'))
])

In [23]:
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (None, 128)               72704     
_________________________________________________________________
repeat_vector (RepeatVector) (None, 5, 128)            0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 5, 128)            131584    
_________________________________________________________________
time_distributed (TimeDistri (None, 5, 13)             1677      
Total params: 205,965
Trainable params: 205,965
Non-trainable params: 0
_________________________________________________________________


In [28]:
for iteration in range(1, 200):
    print()
    print('-' * 50)
    print('Iteration', iteration)
    model.fit(
        X_train, 
        y_train,
        batch_size=BATCH_SIZE,
        epochs=1,
        validation_data=(X_val, y_val)
    )
    # Select 10 samples from the validation set at random so we can visualize
    # errors.
    for i in range(10):
        ind = np.random.randint(0, len(X_val))
        rowx, rowy = X_val[np.array([ind])], y_val[np.array([ind])]
        preds = model.predict_classes(rowx, verbose=0)
        q = decode(rowx[0])
        correct = decode(rowy[0])
        guess = decode(preds[0], calc_argmax=False)
        print('Q', q, end=' ')
        print('T', correct, end=' ')
        if correct == guess:
            print('☑', end=' ')
        else:
            print('☒', end=' ')
        print(guess)


--------------------------------------------------
Iteration 1
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
Q 827-766 T 00061 ☒ 00060
Q 478+713 T 01191 ☒ 01175
Q 924+835 T 01759 ☒ 01732
Q 767-302 T 00465 ☒ 00420
Q 408-035 T 00373 ☒ 00320
Q 628-681 T -0053 ☒ -0062
Q 366-949 T -0583 ☒ -0650
Q 902-189 T 00713 ☒ 00600
Q 409-429 T -0020 ☒ -0002
Q 930-242 T 00688 ☒ 00600

--------------------------------------------------
Iteration 2
Q 275+350 T 00625 ☒ 00666
Q 203-246 T -0043 ☒ -0028
Q 738+641 T 01379 ☒ 01442
Q 351-203 T 00148 ☒ 00168
Q 384-276 T 00108 ☒ 00126
Q 252-115 T 00137 ☒ 00168
Q 346+851 T 01197 ☒ 01214
Q 785+913 T 01698 ☒ 01718
Q 539+148 T 00687 ☒ 00776
Q 925-001 T 00924 ☒ 00888

-----------------

KeyboardInterrupt: ignored