This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
529 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# A3C Implementation | ||
This is an attempt to implement the A3C algorithm in paper Asynchronous Methods for Deep Reinforcement Learning. | ||
|
||
Author: Junyuan Xie (@piiswrong) | ||
|
||
The algorithm should be mostly correct. However I cannot reproduce the result in the paper, possibly due to hyperparameter settings. If you can find a better set of parameters please propose a pull request. | ||
|
||
Note this is a generalization of the original algorithm since we use `batch_size` threads for each worker instead of the original 1 thread. | ||
|
||
## Usage | ||
run `python a3c.py --batch-size=32 --gpus=0` to run training on gpu 0 with batch-size=32. | ||
|
||
run `python launcher.py --gpus=0,1 -n 2 python a3c.py` to launch training on 2 gpus (0 and 1), each gpu has two workers. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
import rl_data | ||
import sym | ||
import argparse | ||
import logging | ||
import os | ||
import gym | ||
from datetime import datetime | ||
import time | ||
|
||
parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym') | ||
parser.add_argument('--test', action='store_true', help='run testing', default=False) | ||
parser.add_argument('--log-file', type=str, help='the name of log file') | ||
parser.add_argument('--log-dir', type=str, default="./log", help='directory of the log file') | ||
parser.add_argument('--model-prefix', type=str, help='the prefix of the model to load') | ||
parser.add_argument('--save-model-prefix', type=str, help='the prefix of the model to save') | ||
parser.add_argument('--load-epoch', type=int, help="load the model on an epoch using the model-prefix") | ||
|
||
parser.add_argument('--kv-store', type=str, default='device', help='the kvstore type') | ||
parser.add_argument('--gpus', type=str, help='the gpus will be used, e.g "0,1,2,3"') | ||
|
||
parser.add_argument('--num-epochs', type=int, default=120, help='the number of training epochs') | ||
parser.add_argument('--num-examples', type=int, default=1000000, help='the number of training examples') | ||
parser.add_argument('--batch-size', type=int, default=32) | ||
parser.add_argument('--input-length', type=int, default=4) | ||
|
||
parser.add_argument('--lr', type=float, default=0.0001) | ||
parser.add_argument('--wd', type=float, default=0) | ||
parser.add_argument('--t-max', type=int, default=4) | ||
parser.add_argument('--gamma', type=float, default=0.99) | ||
parser.add_argument('--beta', type=float, default=0.08) | ||
|
||
args = parser.parse_args() | ||
|
||
def log_config(log_dir=None, log_file=None, prefix=None, rank=0): | ||
reload(logging) | ||
head = '%(asctime)-15s Node[' + str(rank) + '] %(message)s' | ||
if log_dir: | ||
logging.basicConfig(level=logging.DEBUG, format=head) | ||
if not os.path.exists(log_dir): | ||
os.makedirs(log_dir) | ||
if not log_file: | ||
log_file = (prefix if prefix else '') + datetime.now().strftime('_%Y_%m_%d-%H_%M.log') | ||
log_file = log_file.replace('/', '-') | ||
else: | ||
log_file = log_file | ||
log_file_full_name = os.path.join(log_dir, log_file) | ||
handler = logging.FileHandler(log_file_full_name, mode='w') | ||
formatter = logging.Formatter(head) | ||
handler.setFormatter(formatter) | ||
logging.getLogger().addHandler(handler) | ||
logging.info('start with arguments %s', args) | ||
else: | ||
logging.basicConfig(level=logging.DEBUG, format=head) | ||
logging.info('start with arguments %s', args) | ||
|
||
def train(): | ||
# kvstore | ||
kv = mx.kvstore.create(args.kv_store) | ||
|
||
model_prefix = args.model_prefix | ||
if model_prefix is not None: | ||
model_prefix += "-%d" % (kv.rank) | ||
save_model_prefix = args.save_model_prefix | ||
if save_model_prefix is None: | ||
save_model_prefix = model_prefix | ||
|
||
log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank) | ||
|
||
devs = mx.cpu() if args.gpus is None else [ | ||
mx.gpu(int(i)) for i in args.gpus.split(',')] | ||
|
||
epoch_size = args.num_examples / args.batch_size | ||
|
||
if args.kv_store == 'dist_sync': | ||
epoch_size /= kv.num_workers | ||
|
||
# disable kvstore for single device | ||
if 'local' in kv.type and ( | ||
args.gpus is None or len(args.gpus.split(',')) is 1): | ||
kv = None | ||
|
||
# module | ||
dataiter = rl_data.GymDataIter('Breakout-v0', args.batch_size, args.input_length, web_viz=True) | ||
net = sym.get_symbol_atari(dataiter.act_dim) | ||
module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs) | ||
module.bind(data_shapes=dataiter.provide_data, | ||
label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))], | ||
grad_req='add') | ||
|
||
# load model | ||
|
||
if args.load_epoch is not None: | ||
assert model_prefix is not None | ||
_, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch) | ||
else: | ||
arg_params = aux_params = None | ||
|
||
# save model | ||
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix) | ||
|
||
init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'], | ||
[mx.init.Uniform(0.001), mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)]) | ||
module.init_params(initializer=init, | ||
arg_params=arg_params, aux_params=aux_params) | ||
|
||
# optimizer | ||
module.init_optimizer(kvstore=kv, optimizer='adam', | ||
optimizer_params={'learning_rate': args.lr, 'wd': args.wd, 'epsilon': 1e-3}) | ||
|
||
# logging | ||
np.set_printoptions(precision=3, suppress=True) | ||
|
||
T = 0 | ||
dataiter.reset() | ||
score = np.zeros((args.batch_size, 1)) | ||
final_score = np.zeros((args.batch_size, 1)) | ||
for epoch in range(args.num_epochs): | ||
if save_model_prefix: | ||
module.save_params('%s-%04d.params'%(save_model_prefix, epoch)) | ||
|
||
|
||
for _ in range(epoch_size/args.t_max): | ||
tic = time.time() | ||
# clear gradients | ||
for exe in module._exec_group.grad_arrays: | ||
for g in exe: | ||
g[:] = 0 | ||
|
||
S, A, V, r, D = [], [], [], [], [] | ||
for t in range(args.t_max + 1): | ||
data = dataiter.data() | ||
module.forward(mx.io.DataBatch(data=data, label=None), is_train=False) | ||
act, _, val = module.get_outputs() | ||
V.append(val.asnumpy()) | ||
if t < args.t_max: | ||
act = act.asnumpy() | ||
act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])] | ||
reward, done = dataiter.act(act) | ||
S.append(data) | ||
A.append(act) | ||
r.append(reward.reshape((-1, 1))) | ||
D.append(done.reshape((-1, 1))) | ||
|
||
err = 0 | ||
R = V[args.t_max] | ||
for i in reversed(range(args.t_max)): | ||
R = r[i] + args.gamma * (1 - D[i]) * R | ||
adv = np.tile(R - V[i], (1, dataiter.act_dim)) | ||
|
||
batch = mx.io.DataBatch(data=S[i], label=[mx.nd.array(A[i]), mx.nd.array(R)]) | ||
module.forward(batch, is_train=True) | ||
|
||
pi = module.get_outputs()[1] | ||
h = args.beta*(mx.nd.log(pi+1e-6)+1) | ||
module.backward([mx.nd.array(adv), h]) | ||
|
||
print 'pi', pi[0].asnumpy() | ||
print 'h', h[0].asnumpy() | ||
err += (adv**2).mean() | ||
score += r[i] | ||
final_score *= (1-D[i]) | ||
final_score += score * D[i] | ||
score *= 1-D[i] | ||
T += D[i].sum() | ||
|
||
module.update() | ||
logging.info('fps: %f err: %f score: %f final: %f T: %f'%(args.batch_size/(time.time()-tic), err/args.t_max, score.mean(), final_score.mean(), T)) | ||
print score.squeeze() | ||
print final_score.squeeze() | ||
|
||
def test(): | ||
log_config() | ||
|
||
devs = mx.cpu() if args.gpus is None else [ | ||
mx.gpu(int(i)) for i in args.gpus.split(',')] | ||
|
||
# module | ||
dataiter = robo_data.RobosimsDataIter('scenes', args.batch_size, args.input_length, web_viz=True) | ||
print dataiter.provide_data | ||
net = sym.get_symbol_thor(dataiter.act_dim) | ||
module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs) | ||
module.bind(data_shapes=dataiter.provide_data, | ||
label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))], | ||
for_training=False) | ||
|
||
# load model | ||
assert args.load_epoch is not None | ||
assert args.model_prefix is not None | ||
module.load_params('%s-%04d.params'%(args.model_prefix, args.load_epoch)) | ||
|
||
N = args.num_epochs * args.num_examples / args.batch_size | ||
|
||
R = 0 | ||
T = 1e-20 | ||
score = np.zeros((args.batch_size,)) | ||
for t in range(N): | ||
dataiter.clear_history() | ||
data = dataiter.next() | ||
module.forward(data, is_train=False) | ||
act = module.get_outputs()[0].asnumpy() | ||
act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])] | ||
dataiter.act(act) | ||
time.sleep(0.05) | ||
_, reward, _, done = dataiter.history[0] | ||
T += done.sum() | ||
score += reward | ||
R += (done*score).sum() | ||
score *= (1-done) | ||
|
||
if t % 100 == 0: | ||
logging.info('n %d score: %f T: %f'%(t, R/T, T)) | ||
|
||
|
||
if __name__ == '__main__': | ||
if args.test: | ||
test() | ||
else: | ||
train() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Submission job for local jobs.""" | ||
# pylint: disable=invalid-name | ||
from __future__ import absolute_import | ||
|
||
import sys | ||
import os | ||
import subprocess | ||
import logging | ||
from threading import Thread | ||
import argparse | ||
import signal | ||
|
||
sys.path.append(os.path.join(os.environ['HOME'], "mxnet/dmlc-core/tracker")) | ||
sys.path.append(os.path.join('/scratch', "mxnet/dmlc-core/tracker")) | ||
from dmlc_tracker import tracker | ||
|
||
keepalive = """ | ||
nrep=0 | ||
rc=254 | ||
while [ $rc -ne 0 ]; | ||
do | ||
export DMLC_NUM_ATTEMPT=$nrep | ||
%s | ||
rc=$?; | ||
nrep=$((nrep+1)); | ||
done | ||
""" | ||
|
||
def exec_cmd(cmd, role, taskid, pass_env): | ||
"""Execute the command line command.""" | ||
if cmd[0].find('/') == -1 and os.path.exists(cmd[0]) and os.name != 'nt': | ||
cmd[0] = './' + cmd[0] | ||
cmd = ' '.join(cmd) | ||
env = os.environ.copy() | ||
for k, v in pass_env.items(): | ||
env[k] = str(v) | ||
|
||
env['DMLC_TASK_ID'] = str(taskid) | ||
env['DMLC_ROLE'] = role | ||
env['DMLC_JOB_CLUSTER'] = 'local' | ||
|
||
ntrial = 0 | ||
while True: | ||
if os.name == 'nt': | ||
env['DMLC_NUM_ATTEMPT'] = str(ntrial) | ||
ret = subprocess.call(cmd, shell=True, env=env) | ||
if ret != 0: | ||
ntrial += 1 | ||
continue | ||
else: | ||
bash = cmd | ||
ret = subprocess.call(bash, shell=True, executable='bash', env=env) | ||
if ret == 0: | ||
logging.debug('Thread %d exit with 0', taskid) | ||
return | ||
else: | ||
if os.name == 'nt': | ||
sys.exit(-1) | ||
else: | ||
raise RuntimeError('Get nonzero return code=%d' % ret) | ||
|
||
def submit(args): | ||
gpus = args.gpus.strip().split(',') | ||
"""Submit function of local jobs.""" | ||
def mthread_submit(nworker, nserver, envs): | ||
""" | ||
customized submit script, that submit nslave jobs, each must contain args as parameter | ||
note this can be a lambda function containing additional parameters in input | ||
Parameters | ||
---------- | ||
nworker: number of slave process to start up | ||
nserver: number of server nodes to start up | ||
envs: enviroment variables to be added to the starting programs | ||
""" | ||
procs = {} | ||
for i, gpu in enumerate(gpus): | ||
for j in range(args.num_threads): | ||
procs[i] = Thread(target=exec_cmd, args=(args.command + ['--gpus=%s'%gpu], 'worker', i*args.num_threads+j, envs)) | ||
procs[i].setDaemon(True) | ||
procs[i].start() | ||
for i in range(len(gpus)*args.num_threads, len(gpus)*args.num_threads + nserver): | ||
procs[i] = Thread(target=exec_cmd, args=(args.command, 'server', i, envs)) | ||
procs[i].setDaemon(True) | ||
procs[i].start() | ||
|
||
# call submit, with nslave, the commands to run each job and submit function | ||
tracker.submit(args.num_threads*len(gpus), args.num_servers, fun_submit=mthread_submit, | ||
pscmd=(' '.join(args.command))) | ||
|
||
def signal_handler(signal, frame): | ||
logging.info('Stop luancher') | ||
sys.exit(0) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Launch a distributed job') | ||
parser.add_argument('--gpus', type=str, help='the gpus will be used, e.g "0,1,2,3"') | ||
parser.add_argument('-n', '--num-threads', required=True, type=int, | ||
help = 'number of threads per gpu') | ||
parser.add_argument('-s', '--num-servers', type=int, | ||
help = 'number of server nodes to be launched, \ | ||
in default it is equal to NUM_WORKERS') | ||
parser.add_argument('-H', '--hostfile', type=str, | ||
help = 'the hostfile of slave machines which will run \ | ||
the job. Required for ssh and mpi launcher') | ||
parser.add_argument('--sync-dst-dir', type=str, | ||
help = 'if specificed, it will sync the current \ | ||
directory into slave machines\'s SYNC_DST_DIR if ssh \ | ||
launcher is used') | ||
parser.add_argument('--launcher', type=str, default='local', | ||
choices = ['local', 'ssh', 'mpi', 'sge', 'yarn'], | ||
help = 'the launcher to use') | ||
parser.add_argument('command', nargs='+', | ||
help = 'command for launching the program') | ||
args, unknown = parser.parse_known_args() | ||
args.command += unknown | ||
if args.num_servers is None: | ||
args.num_servers = args.num_threads * len(args.gpus.strip().split(',')) | ||
|
||
signal.signal(signal.SIGINT, signal_handler) | ||
submit(args) |
Oops, something went wrong.