Skip to content

Commit

Permalink
1. Fix randomness in distributed algorithms
Browse files Browse the repository at this point in the history
2. Fix instability issue (nan) in TRPO
  • Loading branch information
iffiX committed Jun 22, 2021
1 parent bdc23ba commit a63302f
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 102 deletions.
5 changes: 4 additions & 1 deletion machin/frame/algorithms/trpo.py
Expand Up @@ -241,7 +241,10 @@ def fvp(v):

# usually 1e-15 is low enough
if t.allclose(loss_grad, t.zeros_like(loss_grad), atol=1e-15):
default_logger.warning("TRPO detects zero gradient.")
default_logger.warning(
"TRPO detects zero gradient, update step skipped."
)
return 0, 0

step_dir = self._conjugate_gradients(
fvp,
Expand Down
5 changes: 3 additions & 2 deletions machin/model/algorithms/trpo.py
Expand Up @@ -26,7 +26,8 @@ def sample(self, probability: t.tensor, action=None):
Action log probability tensor of shape ``[batch, 1]``.
"""
batch_size = probability.shape[0]
self.action_param = probability
# dx (xlnx) = lnx + 1, x must > 0
self.action_param = probability + 1e-6
dist = Categorical(probs=probability)
if action is None:
action = dist.sample()
Expand All @@ -41,7 +42,7 @@ def get_kl(self, *args, **kwargs):
self.forward(*args, **kwargs)
action_prob1 = self.action_param
action_prob0 = action_prob1.detach()
kl = action_prob0 * (t.log(action_prob0) - t.log(action_prob1))
kl = action_prob0 * (t.log(action_prob0 / action_prob1))
return kl.sum(1, keepdim=True)

def compare_kl(self, params: t.tensor, *args, **kwargs):
Expand Down
20 changes: 17 additions & 3 deletions test/frame/algorithms/test_a3c.py
Expand Up @@ -255,6 +255,8 @@ def test_config_init(rank):
)
@setup_world
def test_full_train(rank, gae_lambda):
training_group = get_world().create_rpc_group("training", ["0", "1", "2"])

c = TestA3C.c
a3c = TestA3C.a3c("cpu", t.float32)
a3c.set_sync(False)
Expand All @@ -264,11 +266,15 @@ def test_full_train(rank, gae_lambda):
reward_fulfilled = Counter()
smoother = Smooth()
terminal = False

env = c.env
env.seed(0)
env.seed(rank)

# make sure all things are initialized.
training_group.barrier()

# for cpu usage viewing
default_logger.info(f"{rank}, pid {os.getpid()}")

while episode < c.max_episodes:
episode.count()

Expand Down Expand Up @@ -314,8 +320,16 @@ def test_full_train(rank, gae_lambda):
reward_fulfilled.count()
if reward_fulfilled >= c.solved_repeat:
default_logger.info("Environment solved!")
return True
try:
training_group.pair(f"solved", True)
except KeyError:
# already solved in another process
pass
else:
reward_fulfilled.reset()

training_group.barrier()
if training_group.is_paired("solved"):
return True

raise RuntimeError("A3C Training failed.")
115 changes: 54 additions & 61 deletions test/frame/algorithms/test_apex.py
Expand Up @@ -253,6 +253,8 @@ def test_config_init(rank):
@run_multi(expected_results=[True, True, True], timeout=1800)
@setup_world
def test_full_train(rank):
training_group = get_world().create_rpc_group("training", ["0", "1", "2"])

c = TestDQNApex.c
dqn_apex = TestDQNApex.dqn_apex("cpu", t.float32)
# perform manual syncing to decrease the number of rpc calls
Expand All @@ -263,19 +265,19 @@ def test_full_train(rank):
reward_fulfilled = Counter()
smoother = Smooth()
terminal = False

env = c.env
env.seed(0)
world = get_world()
all_group = world.create_rpc_group("all", ["0", "1", "2"])
all_group.pair(f"{rank}_running", True)
env.seed(rank)

if rank in (0, 1):
while episode < c.max_episodes:
# wait for trainer to keep up
sleep(0.2)
episode.count()
# make sure all things are initialized.
training_group.barrier()

# for cpu usage viewing
default_logger.info(f"{rank}, pid {os.getpid()}")

while episode < c.max_episodes:
episode.count()

if rank in (0, 1):
# batch size = 1
total_reward = 0
state = t.tensor(env.reset(), dtype=t.float32)
Expand Down Expand Up @@ -313,31 +315,28 @@ def test_full_train(rank):
rank, episode, smoother.value
)
)

if smoother.value > c.solved_reward:
reward_fulfilled.count()
if reward_fulfilled >= c.solved_repeat:
default_logger.info("Environment solved!")

all_group.unpair(f"{rank}_running")
while all_group.is_paired("0_running") or all_group.is_paired(
"1_running"
):
# wait for all workers to join
sleep(1)
# wait for trainer
sleep(5)
return True
try:
training_group.pair(f"solved", True)
except KeyError:
# already solved in another process
pass
else:
reward_fulfilled.reset()
else:
# wait for some samples
while dqn_apex.replay_buffer.all_size() < 500:
sleep(0.1)
while all_group.is_paired("0_running") or all_group.is_paired("1_running"):
dqn_apex.update()
default_logger.info("Updated")
return True

else:
# wait for some samples
if episode.get() > 200:
for _ in range(100):
dqn_apex.update()
default_logger.info("Updated 100 times.")

training_group.barrier()
if training_group.is_paired("solved"):
return True

raise RuntimeError("DQN-Apex Training failed.")

Expand Down Expand Up @@ -584,6 +583,8 @@ def test_config_init(rank):
@run_multi(expected_results=[True, True, True], timeout=1800)
@setup_world
def test_full_train(rank):
training_group = get_world().create_rpc_group("training", ["0", "1", "2"])

c = TestDDPGApex.c
ddpg_apex = TestDDPGApex.ddpg_apex("cpu", t.float32, discrete=True)
# perform manual syncing to decrease the number of rpc calls
Expand All @@ -595,22 +596,19 @@ def test_full_train(rank):
reward_fulfilled = Counter()
smoother = Smooth()
terminal = False

env = c.env
env.seed(0)
world = get_world()
all_group = world.create_rpc_group("all", ["0", "1", "2"])
all_group.pair(f"{rank}_running", True)
env.seed(rank)

# make sure all things are initialized.
training_group.barrier()

# for cpu usage viewing
default_logger.info(f"{rank}, pid {os.getpid()}")
if rank == 0:
all_group.pair("episode", episode)

if rank in (0, 1):
while episode < c.max_episodes:
# wait for trainer to keep up
sleep(0.2)
episode.count()
while episode < c.max_episodes:
episode.count()

if rank in (0, 1):
# batch size = 1
total_reward = 0
state = t.tensor(env.reset(), dtype=t.float32)
Expand Down Expand Up @@ -651,27 +649,22 @@ def test_full_train(rank):

if smoother.value > c.solved_reward:
reward_fulfilled.count()
if reward_fulfilled >= c.solved_repeat:
default_logger.info("Environment solved!")

all_group.unpair(f"{rank}_running")
while all_group.is_paired("0_running") or all_group.is_paired(
"1_running"
):
# wait for all workers to join
sleep(1)
# wait for trainer
sleep(5)
return True
try:
training_group.pair(f"solved", True)
except KeyError:
# already solved in another process
pass
else:
reward_fulfilled.reset()
else:
# wait for some samples
while ddpg_apex.replay_buffer.all_size() < 500:
sleep(0.1)
while all_group.is_paired("0_running") or all_group.is_paired("1_running"):
ddpg_apex.update()
default_logger.info(f"Updated")
return True
else:
# wait for some samples
if episode.get() > 200:
for _ in range(100):
ddpg_apex.update()
default_logger.info("Updated 100 times.")

training_group.barrier()
if training_group.is_paired("solved"):
return True

raise RuntimeError("DDPG-Apex Training failed.")
20 changes: 17 additions & 3 deletions test/frame/algorithms/test_ars.py
Expand Up @@ -266,6 +266,8 @@ def test_config_init(_):
@run_multi(expected_results=[True, True, True], timeout=1800)
@setup_world
def test_full_train(rank):
training_group = get_world().create_rpc_group("training", ["0", "1", "2"])

c = TestARS.c
ars = TestARS.ars("cpu", t.float32)

Expand All @@ -274,11 +276,15 @@ def test_full_train(rank):
reward_fulfilled = Counter()
smoother = Smooth()
terminal = False

env = c.env
env.seed(0)
env.seed(rank)

# for cpu usage viewing
default_logger.info(f"{rank}, pid {os.getpid()}")

# make sure all things are initialized.
training_group.barrier()

while episode < c.max_episodes:
episode.count()

Expand Down Expand Up @@ -312,8 +318,16 @@ def test_full_train(rank):
reward_fulfilled.count()
if reward_fulfilled >= c.solved_repeat:
default_logger.info("Environment solved!")
raise SafeExit
try:
training_group.pair(f"solved", True)
except KeyError:
# already solved in another process
pass
else:
reward_fulfilled.reset()

training_group.barrier()
if training_group.is_paired("solved"):
return True

raise RuntimeError("ARS Training failed.")
58 changes: 26 additions & 32 deletions test/frame/algorithms/test_impala.py
Expand Up @@ -331,6 +331,8 @@ def test_config_init(rank):
@run_multi(expected_results=[True, True, True], timeout=1800)
@setup_world
def test_full_train(rank):
training_group = get_world().create_rpc_group("training", ["0", "1", "2"])

c = TestIMPALA.c
impala = TestIMPALA.impala("cpu", t.float32)

Expand All @@ -342,22 +344,19 @@ def test_full_train(rank):
reward_fulfilled = Counter()
smoother = Smooth()
terminal = False

env = c.env
env.seed(0)
world = get_world()
all_group = world.create_rpc_group("all", ["0", "1", "2"])
all_group.pair(f"{rank}_running", True)
env.seed(rank)

# make sure all things are initialized.
training_group.barrier()

# for cpu usage viewing
default_logger.info(f"{rank}, pid {os.getpid()}")
if rank == 0:
all_group.pair("episode", episode)

if rank in (0, 1):
while episode < c.max_episodes:
# wait for trainer to keep up
sleep(0.2)
episode.count()
while episode < c.max_episodes:
episode.count()

if rank in (0, 1):
# batch size = 1
total_reward = 0
state = t.tensor(env.reset(), dtype=t.float32)
Expand Down Expand Up @@ -400,27 +399,22 @@ def test_full_train(rank):
reward_fulfilled.count()
if reward_fulfilled >= c.solved_repeat:
default_logger.info("Environment solved!")

all_group.unpair(f"{rank}_running")
while all_group.is_paired("0_running") or all_group.is_paired(
"1_running"
):
# wait for all workers to join
sleep(1)
# wait for trainer
sleep(5)
return True
try:
training_group.pair(f"solved", True)
except KeyError:
# already solved in another process
pass
else:
reward_fulfilled.reset()
else:
# wait for some samples
# Note: the number of entries in buffer means "episodes"
# rather than steps here!
while impala.replay_buffer.all_size() < 5:
sleep(0.1)
while all_group.is_paired("0_running") or all_group.is_paired("1_running"):
impala.update()
default_logger.info("Updated")
return True
else:
# wait for some samples
if episode.get() > 200:
for _ in range(100):
impala.update()
default_logger.info("Updated 100 times.")

training_group.barrier()
if training_group.is_paired("solved"):
return True

raise RuntimeError("IMPALA Training failed.")

0 comments on commit a63302f

Please sign in to comment.