<a href="https://colab.research.google.com/github/martin-fabbri/colab-notebooks/blob/master/nlp/seq-to-seq/seq_to_seq_arithmetic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<img src="https://github.com/martin-fabbri/colab-notebooks/raw/master/nlp/seq-to-seq/images/seq_to_seq_arithmetics.png" width=600 alt="Seq-to-seq">

In [93]:
import numpy as np
from tqdm import tqdm
from tensorflow.keras import layers
from tensorflow.keras import Sequential

In [44]:
SEP = ' '
OPERATIONS = '+-'
DIGITS = '0123456789' 
CHARS = SEP + OPERATIONS + DIGITS
VOCAB_SIZE = len(CHARS)
# 'int + int' (e.g., '345+678')
chars = sorted(set(CHARS))
char_to_index = {c: i for i, c in enumerate(chars)}
index_to_char = {i: c for i, c in enumerate(chars)}

In [99]:
def encode(num_str, max_length):
  '''
  One hot encode <num_str>
  # Arguments
    num_rows: Number of rows in the returned one hot encoding. This is
              used to keep the # of rows for each data the same.
  '''
  x = np.zeros((max_length, VOCAB_SIZE))
  for i, c in enumerate(num_str):
    x[i, char_to_index[c]] = 1
  return x

In [105]:
encode('051+123', 3*2 + 1)

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])

In [37]:
def decode(x, calc_argmax=True):
  if calc_argmax:
    x = x.argmax(axis=-1)
  return (''.join(index_to_char[i] for i in x)).strip()

In [43]:
decode(np.array([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

'51'

In [108]:
training_size = 50000
digits = 3
min_val = 0
max_val = 999
hidden_size = 128
batch_size = 128
SAMPLE_MAX_LENGHT = digits * 2 + 1
LABEL_MAX_LENGHT = digits + 2

In [111]:
def generate_sample():
  left_operant = np.random.randint(min_val, max_val)
  right_operant = np.random.randint(min_val, max_val)
  operation = np.random.choice(list(OPERATIONS))

  left_operant, right_operant, operation
  sample = f'{left_operant:0{digits}}{operation}{right_operant:0{digits}}'
  if operation == '+':
    label = f'{left_operant + right_operant:0{LABEL_MAX_LENGHT}}'
  else:
    label = f'{left_operant - right_operant:0{LABEL_MAX_LENGHT}}'
  return sample, label

In [112]:
pbar = tqdm(total=training_size)
samples = []
while len(samples) < training_size:
  sample = generate_sample()
  if sample not in samples:
    samples.append(sample)
    pbar.update(1)
pbar.close()
len(samples), samples[:5]


  0%|          | 0/50000 [00:00<?, ?it/s][A
  4%|▎         | 1834/50000 [00:00<00:02, 18332.76it/s][A
  6%|▌         | 2859/50000 [00:00<00:03, 14821.30it/s][A
  7%|▋         | 3633/50000 [00:00<00:03, 11628.09it/s][A
  9%|▊         | 4350/50000 [00:00<00:04, 9339.30it/s] [A
 10%|█         | 5044/50000 [00:00<00:05, 7543.30it/s][A
 11%|█▏        | 5686/50000 [00:00<00:06, 6534.88it/s][A
 13%|█▎        | 6285/50000 [00:00<00:07, 5615.97it/s][A
 14%|█▎        | 6831/50000 [00:00<00:08, 5166.74it/s][A
 15%|█▍        | 7344/50000 [00:01<00:08, 4754.90it/s][A
 16%|█▌        | 7824/50000 [00:01<00:09, 4365.78it/s][A
 17%|█▋        | 8270/50000 [00:01<00:10, 4034.08it/s][A
 17%|█▋        | 8685/50000 [00:01<00:10, 3778.61it/s][A
 18%|█▊        | 9074/50000 [00:01<00:11, 3526.32it/s][A
 19%|█▉        | 9438/50000 [00:01<00:12, 3347.83it/s][A
 20%|█▉        | 9783/50000 [00:01<00:12, 3219.05it/s][A
 20%|██        | 10113/50000 [00:01<00:13, 3047.01it/s][A
 21%|██        | 1042

(50000,
 [('568+158', '00726'),
  ('693+057', '00750'),
  ('384-248', '00136'),
  ('734-389', '00345'),
  ('806+559', '01365')])

In [115]:
X = [encode(sample, SAMPLE_MAX_LENGHT) for sample, _ in samples]
y = [encode(label, LABEL_MAX_LENGHT) for _, label in samples]

In [116]:
for l in y:
  print(decode(l))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
00172
-0054
00648
01729
01139
00936
-0397
-0095
-0194
-0055
00215
00486
00808
00890
-0128
01783
01122
00371
01623
-0326
-0288
00576
00668
00677
01136
01023
00726
00670
00024
00892
00763
-0415
00136
-0601
00177
-0193
01547
-0213
01685
01398
00860
01411
01088
00104
-0181
00926
01068
01534
-0303
-0009
-0101
01162
00130
00201
01221
01031
01323
-0241
01217
01384
01534
01482
-0241
00084
00777
01106
00894
-0125
00820
-0489
-0008
01196
-0598
01083
01361
00301
00445
00671
-0699
00679
00603
-0276
01263
00253
-0814
00510
00481
-0251
00045
00390
-0252
00502
00657
01100
00928
00743
00410
01372
00275
01594
00865
00023
-0023
00082
00588
00758
-0606
00964
01431
01442
00452
00105
00152
00639
00535
00734
00674
01070
-0791
00330
00401
-0565
00522
00689
00941
00179
01096
00948
-0564
00248
00756
-0311
-0239
00092
01560
00877
00028
00136
-0063
00977
00874
01197
01125
00435
01556
-0039
00696
01550
00367
00784
-0140
01026
01032
00395
01190
01369