In [39]:
import keras.backend.tensorflow_backend as K
from keras.layers import Input, Lambda, Concatenate, RepeatVector, Reshape, Permute
from keras.models import Model

In [2]:
import numpy as np

# Encoder Sequence 10, Decoder Sequence 8, Latent Dim 2 가정

In [4]:
enc_seq_length = 10
dec_seq_length = 8
latent_dim = 2


enc_input = np.random.randint(10, size=(1, enc_seq_length, latent_dim))
dec_input = np.random.randint(10, size=(1, dec_seq_length, latent_dim))

In [6]:
print('enc_input')
print(enc_input)
print('')
print('dec_input')
print(dec_input)

enc_input
[[[2 5]
  [2 8]
  [7 0]
  [7 8]
  [7 4]
  [5 2]
  [6 7]
  [1 1]
  [2 4]
  [1 5]]]

dec_input
[[[3 6]
  [1 3]
  [2 5]
  [7 4]
  [0 9]
  [9 5]
  [0 1]
  [5 7]]]


# 1. 테스트 모델 생성 Repeat Element + Dot Product

In [9]:
def RepeatVectorLayer(rep, axis):
    return Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis), rep, axis))

In [13]:

def get_repeat_element_dot_product_model():
    encoder_inputs = Input(shape=(enc_seq_length, latent_dim))
    decoder_inputs = Input(shape=(dec_seq_length, latent_dim))

    # 최소공배수만큼으로 만들고 서로 곱해준다
    repeat_d_layer = RepeatVectorLayer(rep=enc_seq_length, axis=2)
    repeat_d = repeat_d_layer(decoder_inputs)

    repeat_e_layer = RepeatVectorLayer(rep=dec_seq_length, axis=1)
    repeat_e = repeat_e_layer(encoder_inputs)

    concat = Concatenate()([repeat_d, repeat_e])

    model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=concat)
    model.compile(loss='mse', optimizer='adam')
    return model

In [23]:
model = get_repeat_element_dot_product_model()
a = model.predict([enc_input, dec_input])
print(a)

