Skip to content

Commit

Permalink
add commont
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 2, 2017
1 parent 384c669 commit c0fcdb6
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/show_battle_game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Interactive game, Pygame are required.
Act like a general and dispatch your soilders.
Act like a general and dispatch your solders.
"""

import os
Expand Down
84 changes: 83 additions & 1 deletion python/magent/builtin/mx_model/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,45 @@

class DeepQNetwork(MXBaseModel):
def __init__(self, env, handle, name,
batch_size=64, reward_decay=0.99, learning_rate=1e-4,
batch_size=64, learning_rate=1e-4, reward_decay=0.99,
train_freq=1, target_update=2000, memory_size=2 ** 20, eval_obs=None,
use_dueling=True, use_double=True, infer_batch_size=8192,
custom_view_space=None, custom_feature_space=None, num_gpu=1):
"""init a model
Parameters
----------
env: Environment
environment
handle: Handle (ctypes.c_int32)
handle of this group, can be got by env.get_handles
name: str
name of this model
learning_rate: float
batch_size: int
reward_decay: float
reward_decay in TD
train_freq: int
mean training times of a sample
target_update: int
target will update every target_update batches
memory_size: int
weight of entropy loss in total loss
eval_obs: numpy array
evaluation set of observation
use_dueling: bool
whether use dueling q network
use_double: bool
whether use double q network
num_gpu: int
number of gpu
infer_batch_size: int
batch size while inferring actions
custom_feature_space: tuple
customized feature space
custom_view_space: tuple
customized feature space
"""
MXBaseModel.__init__(self, env, handle, name, "mxdqn")
# ======================== set config ========================
self.env = env
Expand Down Expand Up @@ -101,6 +136,14 @@ def __init__(self, env, handle, name,
# mx.viz.plot_network(self.loss).view()

def _create_network(self, input_view, input_feature, use_conv=True):
"""define computation graph of network
Parameters
----------
input_view: mx.symbol
input_feature: mx.symbol
the input tensor
"""
kernel_num = [32, 32]
hidden_size = [256]

Expand Down Expand Up @@ -140,6 +183,24 @@ def _create_network(self, input_view, input_feature, use_conv=True):
return qvalues

def infer_action(self, raw_obs, ids, policy="e_greedy", eps=0):
"""infer action for a batch of agents
Parameters
----------
raw_obs: tuple(numpy array, numpy array)
raw observation of agents tuple(views, features)
ids: numpy array
ids of agents
policy: str
can be eps-greedy or greedy
eps: float
used when policy is eps-greedy
Returns
-------
acts: numpy array of int32
actions for agents
"""
view, feature = raw_obs[0], raw_obs[1]

if policy == 'e_greedy':
Expand Down Expand Up @@ -171,6 +232,7 @@ def infer_action(self, raw_obs, ids, policy="e_greedy", eps=0):
return ret.astype(np.int32)

def _calc_target(self, next_view, next_feature, rewards, terminal):
"""calculate target value"""
n = len(rewards)

data_batch = mx.io.DataBatch(data=[mx.nd.array(next_view), mx.nd.array(next_feature)])
Expand All @@ -191,6 +253,7 @@ def _calc_target(self, next_view, next_feature, rewards, terminal):
return target

def _add_to_replay_buffer(self, sample_buffer):
"""add samples in sample_buffer to replay buffer"""
n = 0
for episode in sample_buffer.episodes():
v, f, a, r = episode.views, episode.features, episode.actions, episode.rewards
Expand All @@ -217,6 +280,22 @@ def _add_to_replay_buffer(self, sample_buffer):
return n

def train(self, sample_buffer, print_every=1000):
""" add new samples in sample_buffer to replay buffer and train
Parameters
----------
sample_buffer: magent.utility.EpisodesBuffer
buffer contains samples
print_every: int
print log every print_every batches
Returns
-------
loss: float
bellman residual loss
value: float
estimated state value
"""
add_num = self._add_to_replay_buffer(sample_buffer)
batch_size = self.batch_size
total_loss = 0
Expand Down Expand Up @@ -274,6 +353,7 @@ def train(self, sample_buffer, print_every=1000):
return total_loss / ct if ct != 0 else 0, self._eval(batch_target)

def _reset_bind_size(self, new_size):
"""reset batch size"""
if self.bind_size == new_size:
return
else:
Expand All @@ -291,10 +371,12 @@ def _reshape(model, is_target):
_reshape(self.target_model, True)

def _copy_network(self, dest, source):
"""copy to target network"""
arg_params, aux_params = source.get_params()
dest.set_params(arg_params, aux_params)

def _eval(self, target):
"""evaluate estimated q value"""
if self.eval_obs is None:
return np.mean(target)
else:
Expand Down

0 comments on commit c0fcdb6

Please sign in to comment.