Skip to content

Commit

Permalink
PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
hzxsnczpku committed Feb 8, 2018
1 parent 3fcd220 commit b0764bf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
56 changes: 43 additions & 13 deletions examples/CartPole/train_CartPole_PPO.py
Expand Up @@ -5,7 +5,7 @@
from basic_utils.options import *


def train_CartPole_adapted_PPO(load_model=False, render=False, save_every=None):
def train_CartPole_adapted_PPO(load_model=False, render=False, save_every=None, gamma=0.99, lam=0.98):
# set seed
torch.manual_seed(2)

Expand All @@ -28,8 +28,6 @@ def train_CartPole_adapted_PPO(load_model=False, render=False, save_every=None):
probtype,
epochs_v=10,
epochs_p=10,
lam=0.98,
gamma=0.99,
kl_target=0.003,
lr_updater=9e-4,
lr_optimizer=1e-3,
Expand All @@ -39,10 +37,20 @@ def train_CartPole_adapted_PPO(load_model=False, render=False, save_every=None):
beta_range=(1 / 35.0, 35.0),
kl_cutoff_coeff=50.0,
get_info=True)

if load_model:
agent.load_model("./save_model/" + env.name + "_" + agent.name)

# set data processor
single_processors = [
Scale_Reward(1 - gamma),
Calculate_Return(gamma),
Predict_Value(agent.baseline),
Calculate_Generalized_Advantage(gamma, lam),
Extract_Item_By_Name(["observation", "action", "advantage", "return"]),
Concatenate_Paths()
]
processor = Ensemble(single_processors)

# set data generator
generator = Parallel_Path_Data_Generator(agent=agent,
env=env,
Expand All @@ -54,17 +62,23 @@ def train_CartPole_adapted_PPO(load_model=False, render=False, save_every=None):
t = Path_Trainer(agent,
env,
data_generator=generator,
data_processor=processor,
save_every=save_every,
print_every=10)
t.train()


def train_CartPole_clip_PPO(load_model=False, render=False, save_every=None):
def train_CartPole_clip_PPO(load_model=False, render=False, save_every=None, gamma=0.99, lam=0.98):
# set seed
torch.manual_seed(2)

# set environment
env = Vec_env_wrapper(name='CartPole-v1', consec_frames=1, running_stat=True)
ob_space = env.observation_space

probtype = Categorical(env.action_space)

# set neural network
pol_net = MLPs_pol(ob_space, net_topology_pol_vec, probtype.output_layers)
v_net = MLPs_v(ob_space, net_topology_v_vec)
if use_cuda:
Expand All @@ -76,8 +90,6 @@ def train_CartPole_clip_PPO(load_model=False, render=False, save_every=None):
probtype,
epochs_v=10,
epochs_p=10,
lam=0.98,
gamma=0.99,
kl_target=0.003,
lr_updater=9e-4,
lr_optimizer=1e-3,
Expand All @@ -86,20 +98,38 @@ def train_CartPole_clip_PPO(load_model=False, render=False, save_every=None):
clip_range=(0.05, 0.3),
epsilon=0.2,
get_info=True)

if load_model:
agent.load_model("./save_model/" + env.name + "_" + agent.name)

# set data processor
single_processors = [
Scale_Reward(1 - gamma),
Calculate_Return(gamma),
Predict_Value(agent.baseline),
Calculate_Generalized_Advantage(gamma, lam),
Extract_Item_By_Name(["observation", "action", "advantage", "return"]),
Concatenate_Paths()
]
processor = Ensemble(single_processors)

# set data generator
generator = Parallel_Path_Data_Generator(agent=agent,
env=env,
n_worker=10,
path_num=1,
action_repeat=1,
render=render)

# set trainer
t = Path_Trainer(agent,
env,
n_worker=10,
path_num=10,
data_generator=generator,
data_processor=processor,
save_every=save_every,
render=render,
action_repeat=1,
print_every=10)
t.train()


if __name__ == '__main__':
train_CartPole_adapted_PPO()
# train_CartPole_adapted_PPO()
train_CartPole_clip_PPO()
14 changes: 6 additions & 8 deletions models/agents.py
Expand Up @@ -213,6 +213,8 @@ def __init__(self,
# Advantage Actor-Critic
# ================================================================
class A2C_Agent(Policy_Based_Agent):
name = 'A2C_Agent'

def __init__(self,
pol_net,
v_net,
Expand All @@ -238,22 +240,21 @@ def __init__(self,
policy = StochPolicy(net=pol_net, probtype=probtype, updater=updater)
baseline = ValueFunction(net=v_net, optimizer=optimizer)

self.name = 'A2C_Agent'
Policy_Based_Agent.__init__(self, baseline=baseline, policy=policy)


# ================================================================
# Proximal Policy Optimization
# ================================================================
class PPO_adapted_Agent(Policy_Based_Agent):
name = 'PPO_adapted_Agent'

def __init__(self,
pol_net,
v_net,
probtype,
epochs_v=10,
epochs_p=10,
lam=0.98,
gamma=0.99,
kl_target=0.003,
lr_updater=9e-4,
lr_optimizer=1e-3,
Expand Down Expand Up @@ -282,8 +283,7 @@ def __init__(self,
policy = StochPolicy(net=pol_net, probtype=probtype, updater=updater)
baseline = ValueFunction(net=v_net, optimizer=optimizer)

self.name = 'PPO_adapted_Agent'
Policy_Based_Agent.__init__(self, baseline=baseline, policy=policy, gamma=gamma, lam=lam)
Policy_Based_Agent.__init__(self, baseline=baseline, policy=policy)


class PPO_clip_Agent(Policy_Based_Agent):
Expand All @@ -293,8 +293,6 @@ def __init__(self,
probtype,
epochs_v=10,
epochs_p=10,
lam=0.98,
gamma=0.99,
kl_target=0.003,
lr_updater=9e-4,
lr_optimizer=1e-3,
Expand All @@ -321,7 +319,7 @@ def __init__(self,
baseline = ValueFunction(net=v_net, optimizer=optimizer)

self.name = 'PPO_clip_Agent'
Policy_Based_Agent.__init__(self, baseline=baseline, policy=policy, gamma=gamma, lam=lam)
Policy_Based_Agent.__init__(self, baseline=baseline, policy=policy)


# ================================================================
Expand Down

0 comments on commit b0764bf

Please sign in to comment.