<a href="https://colab.research.google.com/github/ofzlo/NLP-tutorial/blob/main/8.%20RNN/NLP_08_04_understanding_simpleRNN_and_lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 08-04 케라스의 SimpleRNN과 LSTM 이해하기

## 1. 임의의 입력 생성하기

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, LSTM, Bidirectional

In [2]:
# 단어 벡터 차원 5, 문장 길이가 4인 경우를 가정한 입력
# 4번의 시점(timesteps), 5차원의 단어 벡터
train_X = [[0.1, 4.2, 1.5, 1.1, 2.8], [1.0, 3.1, 2.5, 0.7, 1.1], [0.3, 2.1, 1.5, 2.1, 0.1], [2.2, 1.4, 0.5, 0.9, 1.1]]
print(np.shape(train_X))

(4, 5)


In [4]:
# RNN은 3D 텐서를 입력 받으므로 배치 크기 1을 추가해줌으로써 해결
# 샘플이 1개밖에 없으므로 batch_size = 1

train_X = [[[0.1, 4.2, 1.5, 1.1, 2.8], [1.0, 3.1, 2.5, 0.7, 1.1], [0.3, 2.1, 1.5, 2.1, 0.1], [2.2, 1.4, 0.5, 0.9, 1.1]]]
train_X = np.array(train_X, dtype=np.float32)
print(train_X.shape)

(1, 4, 5)


## 2. SimpleRNN 이해하기
- `return_sequences`, `return_state` 기본값 False   
  - `return_sequences가` False인 경우에는 SimpleRNN은 마지막 시점의 은닉 상태만 출력   
  - `return_state가` True일 경우에는 return_sequences의 True/False 여부와 상관없이 마지막 시점의 은닉 상태를 출력   
- 은닉 상태의 크기를 3으로 지정   
- 본 실습에서 SimpleRNN을 매번 재선언하므로 은닉 상태의 값 자체는 매번 초기화되어 이전 출력과 값의 일관성은 없음.   
  - 출력값 자체보다 해당 값의 크기(shape)에 주목할 것   

return_sequences가 False인 경우에는 SimpleRNN은 마지막 시점의 은닉 상태만 출력

In [5]:
rnn = SimpleRNN(3)
# rnn = SimpleRNN(3, return_sequences=False, return_state=False)와 동일.
hidden_state = rnn(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))

hidden state : [[-0.70272124  0.99516934  0.9120971 ]], shape: (1, 3)


입력 데이터 : (1, 4, 5)   
- 4가 시점(timesteps)에 해당하는 값이므로 모든 시점에 대해서 은닉 상태의 값을 출력하여 (1, 4, 3) 크기의 텐서를 출력

In [6]:
# return_sequences를 True로 지정
rnn = SimpleRNN(3, return_sequences=True)
hidden_states = rnn(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))

hidden states : [[[-0.99257535  0.84529537  0.9999421 ]
  [-0.9210455  -0.54077566  0.99943024]
  [-0.9752333  -0.8279698   0.06821049]
  [-0.9961943  -0.84069175  0.90067905]]], shape: (1, 4, 3)


In [7]:
rnn = SimpleRNN(3, return_sequences=True, return_state=True)
hidden_states, last_state = rnn(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))

hidden states : [[[-0.45682877  0.57765144  0.31072354]
  [-0.46510667 -0.1460671   0.14140047]
  [-0.88290656  0.03890514  0.75322276]
  [ 0.30626792  0.26000983 -0.06197071]]], shape: (1, 4, 3)
last hidden state : [[ 0.30626792  0.26000983 -0.06197071]], shape: (1, 3)


In [8]:
rnn = SimpleRNN(3, return_sequences=False, return_state=True)
hidden_state, last_state = rnn(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))

hidden state : [[-0.6224874  -0.72633815 -0.85004586]], shape: (1, 3)
last hidden state : [[-0.6224874  -0.72633815 -0.85004586]], shape: (1, 3)


## 3. LSTM 이해하기


In [9]:
lstm = LSTM(3, return_sequences=False, return_state=True)
hidden_state, last_state, last_cell_state = lstm(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))
print('last cell state : {}, shape: {}'.format(last_cell_state, last_cell_state.shape))

hidden state : [[-0.14960948 -0.11244339 -0.13921316]], shape: (1, 3)
last hidden state : [[-0.14960948 -0.11244339 -0.13921316]], shape: (1, 3)
last cell state : [[-0.3116475  -0.3421559  -0.64889914]], shape: (1, 3)


In [10]:
lstm = LSTM(3, return_sequences=True, return_state=True)
hidden_states, last_hidden_state, last_cell_state = lstm(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('last hidden state : {}, shape: {}'.format(last_hidden_state, last_hidden_state.shape))
print('last cell state : {}, shape: {}'.format(last_cell_state, last_cell_state.shape))

hidden states : [[[-0.03672146 -0.6156713   0.49378574]
  [-0.025212   -0.53800076  0.32965654]
  [-0.19833891 -0.12230901  0.14449033]
  [-0.2036786  -0.2562331   0.24774766]]], shape: (1, 4, 3)
last hidden state : [[-0.2036786  -0.2562331   0.24774766]], shape: (1, 3)
last cell state : [[-0.34576106 -0.6270748   0.48928905]], shape: (1, 3)


## 4. Bidirectional(LSTM) 이해하기

In [11]:
k_init = tf.keras.initializers.Constant(value=0.1)
b_init = tf.keras.initializers.Constant(value=0)
r_init = tf.keras.initializers.Constant(value=0.1)

In [12]:
bilstm = Bidirectional(LSTM(3, return_sequences=False, return_state=True, \
                            kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init))
hidden_states, forward_h, forward_c, backward_h, backward_c = bilstm(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('forward state : {}, shape: {}'.format(forward_h, forward_h.shape))
print('backward state : {}, shape: {}'.format(backward_h, backward_h.shape))

hidden states : [[0.6303138 0.6303138 0.6303138 0.7038734 0.7038734 0.7038734]], shape: (1, 6)
forward state : [[0.6303138 0.6303138 0.6303138]], shape: (1, 3)
backward state : [[0.7038734 0.7038734 0.7038734]], shape: (1, 3)


In [13]:
bilstm = Bidirectional(LSTM(3, return_sequences=True, return_state=True, \
                            kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init))
hidden_states, forward_h, forward_c, backward_h, backward_c = bilstm(train_X)

In [15]:
print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('forward state : {}, shape: {}'.format(forward_h, forward_h.shape))
print('backward state : {}, shape: {}'.format(backward_h, backward_h.shape))

hidden states : [[[0.35906473 0.35906473 0.35906473 0.7038734  0.7038734  0.7038734 ]
  [0.55111325 0.55111325 0.55111325 0.58863586 0.58863586 0.58863586]
  [0.59115744 0.59115744 0.59115744 0.3951699  0.3951699  0.3951699 ]
  [0.6303138  0.6303138  0.6303138  0.21942244 0.21942244 0.21942244]]], shape: (1, 4, 6)
forward state : [[0.6303138 0.6303138 0.6303138]], shape: (1, 3)
backward state : [[0.7038734 0.7038734 0.7038734]], shape: (1, 3)