[[[[3. 6. 2. 5.]
   [3. 6. 2. 8.]
   [3. 6. 7. 0.]
   [3. 6. 7. 8.]
   [3. 6. 7. 4.]
   [3. 6. 5. 2.]
   [3. 6. 6. 7.]
   [3. 6. 1. 1.]
   [3. 6. 2. 4.]
   [3. 6. 1. 5.]]

  [[1. 3. 2. 5.]
   [1. 3. 2. 8.]
   [1. 3. 7. 0.]
   [1. 3. 7. 8.]
   [1. 3. 7. 4.]
   [1. 3. 5. 2.]
   [1. 3. 6. 7.]
   [1. 3. 1. 1.]
   [1. 3. 2. 4.]
   [1. 3. 1. 5.]]

  [[2. 5. 2. 5.]
   [2. 5. 2. 8.]
   [2. 5. 7. 0.]
   [2. 5. 7. 8.]
   [2. 5. 7. 4.]
   [2. 5. 5. 2.]
   [2. 5. 6. 7.]
   [2. 5. 1. 1.]
   [2. 5. 2. 4.]
   [2. 5. 1. 5.]]

  [[7. 4. 2. 5.]
   [7. 4. 2. 8.]
   [7. 4. 7. 0.]
   [7. 4. 7. 8.]
   [7. 4. 7. 4.]
   [7. 4. 5. 2.]
   [7. 4. 6. 7.]
   [7. 4. 1. 1.]
   [7. 4. 2. 4.]
   [7. 4. 1. 5.]]

  [[0. 9. 2. 5.]
   [0. 9. 2. 8.]
   [0. 9. 7. 0.]
   [0. 9. 7. 8.]
   [0. 9. 7. 4.]
   [0. 9. 5. 2.]
   [0. 9. 6. 7.]
   [0. 9. 1. 1.]
   [0. 9. 2. 4.]
   [0. 9. 1. 5.]]

  [[9. 5. 2. 5.]
   [9. 5. 2. 8.]
   [9. 5. 7. 0.]
   [9. 5. 7. 8.]
   [9. 5. 7. 4.]
   [9. 5. 5. 2.]
   [9. 5. 6. 7.]
   [9. 5. 1. 1.]
   [

In [24]:
a.shape

(1, 8, 10, 4)

# 2. Repeat Vector + Reshape으로 위의 것을 대신할 수 있는지 테스트

In [83]:

def get_reshape_repeat_model():
    encoder_inputs = Input(shape=(enc_seq_length, latent_dim))
    decoder_inputs = Input(shape=(dec_seq_length, latent_dim))

    # 최소공배수만큼으로 만들고 서로 곱해준다
    reshaped_encoder_inputs = Reshape((enc_seq_length * latent_dim, ))(encoder_inputs)
    reshaped_decoder_inputs = Reshape((dec_seq_length * latent_dim, ))(decoder_inputs)
    enc_repeat_vector = RepeatVector(dec_seq_length)(reshaped_encoder_inputs)
    dec_repeat_vector = RepeatVector(enc_seq_length)(reshaped_decoder_inputs)
    print(enc_repeat_vector.get_shape())
    print(dec_repeat_vector.get_shape())
    reshape_enc_repeat_vector = Reshape((dec_seq_length, enc_seq_length, latent_dim))(enc_repeat_vector)
    reshape_dec_repeat_vector = Reshape((enc_seq_length, dec_seq_length, latent_dim))(dec_repeat_vector)
    reshape_dec_repeat_vector = Permute((2, 1, 3))(reshape_dec_repeat_vector)
    print(reshape_enc_repeat_vector.get_shape())
    print(reshape_dec_repeat_vector.get_shape())

#     concat = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 1]))([reshape_enc_repeat_vector, reshape_dec_repeat_vector])
    concat = Concatenate()([reshape_dec_repeat_vector, reshape_enc_repeat_vector])

    model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=concat)
    model.compile(loss='mse', optimizer='adam')
    return model

In [84]:
model = get_reshape_repeat_model()
b = model.predict([enc_input, dec_input])
print(b)

(?, 8, 20)
(?, 10, 16)
(?, 8, 10, 2)
(?, 8, 10, 2)
[[[[3. 6. 2. 5.]
   [3. 6. 2. 8.]
   [3. 6. 7. 0.]
   [3. 6. 7. 8.]
   [3. 6. 7. 4.]
   [3. 6. 5. 2.]
   [3. 6. 6. 7.]
   [3. 6. 1. 1.]
   [3. 6. 2. 4.]
   [3. 6. 1. 5.]]

  [[1. 3. 2. 5.]
   [1. 3. 2. 8.]
   [1. 3. 7. 0.]
   [1. 3. 7. 8.]
   [1. 3. 7. 4.]
   [1. 3. 5. 2.]
   [1. 3. 6. 7.]
   [1. 3. 1. 1.]
   [1. 3. 2. 4.]
   [1. 3. 1. 5.]]

  [[2. 5. 2. 5.]
   [2. 5. 2. 8.]
   [2. 5. 7. 0.]
   [2. 5. 7. 8.]
   [2. 5. 7. 4.]
   [2. 5. 5. 2.]
   [2. 5. 6. 7.]
   [2. 5. 1. 1.]
   [2. 5. 2. 4.]
   [2. 5. 1. 5.]]

  [[7. 4. 2. 5.]
   [7. 4. 2. 8.]
   [7. 4. 7. 0.]
   [7. 4. 7. 8.]
   [7. 4. 7. 4.]
   [7. 4. 5. 2.]
   [7. 4. 6. 7.]
   [7. 4. 1. 1.]
   [7. 4. 2. 4.]
   [7. 4. 1. 5.]]

  [[0. 9. 2. 5.]
   [0. 9. 2. 8.]
   [0. 9. 7. 0.]
   [0. 9. 7. 8.]
   [0. 9. 7. 4.]
   [0. 9. 5. 2.]
   [0. 9. 6. 7.]
   [0. 9. 1. 1.]
   [0. 9. 2. 4.]
   [0. 9. 1. 5.]]

  [[9. 5. 2. 5.]
   [9. 5. 2. 8.]
   [9. 5. 7. 0.]
   [9. 5. 7. 8.]
   [9. 5. 7. 4.]
   [

In [85]:
b.shape

(1, 8, 10, 4)

# Check Equality

In [86]:
np.equal(a, b)

array([[[[ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  Tru