-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
954dfba
commit f5f17dc
Showing
77 changed files
with
1,315 additions
and
38 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from train import * | ||
from models.net_builder import * | ||
from basic_utils.env_wrapper import Vec_env_wrapper | ||
from models.agents import * | ||
from basic_utils.options import * | ||
from basic_utils.exploration_noise import * | ||
|
||
|
||
def train_CartPole_DQN(load_model=False, render=False, save_every=None, double=False, prioritized=False): | ||
torch.manual_seed(8833) | ||
np.random.seed(8833) | ||
env = Vec_env_wrapper(name='Acrobot-v1', consec_frames=1, running_stat=False, seed=23333) | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
mean_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
std_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_mean_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_std_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
noise = NoNoise_Exploration() | ||
|
||
if use_cuda: | ||
net.cuda() | ||
mean_net.cuda() | ||
std_net.cuda() | ||
target_net.cuda() | ||
target_mean_net.cuda() | ||
target_std_net.cuda() | ||
|
||
agent = Bayesian_DQN_Agent(net, | ||
mean_net, | ||
std_net, | ||
target_net, | ||
target_mean_net, | ||
target_std_net, | ||
alpha=1, | ||
beta=1e-4, | ||
gamma=0.95, | ||
lr=2e-3, | ||
scale=1e-4, | ||
update_target_every=1000, | ||
get_info=True) | ||
|
||
if prioritized: | ||
memory = PrioritizedReplayBuffer(memory_cap=10000, | ||
batch_size_q=64, | ||
alpha=0.8, | ||
beta=0.6) | ||
else: | ||
memory = ReplayBuffer(memory_cap=10000, batch_size_q=64) | ||
|
||
if load_model: | ||
agent.load_model("./save_model/" + env.name + "_" + agent.name) | ||
|
||
t = Mem_Trainer(agent=agent, | ||
env=env, | ||
memory=memory, | ||
n_worker=1, | ||
step_num=1, | ||
rand_explore_len=1000, | ||
save_every=save_every, | ||
render=render, | ||
print_every=10, | ||
noise=noise) | ||
t.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train_CartPole_DQN(save_every=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from train import * | ||
from models.net_builder import * | ||
from basic_utils.env_wrapper import Vec_env_wrapper | ||
from models.agents import * | ||
from basic_utils.options import * | ||
from basic_utils.exploration_noise import * | ||
|
||
|
||
def train_CartPole_DQN(load_model=False, render=False, save_every=None): | ||
torch.manual_seed(8833) | ||
env = Vec_env_wrapper(name='Acrobot-v1', consec_frames=1, running_stat=False, seed=23333) | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
net = MLPs_q(observation_space, action_space, net_topology_q_dropout_vec) | ||
net.train() | ||
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
noise = NoNoise_Exploration() | ||
|
||
if use_cuda: | ||
net.cuda() | ||
target_net.cuda() | ||
|
||
agent = DQN_Agent(net=net, | ||
target_net=target_net, | ||
gamma=0.95, | ||
lr=1e-3, | ||
update_target_every=1000, | ||
get_info=True) | ||
|
||
memory = ReplayBuffer(memory_cap=10000, batch_size_q=64) | ||
|
||
if load_model: | ||
agent.load_model("./save_model/" + env.name + "_" + agent.name) | ||
|
||
t = Mem_Trainer(agent=agent, | ||
env=env, | ||
memory=memory, | ||
n_worker=1, | ||
step_num=1, | ||
rand_explore_len=1000, | ||
save_every=save_every, | ||
render=render, | ||
print_every=10, | ||
noise=noise) | ||
t.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train_CartPole_DQN(save_every=5) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from train import * | ||
from models.net_builder import * | ||
from basic_utils.env_wrapper import Vec_env_wrapper | ||
from models.agents import * | ||
from basic_utils.options import * | ||
from basic_utils.exploration_noise import * | ||
|
||
|
||
def train_CartPole_DQN(load_model=False, render=False, save_every=None): | ||
torch.manual_seed(8833) | ||
env = Vec_env_wrapper(name='Acrobot-v1', consec_frames=1, running_stat=False, seed=23333) | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
noise = EpsilonGreedy_Exploration(action_n=action_space.n, | ||
explore_len=10000, | ||
init_epsilon=0.05, | ||
final_epsilon=0.05) | ||
|
||
if use_cuda: | ||
net.cuda() | ||
target_net.cuda() | ||
|
||
agent = DQN_Agent(net=net, | ||
target_net=target_net, | ||
gamma=0.95, | ||
lr=1e-3, | ||
update_target_every=1000, | ||
get_info=True) | ||
|
||
memory = ReplayBuffer(memory_cap=10000, batch_size_q=64) | ||
|
||
if load_model: | ||
agent.load_model("./save_model/" + env.name + "_" + agent.name) | ||
|
||
t = Mem_Trainer(agent=agent, | ||
env=env, | ||
memory=memory, | ||
n_worker=1, | ||
step_num=1, | ||
rand_explore_len=1000, | ||
save_every=save_every, | ||
render=render, | ||
print_every=10, | ||
noise=noise) | ||
t.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train_CartPole_DQN(save_every=5) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from train import * | ||
from models.net_builder import * | ||
from basic_utils.env_wrapper import Vec_env_wrapper | ||
from models.agents import * | ||
from basic_utils.options import * | ||
from basic_utils.exploration_noise import * | ||
|
||
|
||
def train_CartPole_DQN(load_model=False, render=False, save_every=None): | ||
torch.manual_seed(8833) | ||
env = Vec_env_wrapper(name='Acrobot-v1', consec_frames=1, running_stat=False, seed=23333) | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
noise = NoNoise_Exploration() | ||
|
||
if use_cuda: | ||
net.cuda() | ||
target_net.cuda() | ||
|
||
agent = DQN_Agent(net=net, | ||
target_net=target_net, | ||
gamma=0.95, | ||
lr=1e-3, | ||
update_target_every=500, | ||
get_info=True) | ||
|
||
memory = ReplayBuffer(memory_cap=10000, batch_size_q=64) | ||
|
||
if load_model: | ||
agent.load_model("./save_model/" + env.name + "_" + agent.name) | ||
|
||
t = Mem_Trainer(agent=agent, | ||
env=env, | ||
memory=memory, | ||
n_worker=1, | ||
step_num=1, | ||
rand_explore_len=1000, | ||
save_every=save_every, | ||
render=render, | ||
print_every=10, | ||
noise=noise) | ||
t.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train_CartPole_DQN(save_every=5) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from train import * | ||
from models.net_builder import * | ||
from basic_utils.env_wrapper import Vec_env_wrapper | ||
from models.agents import * | ||
from basic_utils.options import * | ||
from basic_utils.exploration_noise import * | ||
|
||
|
||
def train_CartPole_DQN(load_model=False, render=False, save_every=None): | ||
torch.manual_seed(8833) | ||
env = Vec_env_wrapper(name='Acrobot-v1', consec_frames=1, running_stat=False, seed=23333) | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
|
||
noise = EpsilonGreedy_Exploration(action_n=action_space.n, | ||
explore_len=10000, | ||
init_epsilon=0.05, | ||
final_epsilon=0.05) | ||
|
||
if use_cuda: | ||
net.cuda() | ||
target_net.cuda() | ||
|
||
agent = Double_DQN_Agent(net=net, | ||
target_net=target_net, | ||
gamma=0.95, | ||
lr=1e-3, | ||
update_target_every=1000, | ||
get_info=True) | ||
|
||
memory = ReplayBuffer(memory_cap=10000, batch_size_q=64) | ||
|
||
if load_model: | ||
agent.load_model("./save_model/" + env.name + "_" + agent.name) | ||
|
||
t = Mem_Trainer(agent=agent, | ||
env=env, | ||
memory=memory, | ||
n_worker=1, | ||
step_num=1, | ||
rand_explore_len=1000, | ||
save_every=save_every, | ||
render=render, | ||
print_every=10, | ||
noise=noise) | ||
t.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train_CartPole_DQN(save_every=5) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from train import * | ||
from models.net_builder import * | ||
from basic_utils.env_wrapper import Vec_env_wrapper | ||
from models.agents import * | ||
from basic_utils.options import * | ||
from basic_utils.exploration_noise import * | ||
|
||
|
||
def train_CartPole_DQN(load_model=False, render=False, save_every=None): | ||
torch.manual_seed(8833) | ||
env = Vec_env_wrapper(name='Acrobot-v1', consec_frames=1, running_stat=False, seed=23333) | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec) | ||
|
||
noise = EpsilonGreedy_Exploration(action_n=action_space.n, | ||
explore_len=10000, | ||
init_epsilon=0.05, | ||
final_epsilon=0.05) | ||
|
||
if use_cuda: | ||
net.cuda() | ||
target_net.cuda() | ||
|
||
agent = DQN_Agent(net=net, | ||
target_net=target_net, | ||
gamma=0.95, | ||
lr=5e-3, | ||
update_target_every=1000, | ||
get_info=True) | ||
|
||
memory = PrioritizedReplayBuffer(memory_cap=10000, | ||
batch_size_q=64, | ||
alpha=0.8, | ||
beta=0.6) | ||
if load_model: | ||
agent.load_model("./save_model/" + env.name + "_" + agent.name) | ||
|
||
t = Mem_Trainer(agent=agent, | ||
env=env, | ||
memory=memory, | ||
n_worker=1, | ||
step_num=1, | ||
rand_explore_len=1000, | ||
save_every=save_every, | ||
render=render, | ||
print_every=10, | ||
noise=noise) | ||
t.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
train_CartPole_DQN(save_every=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy as np | ||
import pylab | ||
|
||
length = 50 | ||
MEAN_LENGTH = 10 | ||
|
||
# distri = ['Prioritized DQN'] | ||
distri = ['DQN No Exploration', 'DQN Epsilon Greedy', 'Bayesian Thompson Sampling', 'Bayesian Dropout', 'Double DQN', 'Prioritized DQN'] | ||
|
||
# pylab.title(NAME[index]) | ||
pylab.title('Acrobot-v1') | ||
pylab.xlabel("episode") | ||
pylab.ylabel("score") | ||
c = [] | ||
|
||
# names = ['Double_DQN.npy', 'Prioritized_DQN.npy'] | ||
names = ['DQN_no_noise.npy', 'DQN_eg.npy', 'Bayesian_DQN_ts.npy', 'Bayesian_Dropout.npy', 'Double_DQN.npy', 'Prioritized_DQN.npy'] | ||
for dis in names: | ||
# a = np.load(NAME[index] + "_" + dis + '_Agent' + ".npy") | ||
a = np.load(dis) | ||
b = [] | ||
upper = [] | ||
lower = [] | ||
|
||
for i in range(min(a.shape[0], length)): | ||
fin = False | ||
# while not fin: | ||
if i < MEAN_LENGTH: | ||
mean = np.mean(a[0:i + 1]) | ||
std = np.std(a[0:i + 1]) | ||
else: | ||
mean = np.mean(a[i - MEAN_LENGTH: i]) | ||
std = np.std(a[i - MEAN_LENGTH: i]) | ||
if mean - a[i] > 3 * std: | ||
a[i] = mean | ||
else: | ||
fin = True | ||
b.append(mean) | ||
upper.append(mean + 0.5 * std) | ||
lower.append(mean - 0.5 * std) | ||
b = b[:length] | ||
lower = lower[:length] | ||
upper = upper[:length] | ||
c += pylab.plot(b) | ||
pylab.fill_between(range(len(b)), lower, upper, alpha=0.3) | ||
|
||
pylab.legend(c, distri, loc=2, fontsize='x-small') | ||
pylab.show() |
Oops, something went wrong.