Skip to content

Commit

Permalink
update to new version of SAC with no state-value function (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Apr 19, 2022
1 parent d3a537a commit 9762e7b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 61 deletions.
50 changes: 25 additions & 25 deletions all/agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class SAC(Agent):
policy (DeterministicPolicy): An Approximation of a deterministic policy.
q1 (QContinuous): An Approximation of the continuous action Q-function.
q2 (QContinuous): An Approximation of the continuous action Q-function.
v (VNetwork): An Approximation of the state-value function.
replay_buffer (ReplayBuffer): The experience replay buffer.
discount_factor (float): Discount factor for future rewards.
entropy_target (float): The desired entropy of the policy. Usually -env.action_space.shape[0]
Expand All @@ -32,9 +31,8 @@ class SAC(Agent):

def __init__(self,
policy,
q_1,
q_2,
v,
q1,
q2,
replay_buffer,
discount_factor=0.99,
entropy_target=-2.,
Expand All @@ -47,9 +45,8 @@ def __init__(self,
):
# objects
self.policy = policy
self.v = v
self.q_1 = q_1
self.q_2 = q_2
self.q1 = q1
self.q2 = q2
self.replay_buffer = replay_buffer
self.logger = logger
# hyperparameters
Expand Down Expand Up @@ -78,34 +75,37 @@ def _train(self):
(states, actions, rewards, next_states, _) = self.replay_buffer.sample(self.minibatch_size)

# compute targets for Q and V
_actions, _log_probs = self.policy.no_grad(states)
q_targets = rewards + self.discount_factor * self.v.target(next_states)
v_targets = torch.min(
self.q_1.target(states, _actions),
self.q_2.target(states, _actions),
) - self.temperature * _log_probs
next_actions, next_log_probs = self.policy.no_grad(next_states)
q_targets = rewards + self.discount_factor * (torch.min(
self.q1.target(next_states, next_actions),
self.q2.target(next_states, next_actions),
) - self.temperature * next_log_probs)

# update Q and V-functions
self.q_1.reinforce(mse_loss(self.q_1(states, actions), q_targets))
self.q_2.reinforce(mse_loss(self.q_2(states, actions), q_targets))
self.v.reinforce(mse_loss(self.v(states), v_targets))
q1_loss = mse_loss(self.q1(states, actions), q_targets)
self.q1.reinforce(q1_loss)
q2_loss = mse_loss(self.q2(states, actions), q_targets)
self.q2.reinforce(q2_loss)

# update policy
_actions2, _log_probs2 = self.policy(states)
loss = (-self.q_1(states, _actions2) + self.temperature * _log_probs2).mean()
new_actions, new_log_probs = self.policy(states)
q_values = self.q1(states, new_actions)
loss = -(q_values - self.temperature * new_log_probs).mean()
self.policy.reinforce(loss)
self.q_1.zero_grad()
self.q1.zero_grad()

# adjust temperature
temperature_grad = (_log_probs + self.entropy_target).mean()
temperature_grad = (new_log_probs + self.entropy_target).mean() * self.temperature
self.temperature = max(0, self.temperature + self.lr_temperature * temperature_grad.detach())

# additional debugging info
self.logger.add_loss('entropy', -_log_probs.mean())
self.logger.add_loss('v_mean', v_targets.mean())
self.logger.add_loss('r_mean', rewards.mean())
self.logger.add_loss('temperature_grad', temperature_grad)
self.logger.add_loss('temperature', self.temperature)
self.logger.add_info('entropy', -new_log_probs.mean())
self.logger.add_info('q_values', q_values.mean())
self.logger.add_loss('rewards', rewards.mean())
self.logger.add_info('normalized_q1_error', q1_loss / q_targets.var())
self.logger.add_info('normalized_q2_error', q2_loss / q_targets.var())
self.logger.add_info('temperature', self.temperature)
self.logger.add_info('temperature_grad', temperature_grad)

def _should_train(self):
self._frames_seen += 1
Expand Down
4 changes: 2 additions & 2 deletions all/logging/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def add_loss(self, name, value, step="frame"):
self._add_scalar("loss/" + name, value, step)

def add_eval(self, name, value, step="frame"):
self._add_scalar("loss/" + name, value, step)
self._add_scalar("eval/" + name, value, step)

def add_info(self, name, value, step="frame"):
self._add_scalar("info/" + name, value, step)
Expand Down Expand Up @@ -108,7 +108,7 @@ def add_loss(self, name, value, step="frame"):
self._add_scalar("loss/" + name, value, step)

def add_eval(self, name, value, step="frame"):
self._add_scalar("loss/" + name, value, step)
self._add_scalar("eval/" + name, value, step)

def add_info(self, name, value, step="frame"):
self._add_scalar("info/" + name, value, step)
Expand Down
52 changes: 18 additions & 34 deletions all/presets/continuous/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from all.memory import ExperienceReplayBuffer
from all.presets.builder import PresetBuilder
from all.presets.preset import Preset
from all.presets.continuous.models import fc_q, fc_v, fc_soft_policy
from all.presets.continuous.models import fc_q, fc_soft_policy


default_hyperparameters = {
# Common settings
"discount_factor": 0.98,
# Adam optimizer settings
"lr_q": 1e-3,
"lr_v": 1e-3,
"lr_pi": 1e-4,
# Training settings
"minibatch_size": 100,
Expand All @@ -33,7 +32,6 @@
# Model construction
"q1_model_constructor": fc_q,
"q2_model_constructor": fc_q,
"v_model_constructor": fc_v,
"policy_model_constructor": fc_soft_policy
}

Expand All @@ -49,7 +47,6 @@ class SACContinuousPreset(Preset):
Keyword Args:
lr_q (float): Learning rate for the Q networks.
lr_v (float): Learning rate for the state-value networks.
lr_pi (float): Learning rate for the policy network.
minibatch_size (int): Number of experiences to sample in each training update.
update_frequency (int): Number of timesteps per training update.
Expand All @@ -67,50 +64,38 @@ class SACContinuousPreset(Preset):

def __init__(self, env, name, device, **hyperparameters):
super().__init__(name, device, hyperparameters)
self.q_1_model = hyperparameters["q1_model_constructor"](env).to(device)
self.q_2_model = hyperparameters["q2_model_constructor"](env).to(device)
self.v_model = hyperparameters["v_model_constructor"](env).to(device)
self.q1_model = hyperparameters["q1_model_constructor"](env).to(device)
self.q2_model = hyperparameters["q2_model_constructor"](env).to(device)
self.policy_model = hyperparameters["policy_model_constructor"](env).to(device)
self.action_space = env.action_space

def agent(self, logger=DummyLogger(), train_steps=float('inf')):
n_updates = (train_steps - self.hyperparameters["replay_start_size"]) / self.hyperparameters["update_frequency"]

q_1_optimizer = Adam(self.q_1_model.parameters(), lr=self.hyperparameters["lr_q"])
q_1 = QContinuous(
self.q_1_model,
q_1_optimizer,
q1_optimizer = Adam(self.q1_model.parameters(), lr=self.hyperparameters["lr_q"])
q1 = QContinuous(
self.q1_model,
q1_optimizer,
scheduler=CosineAnnealingLR(
q_1_optimizer,
n_updates
),
logger=logger,
name='q_1'
)

q_2_optimizer = Adam(self.q_2_model.parameters(), lr=self.hyperparameters["lr_q"])
q_2 = QContinuous(
self.q_2_model,
q_2_optimizer,
scheduler=CosineAnnealingLR(
q_2_optimizer,
q1_optimizer,
n_updates
),
target=PolyakTarget(self.hyperparameters["polyak_rate"]),
logger=logger,
name='q_2'
name='q1'
)

v_optimizer = Adam(self.v_model.parameters(), lr=self.hyperparameters["lr_v"])
v = VNetwork(
self.v_model,
v_optimizer,
q2_optimizer = Adam(self.q2_model.parameters(), lr=self.hyperparameters["lr_q"])
q2 = QContinuous(
self.q2_model,
q2_optimizer,
scheduler=CosineAnnealingLR(
v_optimizer,
q2_optimizer,
n_updates
),
target=PolyakTarget(self.hyperparameters["polyak_rate"]),
logger=logger,
name='v',
name='q2'
)

policy_optimizer = Adam(self.policy_model.parameters(), lr=self.hyperparameters["lr_pi"])
Expand All @@ -132,9 +117,8 @@ def agent(self, logger=DummyLogger(), train_steps=float('inf')):

return TimeFeature(SAC(
policy,
q_1,
q_2,
v,
q1,
q2,
replay_buffer,
temperature_initial=self.hyperparameters["temperature_initial"],
entropy_target=(-self.action_space.shape[0] * self.hyperparameters["entropy_target_scaling"]),
Expand Down

0 comments on commit 9762e7b

Please sign in to comment.