# 덧셈 뉴럴넷 만들기

- from: https://keras.io/examples/addition_rnn/
- with RNN (LSTM)
- (최대 3자리) 숫자 두 개를 더하는 작업
- 결과는 4자리 숫자 산출 가능

In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
import numpy as np

In [2]:
DATASET_SIZE = 50_000

DIGITS = 3
# Maximum length of input is 'int + int' (e.g., '345+678')
MAXLEN = DIGITS + 1 + DIGITS

# 숫자 표현 방법 정의

- One-hot encoding 사용
- 숫자 10개(0~9)와 + 기호, 공백까지 사용하여 총 12개의 문자 사용

In [3]:
class CharacterTable(object):
    def __init__(self, chars):
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

    def encode(self, C, num_rows):
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1
        return x

    def decode(self, x, calc_argmax=True):
        if calc_argmax:
            x = x.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in x)
    
# All the numbers, plus sign and space for padding.
chars = '0123456789+ '
ctable = CharacterTable(chars)

In [4]:
ctable.chars, len(chars)

([' ', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], 12)

In [5]:
ctable.encode('345+32 ', MAXLEN)

array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [6]:
ctable.encode('1620', DIGITS + 1)

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [7]:
ctable.decode(np.array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
                        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

'1620'

In [8]:
ctable.decode(np.array([[0., 0., 0., 0.9, 0., 0.1, 0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0.45, 0., 0., 0., 0., 0.55, 0., 0., 0.],
                        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
                        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

'1620'

## Dataset 생성

In [9]:
questions = []  # 문제 DB
expected = []   # 정답 DB
seen = set()    # 같은 문제 제거용

while len(questions) < DATASET_SIZE:
    f = lambda: int(''.join(np.random.choice(list('0123456789')) for i in range(np.random.randint(1, DIGITS + 1))))
    a, b = f(), f()
    key = tuple(sorted((a, b)))
    if key in seen:
        continue
    seen.add(key)
    
    q = f"{a}+{b}"
    query = f"{q:<{MAXLEN}}"
    ans = f"{a+b:{DIGITS+1}}"
    # Reverse the query, e.g., '12+345  ' becomes '  543+21'
    query = query[::-1]
    questions.append(query)
    expected.append(ans)
print('Total addition questions:', len(questions))

Total addition questions: 50000


In [10]:
questions[0], expected[0]

('   2+85', '  60')

In [11]:
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
    x[i] = ctable.encode(sentence, MAXLEN)
for i, sentence in enumerate(expected):
    y[i] = ctable.encode(sentence, DIGITS + 1)

# Shuffle (x, y)
indices = np.arange(len(y))
np.random.shuffle(indices)
x = x[indices]
y = y[indices]

In [12]:
indices

array([39348, 12860, 14717, ..., 38890,  6479, 38775])

## Model 정의

In [16]:
# Try replacing GRU, or SimpleRNN.
RNN = layers.LSTM
HIDDEN_SIZE = 128

model = Sequential()

# B x MAXLEN x CHARS -> B x HIDDENS
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))

# B x HIDDENS -> B x DIGITS+1 x HIDDENS 
model.add(layers.RepeatVector(DIGITS + 1))

# same shape: B x DIGITS+1 x HIDDENS (return_sequences is true)
model.add(RNN(HIDDEN_SIZE, return_sequences=True))

# B x DIGITS+1 x HIDDENS -> B x DIGITS+1 x CHARS
model.add(layers.TimeDistributed(layers.Dense(len(chars), activation='softmax')))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_2 (LSTM)                (None, 128)               72192     
_________________________________________________________________
repeat_vector_1 (RepeatVecto (None, 4, 128)            0         
_________________________________________________________________
lstm_3 (LSTM)                (None, 4, 128)            131584    
_________________________________________________________________
time_distributed_1 (TimeDist (None, 4, 12)             1548      
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
_________________________________________________________________


In [17]:
%%time

REPEAT = 5
EPOCHS_PER_CYCLE = 10

for r in range(REPEAT):
    model.fit(x, y,
              batch_size=256,
              initial_epoch=r * EPOCHS_PER_CYCLE,
              epochs=(r+1) * EPOCHS_PER_CYCLE,
              validation_split=0.2)
    for _ in range(10):
        ind = np.random.randint(0, len(x))
        rowx, rowy = x[np.array([ind])], y[np.array([ind])]
        preds = model.predict_classes(rowx, verbose=0)
        q = ctable.decode(rowx[0])
        correct = ctable.decode(rowy[0])
        guess = ctable.decode(preds[0], calc_argmax=False)
        print(f'Q {q[::-1]} T {correct} {"==" if correct == guess else "!="} {guess}')


Train on 40000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Q 28+116  T  144 !=  135
Q 35+515  T  550 !=  539
Q 979+9   T  988 !=  974
Q 65+229  T  294 !=  297
Q 98+68   T  166 !=  164
Q 543+44  T  587 !=  581
Q 362+75  T  437 !=  411
Q 6+276   T  282 !=  279
Q 5+876   T  881 !=  884
Q 629+98  T  727 !=  747
Train on 40000 samples, validate on 10000 samples
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Q 110+193 T  303 !=  308
Q 178+1   T  179 !=  177
Q 665+509 T 1174 != 1177
Q 583+0   T  583 ==  583
Q 66+847  T  913 !=  914
Q 678+408 T 1086 != 1081
Q 17+28   T   45 !=   41
Q 4+983   T  987 !=  984
Q 243+998 T 1241 != 1231
Q 75+26   T  101 ==  101
Train on 40000 samples, validate on 10000 samples
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30

Q 268+66  T  334 ==  334
Q 646+2   T  648 ==  648
Q 484+754 T 1238 == 1238
Q 38+469  T  507 ==  507
Q 169+840 T 1009 == 1009
Q 103+8   T  111 ==  111
Q 28+500  T  528 ==  528
Q 47+202  T  249 ==  249
Q 881+136 T 1017 == 1017
Q 72+708  T  780 ==  780
Wall time: 3min 25s


## 개별 검증

In [51]:
test_input = np.zeros((3, MAXLEN, len(chars)), dtype=np.bool)

test_input[0] = ctable.encode('123+321'[::-1], MAXLEN)  # 444
test_input[1] = ctable.encode('12+32  '[::-1], MAXLEN)  # 44
test_input[2] = ctable.encode('1+3    '[::-1], MAXLEN)  # 4

test_input

array([[[False, False, False,  True, False, False, False, False, False,
         False, False, False],
        [False, False, False, False,  True, False, False, False, False,
         False, False, False],
        [False, False, False, False, False,  True, False, False, False,
         False, False, False],
        [False,  True, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False, False,  True, False, False, False,
         False, False, False],
        [False, False, False, False,  True, False, False, False, False,
         False, False, False],
        [False, False, False,  True, False, False, False, False, False,
         False, False, False]],

       [[ True, False, False, False, False, False, False, False, False,
         False, False, False],
        [ True, False, False, False, False, False, False, False, False,
         False, False, False],
        [False, False, False, False,  True, False, False, False, False,

In [52]:
res = model.predict(test_input)
for r in res:
    print(ctable.decode(r))

 444
  44
   4
