Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
A3C example (#4516)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Jan 4, 2017
1 parent fbb6885 commit f4b8317
Show file tree
Hide file tree
Showing 5 changed files with 529 additions and 0 deletions.
14 changes: 14 additions & 0 deletions example/reinforcement-learning/a3c/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# A3C Implementation
This is an attempt to implement the A3C algorithm in paper Asynchronous Methods for Deep Reinforcement Learning.

Author: Junyuan Xie (@piiswrong)

The algorithm should be mostly correct. However I cannot reproduce the result in the paper, possibly due to hyperparameter settings. If you can find a better set of parameters please propose a pull request.

Note this is a generalization of the original algorithm since we use `batch_size` threads for each worker instead of the original 1 thread.

## Usage
run `python a3c.py --batch-size=32 --gpus=0` to run training on gpu 0 with batch-size=32.

run `python launcher.py --gpus=0,1 -n 2 python a3c.py` to launch training on 2 gpus (0 and 1), each gpu has two workers.

222 changes: 222 additions & 0 deletions example/reinforcement-learning/a3c/a3c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import mxnet as mx
import numpy as np
import rl_data
import sym
import argparse
import logging
import os
import gym
from datetime import datetime
import time

parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym')
parser.add_argument('--test', action='store_true', help='run testing', default=False)
parser.add_argument('--log-file', type=str, help='the name of log file')
parser.add_argument('--log-dir', type=str, default="./log", help='directory of the log file')
parser.add_argument('--model-prefix', type=str, help='the prefix of the model to load')
parser.add_argument('--save-model-prefix', type=str, help='the prefix of the model to save')
parser.add_argument('--load-epoch', type=int, help="load the model on an epoch using the model-prefix")

parser.add_argument('--kv-store', type=str, default='device', help='the kvstore type')
parser.add_argument('--gpus', type=str, help='the gpus will be used, e.g "0,1,2,3"')

parser.add_argument('--num-epochs', type=int, default=120, help='the number of training epochs')
parser.add_argument('--num-examples', type=int, default=1000000, help='the number of training examples')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--input-length', type=int, default=4)

parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--wd', type=float, default=0)
parser.add_argument('--t-max', type=int, default=4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--beta', type=float, default=0.08)

args = parser.parse_args()

def log_config(log_dir=None, log_file=None, prefix=None, rank=0):
reload(logging)
head = '%(asctime)-15s Node[' + str(rank) + '] %(message)s'
if log_dir:
logging.basicConfig(level=logging.DEBUG, format=head)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not log_file:
log_file = (prefix if prefix else '') + datetime.now().strftime('_%Y_%m_%d-%H_%M.log')
log_file = log_file.replace('/', '-')
else:
log_file = log_file
log_file_full_name = os.path.join(log_dir, log_file)
handler = logging.FileHandler(log_file_full_name, mode='w')
formatter = logging.Formatter(head)
handler.setFormatter(formatter)
logging.getLogger().addHandler(handler)
logging.info('start with arguments %s', args)
else:
logging.basicConfig(level=logging.DEBUG, format=head)
logging.info('start with arguments %s', args)

def train():
# kvstore
kv = mx.kvstore.create(args.kv_store)

model_prefix = args.model_prefix
if model_prefix is not None:
model_prefix += "-%d" % (kv.rank)
save_model_prefix = args.save_model_prefix
if save_model_prefix is None:
save_model_prefix = model_prefix

log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank)

devs = mx.cpu() if args.gpus is None else [
mx.gpu(int(i)) for i in args.gpus.split(',')]

epoch_size = args.num_examples / args.batch_size

if args.kv_store == 'dist_sync':
epoch_size /= kv.num_workers

# disable kvstore for single device
if 'local' in kv.type and (
args.gpus is None or len(args.gpus.split(',')) is 1):
kv = None

# module
dataiter = rl_data.GymDataIter('Breakout-v0', args.batch_size, args.input_length, web_viz=True)
net = sym.get_symbol_atari(dataiter.act_dim)
module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs)
module.bind(data_shapes=dataiter.provide_data,
label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))],
grad_req='add')

