# Make RNN learn to perform addition

改寫https://github.com/keras-team/keras/blob/master/examples/addition_rnn.py

手法為sequence to sequence learning

而目標為例如輸入535+61, 希望輸出為596

In [14]:
from keras.models import Model
from keras.layers import Dense, LSTM, Input, RepeatVector, TimeDistributed
from keras.activations import softmax
import numpy as np

In [15]:
class CharacterTable():
    def __init__(self, chars):
        """ 製造char->int以及int->char的對應
        @ Args:
            chars(str): 所有可能的輸入
        """
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
    
    def encode(self, C, num_rows):
        """ 將給定的string C做one-hot encoding
        @ Args:
            C(str): 要做encoding的string
            num_rows(int): 即可能的輸入個數, 比如數字有10種可能, 此即為10
        @ returns:
            x(np.array): C的每個char經過encoding後的結果(num_rows x len(C))
        """
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1
        return x
    
    def decode(self, x, calc_argmax=True):
        """ 將int轉換回char
        """
        if calc_argmax:
            x = x.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in list(x))

In [16]:
class colors:
    ok = '\033[92m'
    fail = '\033[91m'
    close = '\033[0m'

In [17]:
NB_TRAINING_SAMPLES = 50000 # number of addition questions
NB_DIGITS = 3 # input number in [0 - 999]
REVERSE = True # like data augmentation, we have 535+61 and then we'll train 61+535 too
MAX_STRING_LENGTH = NB_DIGITS*2 + 1 # for '345+678', then length is 3*2+1=7

In [18]:
all_chars = '0123456789+ ' # 所有char的可能(12個)
char_table = CharacterTable(all_chars) 

In [19]:
questions = [] # 所有問題, 如'123+456'
answers = [] # 答案, 如579
seen = set() # 出現過的問題

In [20]:
MAXIMUM = 10**NB_DIGITS-1
while len(questions) < NB_TRAINING_SAMPLES:
    a = np.random.randint(0, MAXIMUM)
    b = np.random.randint(0, MAXIMUM)
    q1 = str(a) + '+' + str(b)
    # 這種放法nn可以簡單的知道"非空白之後高機率是非空白", 若調換則NN要多學判斷數字斷點的方式
    q1 = ' '*(MAX_STRING_LENGTH - len(q1)) + q1
    q2 = str(b) + '+' + str(a)
    q2 = ' '*(MAX_STRING_LENGTH - len(q2)) + q2
    if q1 in seen:
        continue
    seen.add(q1)
    ans = str(eval(q1))
    questions.append(q1)
    # 類似前面的方式, 這樣擺的話NN知道"空白之後高機率是空白"
    # 不過我沒驗證  或許實際上沒差?
    answers.append(ans + ' '*(NB_DIGITS+1-len(ans)))
    
    if REVERSE:
        seen.add(q2)    
        ans = str(eval(q2))        
        questions.append(q2)            
        answers.append(ans + ' '*(NB_DIGITS+1-len(ans)) )    

* The shape of x:
[問題數, 問題最大長度(即RNN的輸入長度), 所有的字符數(即one-hot向量的長度)]

In [21]:
x = np.zeros(shape=(len(questions), MAX_STRING_LENGTH, len(all_chars)), dtype=np.bool)
# NB_DIGITS+1: 三位數+三位數最多變四位數
y = np.zeros(shape=(len(questions), NB_DIGITS+1, len(all_chars)), dtype=np.bool)

for i, expression in enumerate(questions):
    x[i] = char_table.encode(expression, MAX_STRING_LENGTH)
for i, answer in enumerate(answers):
    y[i] = char_table.encode(answer, NB_DIGITS+1)

In [22]:
random_order = np.arange(len(questions))
np.random.shuffle(random_order)

x = x[random_order]
y = y[random_order]

In [23]:
# train-test split
split_index = int(0.9*len(questions))
train_x, val_x = x[:split_index], x[split_index:]
train_y, val_y = y[:split_index], y[split_index:]

In [24]:
HIDDEN_SIZE = 128 # lstm中hidden units的數目
NB_LSTM_LAYERS_WITH_SEQ_OUTPUT = 1 # encoder後串接的lstm層數
RNN_MODEL = LSTM

