In [1]:
import os
os.chdir('/home')

from warnings import filterwarnings
filterwarnings('ignore', module='skimage')

import numpy as np
import gym

from IPython import display
import matplotlib.pyplot as plt
%matplotlib inline

from kerasgym.models import cnn_model_base, DDPGModel, DQNModel
from kerasgym.agents import Agent
from kerasgym.agents.process_state import downsample, rgb_to_binary
from kerasgym.agents.process_state import stack_consecutive, combine_consecutive
from kerasgym.agents.process_action import argmax
from kerasgym.agents.exploration import LinearDecay, ScopingPeriodic, EpsilonGreedy
from kerasgym.agents.exploration import graph_schedule

Using TensorFlow backend.


In [2]:
# env
env = gym.make('Breakout-v0')
env.reset()

# custom shape due to downsampling and stacking
shape = (84, 84, 4)

# model
base_config = {
    'in_shape': shape,
    'conv_layer_sizes': [32, 32, 32],
    'fc_layer_sizes': [32, 16, 8],
    'kernel_sizes': [(3,3)]*3,
    'strides': [(1,1)]*3,
    'activation': 'relu'
}
base_model = cnn_model_base(**base_config)

'''
ddpg_config = {
    'action_dim': 4,
    'actor_activation': 'softmax',
    'gamma': 0.9,
    'tau': 0.01,
    'actor_alpha': 1e-4,
    'critic_alpha': 1e-4
}
model = DDPGModel(base_model, **ddpg_config)
'''

dqn_config = {
    'action_dim': 4,
    'gamma': 0.9,
    'tau': 0.01,
    'alpha': 1e-4
}
model = DQNModel(base_model, **dqn_config)

In [10]:
schedule = ScopingPeriodic(start_value=0.9, amp=0.4, period=0.15, duration=1000000)
explorer = EpsilonGreedy(schedule)
buffer_size = 10000
batch_size = 32
repeated_actions = 4

agent = Agent(env,
              state_processing_fns=[downsample(shape), rgb_to_binary(),
                                    #combine_consecutive(fun='diff'), 
                                    stack_consecutive(4)],
              model=model, action_processing_fn=argmax(),
              explorer=explorer, buffer_size=buffer_size,
              batch_size=batch_size, repeated_actions=repeated_actions)
agent.reset()

In [11]:
agent.run_indefinitely()

End of episode 0. Keep running...
End of episode 1. Keep running...
End of episode 2. Keep running...
End of episode 3. Keep running...
End of episode 4. Keep running...
End of episode 5. Keep running...
End of episode 6. Keep running...
End of episode 7. Keep running...
End of episode 8. Keep running...
End of episode 9. Keep running...
End of episode 10. Keep running...
End of episode 11. Keep running...
End of episode 12. Keep running...
End of episode 13. Keep running...
End of episode 14. Keep running...
End of episode 15. Keep running...
End of episode 16. Keep running...
End of episode 17. Keep running...
End of episode 18. Keep running...
End of episode 19. Keep running...
End of episode 20. Keep running...
End of episode 21. Keep running...
End of episode 22. Keep running...
End of episode 23. Keep running...
End of episode 24. Keep running...
End of episode 25. Keep running...
End of episode 26. Keep running...
End of episode 27. Keep running...
End of episode 28. Keep runnin

End of episode 231. Keep running...
End of episode 232. Keep running...
End of episode 233. Keep running...
End of episode 234. Keep running...
End of episode 235. Keep running...
End of episode 236. Keep running...
End of episode 237. Keep running...
End of episode 238. Keep running...
End of episode 239. Keep running...
End of episode 240. Keep running...
End of episode 241. Keep running...
End of episode 242. Keep running...
End of episode 243. Keep running...
End of episode 244. Keep running...
End of episode 245. Keep running...
End of episode 246. Keep running...
End of episode 247. Keep running...
End of episode 248. Keep running...
End of episode 249. Keep running...
End of episode 250. Keep running...
End of episode 251. Keep running...
End of episode 252. Keep running...
End of episode 253. Keep running...
End of episode 254. Keep running...
End of episode 255. Keep running...
End of episode 256. Keep running...
End of episode 257. Keep running...
End of episode 258. Keep run

In [5]:
agent.explorer.schedule.get()

0.9950657332885031

In [6]:
state = agent.env_state
for sp in agent.state_processors:
    state = sp(state, agent.env)

In [7]:
agent.model.predict(state)

array([-0.02376402,  0.0150344 ,  0.00839505,  0.00988599], dtype=float32)

In [14]:
import pandas as pd
pd.Series(agent.model.predict(agent.replay_buffer.get_batch(5000)['states'], single=False).argmax(axis=1)).value_counts()

1    5000
dtype: int64