# load model

if args.load_epoch is not None:
assert model_prefix is not None
_, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch)
else:
arg_params = aux_params = None

# save model
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)

init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'],
[mx.init.Uniform(0.001), mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)])
module.init_params(initializer=init,
arg_params=arg_params, aux_params=aux_params)

# optimizer
module.init_optimizer(kvstore=kv, optimizer='adam',
optimizer_params={'learning_rate': args.lr, 'wd': args.wd, 'epsilon': 1e-3})

# logging
np.set_printoptions(precision=3, suppress=True)

T = 0
dataiter.reset()
score = np.zeros((args.batch_size, 1))
final_score = np.zeros((args.batch_size, 1))
for epoch in range(args.num_epochs):
if save_model_prefix:
module.save_params('%s-%04d.params'%(save_model_prefix, epoch))


for _ in range(epoch_size/args.t_max):
tic = time.time()
# clear gradients
for exe in module._exec_group.grad_arrays:
for g in exe:
g[:] = 0

S, A, V, r, D = [], [], [], [], []
for t in range(args.t_max + 1):
data = dataiter.data()
module.forward(mx.io.DataBatch(data=data, label=None), is_train=False)
act, _, val = module.get_outputs()
V.append(val.asnumpy())
if t < args.t_max:
act = act.asnumpy()
act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])]
reward, done = dataiter.act(act)
S.append(data)
A.append(act)
r.append(reward.reshape((-1, 1)))
D.append(done.reshape((-1, 1)))

err = 0
R = V[args.t_max]
for i in reversed(range(args.t_max)):
R = r[i] + args.gamma * (1 - D[i]) * R
adv = np.tile(R - V[i], (1, dataiter.act_dim))

batch = mx.io.DataBatch(data=S[i], label=[mx.nd.array(A[i]), mx.nd.array(R)])
module.forward(batch, is_train=True)

pi = module.get_outputs()[1]
h = args.beta*(mx.nd.log(pi+1e-6)+1)
module.backward([mx.nd.array(adv), h])

print 'pi', pi[0].asnumpy()
print 'h', h[0].asnumpy()
err += (adv**2).mean()
score += r[i]
final_score *= (1-D[i])
final_score += score * D[i]
score *= 1-D[i]
T += D[i].sum()

module.update()
logging.info('fps: %f err: %f score: %f final: %f T: %f'%(args.batch_size/(time.time()-tic), err/args.t_max, score.mean(), final_score.mean(), T))
print score.squeeze()
print final_score.squeeze()

def test():
log_config()

devs = mx.cpu() if args.gpus is None else [
mx.gpu(int(i)) for i in args.gpus.split(',')]

# module
dataiter = robo_data.RobosimsDataIter('scenes', args.batch_size, args.input_length, web_viz=True)
print dataiter.provide_data
net = sym.get_symbol_thor(dataiter.act_dim)
module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs)
module.bind(data_shapes=dataiter.provide_data,
label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))],
for_training=False)

# load model
assert args.load_epoch is not None
assert args.model_prefix is not None
module.load_params('%s-%04d.params'%(args.model_prefix, args.load_epoch))

N = args.num_epochs * args.num_examples / args.batch_size

R = 0
T = 1e-20
score = np.zeros((args.batch_size,))
for t in range(N):
dataiter.clear_history()
data = dataiter.next()
module.forward(data, is_train=False)
act = module.get_outputs()[0].asnumpy()
act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])]
dataiter.act(act)
time.sleep(0.05)
_, reward, _, done = dataiter.history[0]
T += done.sum()
score += reward
R += (done*score).sum()
score *= (1-done)

if t % 100 == 0:
logging.info('n %d score: %f T: %f'%(t, R/T, T))


if __name__ == '__main__':
if args.test:
test()
else:
train()


