Skip to content

Commit

Permalink
Add Bayes NN
Browse files Browse the repository at this point in the history
  • Loading branch information
hzxsnczpku committed Dec 29, 2017
1 parent 16d7c3e commit 954dfba
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 6 deletions.
11 changes: 11 additions & 0 deletions basic_utils/exploration_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def noise(self):
return self.state


class NoNoise_Exploration:
def __init__(self):
self.extra_info = []

def process_action(self, a):
return np.argmax(a), {}

def reset(self):
pass


class EpsilonGreedy_Exploration:
"""
The epsilon greedy noise.
Expand Down
17 changes: 17 additions & 0 deletions basic_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,23 @@ def set_flat_params_to(model, flat_params):
prev_ind += flat_size


def get_flat_grads_from(model):
"""
Get the flattened gradients from the model.
Args:
model: the model from which the parameters are derived
Return:
flat_param: the flattened parameters
"""
grads = []
for param in model.parameters():
grads.append(param.grad.data.view(-1))
flat_grads = torch.cat(grads)
return flat_grads


def set_flat_grads_to(model, flat_grads):
"""
Set the flattened gradients to the model.
Expand Down
69 changes: 69 additions & 0 deletions examples/CartPole/train_CartPole_BDQN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from train import *
from models.net_builder import *
from basic_utils.env_wrapper import Vec_env_wrapper
from models.agents import *
from basic_utils.options import *
from basic_utils.exploration_noise import *


def train_CartPole_DQN(load_model=False, render=False, save_every=None, double=False, prioritized=False):
env = Vec_env_wrapper(name='CartPole-v1', consec_frames=1, running_stat=True)
action_space = env.action_space
observation_space = env.observation_space

net = MLPs_q(observation_space, action_space, net_topology_q_vec)
mean_net = MLPs_q(observation_space, action_space, net_topology_q_vec)
std_net = MLPs_q(observation_space, action_space, net_topology_q_vec)
target_net = MLPs_q(observation_space, action_space, net_topology_q_vec)
target_mean_net = MLPs_q(observation_space, action_space, net_topology_q_vec)
target_std_net = MLPs_q(observation_space, action_space, net_topology_q_vec)
noise = NoNoise_Exploration()

if use_cuda:
net.cuda()
mean_net.cuda()
std_net.cuda()
target_net.cuda()
target_mean_net.cuda()
target_std_net.cuda()

agent = Bayesian_DQN_Agent(net,
mean_net,
std_net,
target_net,
target_mean_net,
target_std_net,
alpha=1,
beta=0,
gamma=0.95,
lr=1e-3,
scale=1e-3,
update_target_every=500,
get_info=True)

if prioritized:
memory = PrioritizedReplayBuffer(memory_cap=2000,
batch_size_q=64,
alpha=0.8,
beta=0.6)
else:
memory = ReplayBuffer(memory_cap=2000, batch_size_q=64)

if load_model:
agent.load_model("./save_model/" + env.name + "_" + agent.name)

t = Mem_Trainer(agent=agent,
env=env,
memory=memory,
n_worker=1,
step_num=1,
rand_explore_len=1000,
save_every=save_every,
render=render,
print_every=50,
noise=noise)
t.train()