In [25]:
# input shape即為MAX_STRING_LENGTH(此例為3+3+1=7) x encoding長度(所有char的數目, 此例為12)
model_input = Input(shape=train_x.shape[1:], name='input_layer')
# 此處lstm遞迴MAX_STRING_LENGTH次, 並且只輸出這整個字串的encoder結果, 並不輸出sequence
rnn = RNN_MODEL(units=HIDDEN_SIZE)(model_input)
# 將上面結果重複NB_DIGITS+1次後再丟到LSTM中, 並且輸出sequence
# 即我們最終的目標是預測一個最長可能有NB_DIGITS+1的數字(e.g. 600+500=1100)
rnn = RepeatVector(NB_DIGITS+1)(rnn)
for _ in range(NB_LSTM_LAYERS_WITH_SEQ_OUTPUT):
    rnn = RNN_MODEL(HIDDEN_SIZE, return_sequences=True)(rnn)
# 現在依照初始的架構, 此處的shape為NB_DIGITS x HIDDEN_SIZE
# TimeDistributed對NB_DIGITS+1個資料分別建立一個NN
# 也就是TimeDistributed分配的網路數=RNN遞迴次數=輸出的sequence長度=NB_DIGITS+1
# 此目的在於將每個time step的資訊分開, 如最後面的圖的decoder部分中Dense彼此沒有連接
# 若沒有這個步驟, 則所有time step的資訊將會混在一起
# 由於每個char都有12種可能, 因此Dense的hidden unit數目為len(all_chars)=12
prediction = TimeDistributed(Dense(len(all_chars), activation='softmax'))(rnn)
model = Model(inputs=[model_input], outputs=[prediction])
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_layer (InputLayer)     (None, 7, 12)             0         
_________________________________________________________________
lstm_3 (LSTM)                (None, 128)               72192     
_________________________________________________________________
repeat_vector_2 (RepeatVecto (None, 4, 128)            0         
_________________________________________________________________
lstm_4 (LSTM)                (None, 4, 128)            131584    
_________________________________________________________________
time_distributed_2 (TimeDist (None, 4, 12)             1548      
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
_________________________________________________________________
None


In [26]:
BATCH_SIZE=32
EPOCHS=100
model.compile(optimizer='nadam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
for i in range(200):
    model.fit(x=train_x, y=train_y, validation_data=(val_x, val_y),
              batch_size=BATCH_SIZE, epochs=1)
    for i in range(10):
        ind = np.random.randint(0, len(val_x))
        rowx, rowy = val_x[np.array([ind])], val_y[np.array([ind])]
        preds = model.predict(rowx)
        q = char_table.decode(rowx[0])
        correct = char_table.decode(rowy[0])
        guess = char_table.decode(preds[0], calc_argmax=True)
        print('Q', q[::-1] if REVERSE else q, end=' ')
        print('T', correct, end=' ')
        if correct == guess:
            print(colors.ok + '☑' + colors.close, end=' ')
        else:
            print(colors.fail + '☒' + colors.close, end=' ')
        print(guess)

Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 575+356 T 1228 [91m☒[0m 1227
Q 337+237 T 1465 [91m☒[0m 1402
Q 679+984 T 1465 [91m☒[0m 1577
Q 366+63  T 699  [91m☒[0m 700 
Q 629+693 T 1322 [91m☒[0m 1300
Q 446+221 T 766  [91m☒[0m 707 
Q 268+136 T 1493 [91m☒[0m 1472
Q 388+92  T 912  [91m☒[0m 900 
Q 299+159 T 1943 [91m☒[0m 1777
Q 6+476   T 680  [91m☒[0m 707 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 675+519 T 1491 [91m☒[0m 1499
Q 341+846 T 791  [91m☒[0m 794 
Q 132+91  T 250  [91m☒[0m 210 
Q 319+472 T 1187 [91m☒[0m 1177
Q 079+713 T 1287 [91m☒[0m 1299
Q 331+981 T 322  [91m☒[0m 214 
Q 291+516 T 807  [91m☒[0m 704 
Q 197+94  T 840  [91m☒[0m 844 
Q 035+122 T 751  [91m☒[0m 759 
Q 119+525 T 1436 [91m☒[0m 1446
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 07+773  T 447  [92m☑[0m 447 
Q 161+767 T 928  [92m☑[0m 928 
Q 363+383 T 746  [91m☒[0m 748 
Q 796+898 T 1595 [91m☒[0m 1583
Q 617+657 T 1472 [91m☒[0m 147

Q 724+952 T 686  [92m☑[0m 686 
Q 515+571 T 690  [92m☑[0m 690 
Q 937+776 T 1416 [92m☑[0m 1416
Q 26+017  T 772  [91m☒[0m 872 
Q 736+729 T 1564 [92m☑[0m 1564
Q 902+374 T 682  [92m☑[0m 682 
Q 062+444 T 704  [92m☑[0m 704 
Q 356+081 T 833  [92m☑[0m 833 
Q 976+001 T 779  [92m☑[0m 779 
Q 713+635 T 853  [92m☑[0m 853 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 858+795 T 1455 [92m☑[0m 1455
Q 019+348 T 1753 [92m☑[0m 1753
Q 317+929 T 1642 [92m☑[0m 1642
Q 703+631 T 443  [92m☑[0m 443 
Q 4+179   T 975  [92m☑[0m 975 
Q 152+306 T 854  [92m☑[0m 854 
Q 239+676 T 1608 [92m☑[0m 1608
Q 131+735 T 668  [92m☑[0m 668 
Q 761+665 T 733  [92m☑[0m 733 
Q 836+034 T 1068 [92m☑[0m 1068
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 957+439 T 1693 [92m☑[0m 1693
Q 485+783 T 971  [92m☑[0m 971 
Q 134+482 T 715  [92m☑[0m 715 
Q 534+98  T 524  [92m☑[0m 524 
Q 837+346 T 1381 [92m☑[0m 1381
Q 564+271 T 637  [92m☑[0m 637 
Q 712+448 T 1061 [92m☑

Q 916+538 T 1454 [92m☑[0m 1454
Q 059+44  T 994  [92m☑[0m 994 
Q 357+46  T 817  [92m☑[0m 817 
Q 346+32  T 666  [92m☑[0m 666 
Q 341+452 T 397  [92m☑[0m 397 
Q 8+51    T 23   [92m☑[0m 23  
Q 381+185 T 764  [92m☑[0m 764 
Q 001+469 T 1064 [92m☑[0m 1064
Q 479+389 T 1957 [92m☑[0m 1957
Q 852+541 T 403  [92m☑[0m 403 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 205+7   T 509  [92m☑[0m 509 
Q 437+899 T 1732 [92m☑[0m 1732
Q 67+374  T 549  [91m☒[0m 449 
Q 554+856 T 1113 [92m☑[0m 1113
Q 383+214 T 795  [92m☑[0m 795 
Q 414+548 T 1259 [92m☑[0m 1259
Q 685+585 T 1171 [92m☑[0m 1171
Q 283+689 T 1368 [92m☑[0m 1368
Q 067+664 T 1226 [92m☑[0m 1226
Q 116+24  T 653  [91m☒[0m 643 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 544+082 T 725  [92m☑[0m 725 
Q 427+849 T 1672 [92m☑[0m 1672
Q 813+906 T 927  [92m☑[0m 927 
Q 588+202 T 1087 [92m☑[0m 1087
Q 091+774 T 667  [92m☑[0m 667 
Q 994+822 T 727  [92m☑[0m 727 
Q 545+587 T 1330 [92m☑

Q 995+268 T 1461 [92m☑[0m 1461
Q 546+304 T 1048 [92m☑[0m 1048
Q 527+194 T 1216 [92m☑[0m 1216
Q 26+017  T 772  [92m☑[0m 772 
Q 25+279  T 1024 [92m☑[0m 1024
Q 674+014 T 886  [92m☑[0m 886 
Q 308+507 T 1508 [92m☑[0m 1508
Q 562+972 T 544  [92m☑[0m 544 
Q 764+336 T 1100 [92m☑[0m 1100
Q 448+966 T 1513 [92m☑[0m 1513
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 933+998 T 1238 [92m☑[0m 1238
Q 76+433  T 401  [92m☑[0m 401 
Q 845+182 T 829  [92m☑[0m 829 
Q 826+831 T 766  [92m☑[0m 766 
Q 995+376 T 1272 [92m☑[0m 1272
Q 127+275 T 1293 [92m☑[0m 1293
Q 479+389 T 1957 [92m☑[0m 1957
Q 513+381 T 498  [92m☑[0m 498 
Q 354+113 T 764  [92m☑[0m 764 
Q 361+669 T 1129 [92m☑[0m 1129
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 685+774 T 1063 [92m☑[0m 1063
Q 403+808 T 1112 [92m☑[0m 1112
Q 951+222 T 381  [92m☑[0m 381 
Q 891+481 T 382  [92m☑[0m 382 
Q 077+574 T 1245 [92m☑[0m 1245
Q 116+24  T 653  [92m☑[0m 653 
Q 814+817 T 1136 [92m☑

Q 062+823 T 588  [92m☑[0m 588 
Q 599+203 T 1297 [92m☑[0m 1297
Q 799+928 T 1826 [92m☑[0m 1826
Q 709+007 T 1607 [92m☑[0m 1607
Q 129+896 T 1619 [92m☑[0m 1619
Q 853+525 T 883  [92m☑[0m 883 
Q 702+289 T 1189 [92m☑[0m 1189
Q 388+603 T 1189 [92m☑[0m 1189
Q 907+838 T 1547 [92m☑[0m 1547
Q 832+527 T 963  [92m☑[0m 963 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 077+862 T 1038 [92m☑[0m 1038
Q 955+749 T 1506 [92m☑[0m 1506
Q 16+76   T 128  [92m☑[0m 128 
Q 619+932 T 1155 [92m☑[0m 1155
Q 577+689 T 1761 [92m☑[0m 1761
Q 702+047 T 947  [92m☑[0m 947 
Q 487+103 T 1085 [92m☑[0m 1085
Q 605+031 T 636  [92m☑[0m 636 
Q 422+642 T 470  [92m☑[0m 470 
Q 214+15  T 463  [92m☑[0m 463 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 069+7   T 967  [92m☑[0m 967 
Q 954+918 T 1278 [92m☑[0m 1278
Q 957+733 T 1096 [92m☑[0m 1096
Q 432+985 T 823  [92m☑[0m 823 
Q 806+476 T 1282 [92m☑[0m 1282
Q 924+622 T 655  [92m☑[0m 655 
Q 553+127 T 1076 [92m☑

Q 395+944 T 1042 [92m☑[0m 1042
Q 85+795  T 655  [92m☑[0m 655 
Q 941+277 T 921  [92m☑[0m 921 
Q 622+176 T 897  [92m☑[0m 897 
Q 413+183 T 695  [92m☑[0m 695 
Q 966+949 T 1618 [92m☑[0m 1618
Q 34+106  T 644  [92m☑[0m 644 
Q 945+024 T 969  [92m☑[0m 969 
Q 545+587 T 1330 [92m☑[0m 1330
Q 225+601 T 628  [92m☑[0m 628 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 384+412 T 697  [92m☑[0m 697 
Q 422+856 T 882  [92m☑[0m 882 
Q 681+237 T 918  [92m☑[0m 918 
Q 151+162 T 412  [92m☑[0m 412 
Q 36+591  T 258  [92m☑[0m 258 
Q 073+848 T 1218 [92m☑[0m 1218
Q 558+57  T 930  [92m☑[0m 930 
Q 591+568 T 1060 [92m☑[0m 1060
Q 923+098 T 1219 [92m☑[0m 1219
Q 302+487 T 987  [92m☑[0m 987 
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
Q 851+217 T 870  [92m☑[0m 870 
Q 747+496 T 1441 [92m☑[0m 1441
Q 93+228  T 861  [92m☑[0m 861 
Q 217+94  T 761  [92m☑[0m 761 
Q 454+706 T 1061 [92m☑[0m 1061
Q 527+811 T 843  [92m☑[0m 843 
Q 672+438 T 1110 [92m☑

KeyboardInterrupt: 

下圖為以上架構的詳解, 圖源來自

https://qiita.com/HotAllure/items/0045998971a48909853d

<img style="float: left;" src="pics/encoder.png" width="50%">
<img style="float: right;" src="pics/decoder.png" width="50%">