121 changes: 121 additions & 0 deletions example/reinforcement-learning/a3c/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Submission job for local jobs."""
# pylint: disable=invalid-name
from __future__ import absolute_import

import sys
import os
import subprocess
import logging
from threading import Thread
import argparse
import signal

sys.path.append(os.path.join(os.environ['HOME'], "mxnet/dmlc-core/tracker"))
sys.path.append(os.path.join('/scratch', "mxnet/dmlc-core/tracker"))
from dmlc_tracker import tracker

keepalive = """
nrep=0
rc=254
while [ $rc -ne 0 ];
do
export DMLC_NUM_ATTEMPT=$nrep
%s
rc=$?;
nrep=$((nrep+1));
done
"""

def exec_cmd(cmd, role, taskid, pass_env):
"""Execute the command line command."""
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt':
cmd[0] = './' + cmd[0]
cmd = ' '.join(cmd)
env = os.environ.copy()
for k, v in pass_env.items():
env[k] = str(v)

env['DMLC_TASK_ID'] = str(taskid)
env['DMLC_ROLE'] = role
env['DMLC_JOB_CLUSTER'] = 'local'

ntrial = 0
while True:
if os.name == 'nt':
env['DMLC_NUM_ATTEMPT'] = str(ntrial)
ret = subprocess.call(cmd, shell=True, env=env)
if ret != 0:
ntrial += 1
continue
else:
bash = cmd
ret = subprocess.call(bash, shell=True, executable='bash', env=env)
if ret == 0:
logging.debug('Thread %d exit with 0', taskid)
return
else:
if os.name == 'nt':
sys.exit(-1)
else:
raise RuntimeError('Get nonzero return code=%d' % ret)

def submit(args):
gpus = args.gpus.strip().split(',')
"""Submit function of local jobs."""
def mthread_submit(nworker, nserver, envs):
"""
customized submit script, that submit nslave jobs, each must contain args as parameter
note this can be a lambda function containing additional parameters in input
Parameters
----------
nworker: number of slave process to start up
nserver: number of server nodes to start up
envs: enviroment variables to be added to the starting programs
"""
procs = {}
for i, gpu in enumerate(gpus):
for j in range(args.num_threads):
procs[i] = Thread(target=exec_cmd, args=(args.command + ['--gpus=%s'%gpu], 'worker', i*args.num_threads+j, envs))
procs[i].setDaemon(True)
procs[i].start()
for i in range(len(gpus)*args.num_threads, len(gpus)*args.num_threads + nserver):
procs[i] = Thread(target=exec_cmd, args=(args.command, 'server', i, envs))
procs[i].setDaemon(True)
procs[i].start()

# call submit, with nslave, the commands to run each job and submit function
tracker.submit(args.num_threads*len(gpus), args.num_servers, fun_submit=mthread_submit,
pscmd=(' '.join(args.command)))

def signal_handler(signal, frame):
logging.info('Stop luancher')
sys.exit(0)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Launch a distributed job')
parser.add_argument('--gpus', type=str, help='the gpus will be used, e.g "0,1,2,3"')
parser.add_argument('-n', '--num-threads', required=True, type=int,
help = 'number of threads per gpu')
parser.add_argument('-s', '--num-servers', type=int,
help = 'number of server nodes to be launched, \
in default it is equal to NUM_WORKERS')
parser.add_argument('-H', '--hostfile', type=str,
help = 'the hostfile of slave machines which will run \
the job. Required for ssh and mpi launcher')
parser.add_argument('--sync-dst-dir', type=str,
help = 'if specificed, it will sync the current \
directory into slave machines\'s SYNC_DST_DIR if ssh \
launcher is used')
parser.add_argument('--launcher', type=str, default='local',
choices = ['local', 'ssh', 'mpi', 'sge', 'yarn'],
help = 'the launcher to use')
parser.add_argument('command', nargs='+',
help = 'command for launching the program')
args, unknown = parser.parse_known_args()
args.command += unknown
if args.num_servers is None:
args.num_servers = args.num_threads * len(args.gpus.strip().split(','))

signal.signal(signal.SIGINT, signal_handler)
submit(args)
Loading

0 comments on commit f4b8317

Please sign in to comment.