In [2]:
from keras.layers import Input, Dense, Convolution2D, Deconvolution2D, MaxPooling2D, UpSampling2D, Merge, LSTM
from keras.models import Model
import keras
from keras.utils.visualize_util import plot


Using TensorFlow backend.


In [3]:
# network hyperparameters

IMAGE_SIZE_X = 84
IMAGE_SIZE_Y = 84
IMAGE_CHANNELS = 1

V_SIZE = 100
A_MAP_SIZE = 100
S_SIZE = 200

DENSE_SIZE = 200


In [4]:
def generate_encoder():
    X = Input(shape=(IMAGE_SIZE_X, IMAGE_SIZE_Y, IMAGE_CHANNELS))
    h = Convolution2D(4, 8, 8, activation='relu', border_mode='same')(X)
    h = MaxPooling2D((16, 16), border_mode='same')(h)

    h = keras.layers.core.Flatten()(h)
    V = Dense(V_SIZE, activation='relu')(h)
    V = keras.layers.core.Reshape((V_SIZE, 1))(V)
    
    encoder = Model(X, V, name='encoder')
#     encoder.compile(optimizer='adam', loss='mse')

    encoder.summary()
    return encoder
    
encoder = generate_encoder()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 84, 84, 1)     0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 84, 84, 4)     260         input_1[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 6, 6, 4)       0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 144)           0           maxpooling2d_1[0][0]             
___________________________________________________________________________________________

In [5]:
def generate_action_mapper():
    a_switch = Input(shape=(6,))
#     switch = Input(shape=(1,))
#     a_switch = Merge([a, switch], mode='concat')
    
    action_map = Dense(A_MAP_SIZE, activation='relu')(a_switch)
    
    action_mapper = Model(a_switch, action_map)
#     action_mapper.compile(optimizer='adam', loss='mse')

    action_mapper.summary()
    return action_mapper
  
action_mapper = generate_action_mapper()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_2 (InputLayer)             (None, 6)             0                                            
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 100)           700         input_2[0][0]                    
Total params: 700
____________________________________________________________________________________________________


In [6]:
def generate_state_predictor():
    V = Input(shape=(V_SIZE,1))
    S = LSTM(S_SIZE)(V)
    
    state_predictor = Model(V, S, name='state_predictor')
    
    state_predictor.summary()
    return state_predictor

state_predictor = generate_state_predictor()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_3 (InputLayer)             (None, 100, 1)        0                                            
____________________________________________________________________________________________________
lstm_1 (LSTM)                    (None, 200)           161600      input_3[0][0]                    
Total params: 161600
____________________________________________________________________________________________________


In [7]:
def generate_decoder():
    S = Input(shape=(S_SIZE,))
    h = Dense(DENSE_SIZE, activation='relu')(S)
    h = keras.layers.core.Reshape((10, 10, 2))(h)
    
    X_recon = Convolution2D(4, 8, 8, activation='relu', border_mode='same')(h)

    
    decoder = Model(S, X_recon, name='decoder')

    decoder.summary()
    return decoder
    
decoder = generate_decoder()    

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_4 (InputLayer)             (None, 200)           0                                            
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 200)           40200       input_4[0][0]                    
____________________________________________________________________________________________________
reshape_2 (Reshape)              (None, 10, 10, 2)     0           dense_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 10, 10, 4)     516         reshape_2[0][0]                  
Total params: 40716
_______________________________________________________________________

In [8]:
def generate_screen_predictor(encoder, decoder, state_predictor):
    X = Input(shape=(IMAGE_SIZE_X, IMAGE_SIZE_Y, IMAGE_CHANNELS))
    
    S = encoder(X)
    S_1 = state_predictor(S)
    X_1 = decoder(S_1)
    
    screen_predictor = Model(X, X_1, name='screen_predictor')
    
    screen_predictor.summary()
    return screen_predictor

screen_predictor = generate_screen_predictor(encoder, decoder, state_predictor)



____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_5 (InputLayer)             (None, 84, 84, 1)     0                                            
____________________________________________________________________________________________________
encoder (Model)                  (None, 100, 1)        14760       input_5[0][0]                    
____________________________________________________________________________________________________
state_predictor (Model)          (None, 200)           161600      encoder[1][0]                    
____________________________________________________________________________________________________
decoder (Model)                  (None, 10, 10, 4)     40716       state_predictor[1][0]            
Total params: 217076
______________________________________________________________________

In [9]:
plot(screen_predictor, to_file='model.png')