if __name__ == '__main__':
train_CartPole_DQN()
7 changes: 6 additions & 1 deletion examples/CartPole/train_CartPole_DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def train_CartPole_DQN(load_model=False, render=False, save_every=None, double=F
if double:
agent = Double_DQN_Agent(net=net, target_net=target_net, gamma=0.95)
else:
agent = DQN_Agent(net=net, target_net=target_net, gamma=0.95)
agent = DQN_Agent(net=net,
target_net=target_net,
gamma=0.95,
lr=1e-3,
update_target_every=500,
get_info=True)

if prioritized:
memory = PrioritizedReplayBuffer(memory_cap=2000, batch_size_q=64)
Expand Down
9 changes: 5 additions & 4 deletions examples/CartPole/train_CartPole_TRPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@


def train_CartPole_TRPO(load_model=False, render=False, save_every=None):
env = Vec_env_wrapper(name='CartPole-v1', consec_frames=1, running_stat=True)
torch.manual_seed(2)
env = Vec_env_wrapper(name='MountainCarContinuous-v0', consec_frames=1, running_stat=True)
ob_space = env.observation_space

probtype = Categorical(env.action_space)
probtype = DiagGauss(env.action_space)

pol_net = MLPs_pol(ob_space, net_topology_pol_vec, probtype.output_layers)
v_net = MLPs_v(ob_space, net_topology_v_vec)
Expand All @@ -25,7 +26,7 @@ def train_CartPole_TRPO(load_model=False, render=False, save_every=None):
epochs_v=10,
gamma=0.99,
cg_iters=10,
max_kl=0.003,
max_kl=0.01,
batch_size=256,
cg_damping=1e-3,
get_info=True)
Expand All @@ -36,7 +37,7 @@ def train_CartPole_TRPO(load_model=False, render=False, save_every=None):
t = Path_Trainer(agent,
env,
n_worker=10,
path_num=10,
path_num=1,
save_every=save_every,
render=render,
action_repeat=1,
Expand Down
45 changes: 44 additions & 1 deletion models/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class BasicAgent:
"""
This is the abstract class of the agent.
"""

def act(self, ob_no):
"""
Get the action given the observation.
Expand Down Expand Up @@ -402,6 +403,48 @@ def __init__(self,
Value_Based_Agent.__init__(self, baseline=baseline, gamma=gamma, double=False)


# ================================================================
# Bayesian Deep Q Learning
# ================================================================
class Bayesian_DQN_Agent(Value_Based_Agent):
def __init__(self,
net,
mean_net,
std_net,
target_net,
target_mean_net,
target_std_net,
alpha=1,
beta=1e-4,
gamma=0.99,
lr=1e-3,
scale=1e-3,
update_target_every=500,
get_info=True):
optimizer = Bayesian_Q_Optimizer(net=net,
mean_net=mean_net,
std_net=std_net,
lr=lr,
alpha=alpha,
beta=beta,
scale=scale,
get_data=get_info)

baseline = QValueFunction_Bayesian(net=net,
mean_net=mean_net,
std_net=std_net,
target_net=target_net,
target_mean_net=target_mean_net,
target_std_net=target_std_net,
optimizer=optimizer,
scale=scale,
tau=0.01,
update_target_every=update_target_every)

self.name = 'Bayesian_DQN_Agent'
Value_Based_Agent.__init__(self, baseline=baseline, gamma=gamma, double=False)


# ================================================================
# Double Deep Q Learning
# ================================================================
Expand Down Expand Up @@ -486,4 +529,4 @@ def __init__(self,
update_target_every=update_target_every)

self.name = 'DDPG_Agent'
Deterministic_Policy_Based_Agent.__init__(self, policy=policy, baseline=baseline, gamma=gamma)
Deterministic_Policy_Based_Agent.__init__(self, policy=policy, baseline=baseline, gamma=gamma)
47 changes: 47 additions & 0 deletions models/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,50 @@ def load_model(self, name):
net = torch.load(name + "_baseline.pkl")
self.net.load_state_dict(net.state_dict())
self.target_net.load_state_dict(net.state_dict())


class QValueFunction_Bayesian:
def __init__(self, net, mean_net, std_net, target_net, target_mean_net, target_std_net, optimizer, tau=0.01,
scale=1e-3, update_target_every=None):
self.net = net
self.mean_net = mean_net
self.std_net = std_net
self.target_net = target_net
self.scale = scale
self.target_mean_net = target_mean_net
self.target_std_net = target_std_net
self.optimizer = optimizer
self.target_updater_mean = Target_updater(self.mean_net, self.target_mean_net, tau, update_target_every)
self.target_updater_std = Target_updater(self.std_net, self.target_std_net, tau, update_target_every)

def predict(self, ob_no, target=False):
observations = turn_into_cuda(np_to_var(np.array(ob_no)))
if not target:
mean = get_flat_params_from(self.mean_net)
std = torch.log(1 + torch.exp(get_flat_params_from(self.std_net)))
sample_weight = mean + torch.randn(mean.size()) * std * self.scale
set_flat_params_to(self.net, sample_weight)
return self.net(observations).data.cpu().numpy()
else:
mean = get_flat_params_from(self.target_mean_net)
std = torch.log(1 + torch.exp(get_flat_params_from(self.target_std_net)))
sample_weight = mean + torch.randn(mean.size()) * std * self.scale
set_flat_params_to(self.target_net, sample_weight)
return self.target_net(observations).data.cpu().numpy()

def act(self, ob_no):
return self.predict(ob_no)

def fit(self, paths):
stat = self.optimizer(paths)
self.target_updater_mean.update()
self.target_updater_std.update()
return stat

def save_model(self, name):
torch.save(self.net, name + "_baseline.pkl")

def load_model(self, name):
net = torch.load(name + "_baseline.pkl")
self.net.load_state_dict(net.state_dict())
self.target_net.load_state_dict(net.state_dict())
58 changes: 58 additions & 0 deletions models/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,64 @@ def __call__(self, path):
return None, {"td_err": td_err.data.cpu().numpy()}


# ================================================================
# Bayesian Q-Learning Optimizer
# ================================================================
class Bayesian_Q_Optimizer(Optimizer):
def __init__(self, net, mean_net, std_net, lr, alpha, beta, scale, get_data=True):
self.net = net
self.mean_net = mean_net
self.std_net = std_net
self.alpha = alpha
self.beta = beta
self.scale = scale
self.optimizer_mean = optim.Adam(params=self.mean_net.parameters(), lr=lr)
self.optimizer_std = optim.Adam(params=self.std_net.parameters(), lr=lr)
self.get_data = get_data

def _derive_info(self, observations, y_targ, actions):
y_pred = self.net(observations).gather(1, actions.long())
explained_var = 1 - torch.var(y_targ - y_pred) / torch.var(y_targ)
loss = (y_targ - y_pred).pow(2).mean()
info = {'explained_var': explained_var.data[0], 'loss': loss.data[0]}
return info

def __call__(self, path):
observations = turn_into_cuda(path["observation"])
actions = turn_into_cuda(path["action"])
weights = turn_into_cuda(path["weights"]) if "weights" in path else None
y_targ = turn_into_cuda(path['y_targ'])

if self.get_data:
info_before = self._derive_info(observations, y_targ, actions)

mean = get_flat_params_from(self.mean_net)
rho = get_flat_params_from(self.std_net)
std = torch.log(1 + torch.exp(rho))
epsilon = torch.randn(mean.size()) * self.scale
sample_weight = epsilon * std + mean
set_flat_params_to(self.net, sample_weight)
td_err = torch.abs(self.net(observations).gather(1, actions.long()) - y_targ)
loss = (td_err.pow(2) * weights).sum() if weights is not None else td_err.pow(2).mean()

self.net.zero_grad()
loss.backward()
grad = get_flat_grads_from(self.net)
grad_mean = self.alpha * grad + self.beta * sample_weight - (mean - sample_weight) / (std.pow(2) + 1e-7)
grad_std = self.alpha * epsilon * grad + self.beta * epsilon * sample_weight - 1 / std + (sample_weight - mean).pow(2) / (std.pow(3) + 1e-7)
grad_std /= 1 + torch.exp(-rho)
set_flat_grads_to(self.mean_net, grad_mean)
set_flat_grads_to(self.std_net, grad_std)
self.optimizer_mean.step()
self.optimizer_std.step()

if self.get_data:
info_after = self._derive_info(observations, y_targ, actions)
return merge_before_after(info_before, info_after), {"td_err": td_err.data.cpu().numpy()}

return None, {"td_err": td_err.data.cpu().numpy()}


# ================================================================
# DDPG Optimizer
# ================================================================
Expand Down

0 comments on commit 954dfba

Please sign in to comment.