Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on formula of the continuous action #1

Open
dbsxdbsx opened this issue Mar 7, 2022 · 7 comments
Open

Question on formula of the continuous action #1

dbsxdbsx opened this issue Mar 7, 2022 · 7 comments

Comments

@dbsxdbsx
Copy link

dbsxdbsx commented Mar 7, 2022

First, thank your for the code related to paper Discrete and Continuous Action Representation for Practical RL in Video Games.

Second, according to your code, all of action spaces of the environments you used for this project are based on the 5th action space architecture stated from the paper ---only ONE dimension of discrete action space + continuous action space for EACH action from discrete action space(If I am wrong ,please tell me).

And My question is the loss formula related to the continous part:

# from calculating critic loss
min_qf_next_target = next_state_prob_d * (torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_prob_d * next_state_log_pi_c - alpha_d * next_state_log_pi_d)

# from calculating policy loss
policy_loss_c = (prob_d * (alpha * prob_d * log_pi_c - min_qf_pi)).sum(1).mean()

# from calculating temperature loss
alpha_loss = (-log_alpha * p_d * (p_d * lpi_c + target_entropy)).sum(1).mean()

From each code formula, you times distribution object (next_state_prob_d ,prob_d and p_d ) with the continuous objects (next_state_log_pi_c ,log_pi_c and lpi_c ),
even there is another same distribution object outside parenthesis.

Intuitionally, I think there is no need to multiple distribution object right with the corresponding continuous objects inside parenthesis.

I don't know whether I am wrong mathmatically. So I ask this question.

@mch5048
Copy link

mch5048 commented Apr 25, 2022

Thanks for this awesome reimplementation! It really gave me a good intuition for adapting the paper for my work.

I agree with @dbsxdbsx for the weighting term in the continuous actor loss.
I found that removing the weighting term inside the parenthesis do not harm the performance.

Mathematically, current implementation seems to "square" the weighting the entropy bonus term in the actor loss.
Thus, the effect of entropy bonus is reduced (as squaring the weight less or equal than 1.0 reduces the value exponetially.).

Want to hear from the code author @nisheeth-golakiya about this.

Thank you!

@mch5048
Copy link

mch5048 commented Apr 26, 2022

Oh, now I get the implementation.
I think the implementation of @nisheeth-golakiya is correct.

The prob_d or p_d term inside the parenthesis is the weighting term used for computing the joint entropy in the paper.

The prob_d or p_d outside the parenthesis is the probability for each discrete action for computing the 'expectation' of actor loss or alpha loss term.

Thus, the original implementation seems to be correct.

@dbsxdbsx you may verify this.

@nisheeth-golakiya
Copy link
Owner

@dbsxdbsx Apologies for my late response.

@mch5048 and @dbsxdbsx I am glad that you're interested in my implementation!

@mch5048 has got it right.

To get the expected value of a quantity, we usually approximate it by sampling (recall that for continous actions, we don't have exact probablity value of a given action). But for discrete actions, the probability value of the action is readily available. Hence, instead of sampling, we directly multiply by the probability value.

The first section of the appendix of the paper reveals many implementation details.


In the following line,

min_qf_next_target = next_state_prob_d * (torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_prob_d * next_state_log_pi_c - alpha_d * next_state_log_pi_d)

alpha * next_state_prob_d * next_state_log_pi_c + alpha_d * next_state_log_pi_d corresponds to Eq. 2 of the paper.
next_state_prob_d outside the parenthesis is to get the expected value of the target. (Please note the .sum(1) here)
This line from the discrete-SAC implementation may make things clearer.

For policy loss of the continuous component,

policy_loss_c = (prob_d * (alpha * prob_d * log_pi_c - min_qf_pi)).sum(1).mean()

please have a look at the fourth point of the first section of the appendix, which says

... This update is essentially the same as in continuous SAC,
as in eq. 7 of (Haarnoja et al. 2018c), except that it is performed as a weighted average over all discrete actions ...

prob_d outside the paranthesis calculates this weighted average which is effectively the expected value.
prob_d inside the paranthesis is analogous to the entropy term of the continuous component. Other way to look at this: if you don't have prob_d inside the paranthesis, you're making every discrete action equally likely, which is not the case.

Lastly, alpha_loss mimics the policy loss calculation.

Please let me know if you need more clarification.

@nisheeth-golakiya
Copy link
Owner

Regarding the calculation of discrete part of the policy loss, I would like to point out that the third point in the first section of the appendix says that it is optimized by minimizing the KL-divergence between the distribution of discrete actions and softmax of the Q-values with temperature alpha_d.

I have taken the liberty to modify this calculation according to this implementation found in discrete-SAC.

If you try the KL-divergence approach, do share the results :)

@mch5048
Copy link

mch5048 commented Apr 27, 2022

Thanks for the clarification. I'll delve into the kl div minimization view of algorithm soon.

@dbsxdbsx
Copy link
Author

dbsxdbsx commented Jun 5, 2022

@nisheeth-golakiya @mch5048 , thanks your attention and explanation on this topic.
Frankly, I still don't get the sense on your implementation. Briefly, I have doubt on the continuous action part, but no doubt on the discrete action part. And here comes my logic details (if some thing is wrong, don't hesitate to tell it, thanks):

critic loss

First, let's take look at the critic loss. The part is this,originally:

min_qf_next_target = next_state_prob_d * (torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_prob_d * next_state_log_pi_c - alpha_d * next_state_log_pi_d)
next_q_value = torch.Tensor(s_rewards).to(device) + (1 - torch.Tensor(s_dones).to(device)) * args.gamma * (min_qf_next_target.sum(1)).view(-1)

If the env only contains discrete actions, then no doubt it should be:

min_qf_next_target = next_state_prob_d * (torch.min(qf1_next_target, qf2_next_target) - alpha_d * next_state_log_pi_d)
next_q_value = torch.Tensor(s_rewards).to(device) + (1 - torch.Tensor(s_dones).to(device)) * args.gamma * (min_qf_next_target.sum(1)).view(-1)

Now let me put it in a more sensible way:

min_pred_next_state_score = torch.min(qf1_next_target, qf2_next_target)
d_entropy_bonus= - alpha_d * next_state_log_pi_d # [batch_size,d_dim],each ele of d_dim represents the specific entropy of  discrete action at the index
score_of_all_actions = min_pred_next_state_score + d_entropy_bonus
next_state_expected_score = (next_state_prob_d * score_of_all_actions).sum(1) # the sum(1) make the shape into [batch_size,], which makes it a math expectation scalar for each batch

next_q_value = torch.Tensor(s_rewards).to(
    device) + (1 - torch.Tensor(s_dones).to(device)) * args.gamma * (next_state_expected_score).view(-1)

(Here I also changed some variable name, hope it doesn't bother you).
We know the true label of critic part is exactly the soft state value of a specific state in context of max entropy alogrithom (like SAC, SOFT Q Leaning). So, from my implementation, it is easy to see that the soft state score of every sampleable discrete action is score_of_all_actions, and it consists of the no soft state value min_pred_next_state_score and the soft entropy bonus d_entropy_bonus in which alpha_d scales it as a temperature.

At present, everything is fine, right?
Now what if I want to give each discrete action another bonus---More precisely, what if making the each continuous action as part of the corresponding sampleable discrete action, just like we are still choosing discreate action at each state for the env?
Following this way, even in hybrid action env, we can still take the simplicity of pure discrete action framework, and still each continuous entropy(bonus) is jointly within a specific discrete action---no broken of the framework of hybrid action here:

min_pred_next_state_score = torch.min(qf1_next_target, qf2_next_target)
d_entropy_bonus= - alpha_d * next_state_log_pi_d # [batch_size,d_dim],each ele of d_dim represents the specific entropy of  discrete action at the index
c_entropy_bonus= - alpha * next_state_log_pi_c # [batch_size,d_dim],each ele of d_dim represents the extra continuous entropy within that discrete action
score_of_all_actions = min_pred_next_state_score + d_entropy_bonus + c_entropy_bonus # each element contains ALL ACTION Q VALUE predicted on next state
next_state_expected_score = (next_state_prob_d * score_of_all_actions).sum(1) # the sum() make the shape into [batch_size,], which makes it a math expectation scalar for each batch

next_q_value = torch.Tensor(s_rewards).to(
    device) + (1 - torch.Tensor(s_dones).to(device)) * args.gamma * (next_state_expected_score).view(-1)

Intuitionally, the c_entropy_bonus can be of any semantic meaning, or even 0 with no meaning. Just here in SAC with hybrid action, it is treated as an extra entropy bonus, just seems like we only need sample a single enumerable action as in pure discrete action env (but exactly a hybrid action env).

Doesn't it make sense?
Just Recall that in pure continuous action env, the above need no modification, except for d_entropy_bonus is 0 and the distribution d_entropy_bonus becomes a scalar of 1,since there is always only 1 (fake) discrete action of probability 100% to choose in pure continuous action env.
So, in this way, within either pure discrete, pure continuous or hybrid action space, the logic here is consistent---whatever the form of action is, we just treat the sampled action as a whole single substance, then make the sampled soft state action score Q(s,a) as correct as possible!

Then, with this modification, I tested in env platform(I tested this env for all the rest experiments below) and fixed temperature of alpha for both discrete and continuous to be 0.2. The result seems as good as original one.

And I mean no off sense, but the original version seems to be complex and I still don't follow the meaning of prob_d inside parenthesis.

actor loss

(I would use actor loss below instead of policy loss to make it more exact meaning in actor critic framework)
the original version:

actions_c, actions_d, log_pi_c, log_pi_d, prob_d = pg.get_action(s_obs, device)
qf1_pi = qf1.forward(s_obs, actions_c, device)
qf2_pi = qf2.forward(s_obs, actions_c, device)
min_qf_pi = torch.min(qf1_pi, qf2_pi)

policy_loss_d = (prob_d * (alpha_d * log_pi_d - min_qf_pi)).sum(1).mean()
policy_loss_c = (prob_d * (alpha * prob_d * log_pi_c - min_qf_pi)).sum(1).mean()
policy_loss = policy_loss_d + policy_loss_c

policy_optimizer.zero_grad()
policy_loss.backward()
policy_optimizer.step()

Still we have no doubt on the pure discreate action part, but the question is from the jointly continuous part.
Still recall that in pure continuous action env, the loss should be policy_loss_c = (alpha * log_pi_c - min_qf_pi).mean(), no sum(1) needed here since there is only one continuous dimension.

SUDDENLY, I realize that treating the policy_loss_c in way of policy ascent, SAC is exactly trying to maximize soft state value with min_qf_pi and log_pi_c are reparametrized, since -policy_loss_c = soft state value = min_qf_pi - alpha * log_pi_c. I know that the orignal policy_loss_c formula comes up from kl divergence with the model to be trained on the left side. What a coincidence(Or is it?)!

Since trying to maximize soft state value is not a bad interpretation for actor loss, and soft state value is exactly what we used to calculate true label in critic loss, I decided to put both the discreate and continuous action together the same way. So the modified actor loss is :

d_entropy_bonus = -alpha_d * log_pi_d
c_entropy_bonus = -alpha * log_pi_c
policy_bonus = (prob_d * (min_qf_pi + d_entropy_bonus+c_entropy_bonus)).sum(1)
policy_loss = - policy_bonus.mean() #average over all batches

From my test, still with fixed alpha, the result seems slightly slower than the original version.

Now, for both actor and critic loss part, I can't help asking myself, why both the original version and my new version can both making SAC convergent. My answer is, to a certain degree, the operator sum(1) makes the operation before this operator somehow untraceable--- that is, sum(1) makes most output number from original and my version to be quite similar.

alpha loss

Following the hybrid action SAC paper, since the target alpha for discrete and continuous are different, I think it is better to put these 2 into 2 separate losses as in the original version. But still doubt on the continuous part,
I modified the continuous part from:

alpha_loss = (-log_alpha * p_d * (p_d * lpi_c + target_entropy)).sum(1).mean()
a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().detach().cpu().item()

to:

alpha_loss = (-log_alpha * p_d * ( lpi_c + target_entropy)).sum(1).mean()
a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().detach().item()

The alpha_loss of pure continuos aciton can be made as more sensible in this way:

alpha_loss = log_alpha * ( -lpi_c - target_entropy)
= log_alpha *  (alpha_entropy - target_entropy)

From this modified alpha loss, it is easy to see that, to make alpha_entropy of all state to be not lower than target_entropy, log_alpha is adjusted accordingly (usually in range [0,1]).

And because each element of alpha_entropy is corresponding to specific discrete action,
so to ONLY calculate the continuous part over All discreate action distribution, the math expectation over all these discrete_action_dim is needed. So that is why p_d is essential outside of parenthesis.

alpha_loss = log_alpha * p_d* ( -lpi_c - target_entropy)
= log_alpha * p_d* (alpha_entropy - target_entropy)

But what about the inside p_d from original version, like this log_alpha * p_d* (p_d * alpha_entropy - target_entropy)? Let's take an example with discrete action dim of 2 in hybrid setting:
Here suppose alpha_entropy=[0.6,0.2] and p_d=[0.5,0.5] and target_entropy =0.4. For the 1st continuous_entropy of alpha_entropy:
in my version, log_alpha would say: "your are larger (enough) than target_entropy, so I should decreased myself somehow."
in original version, log_alpha would say:" your are lower (0.6*0.3)than target_entropy, so I should increase myself somehow."

While the outside p_d has no effect (half half), that the p_dinside of parenthesis does influence the direction of log_alpha, but which direction is correct?

At least in my version, because I calculate soft state value in this way:

# copied from actor loss, similar insights occurs from critic loss part
d_entropy_bonus = -alpha_d * log_pi_d
c_entropy_bonus = -alpha * log_pi_c
policy_bonus = (prob_d * (min_qf_pi + d_entropy_bonus+c_entropy_bonus)).sum(1)

here policy_bonus is exactly the soft state value, and log_alpha only effects alpha here of c_entropy_bonus And since c_entropy_bonus ONLY accounts for the continuous entropy, and each element at that pos on 2d dimension of c_entropy_bonus ONLY accounts for PART OF ALL entropy of the specific SOLO discrete action at that pos.
So I think we should compare target_entropy with alpha_entropy, not with alpha_entropy*some_weight.

Finally, after testing with the modified alpha, the modified SAC is still convergentable.

reference for a thorough understanding of SAC

@mch5048
Copy link

mch5048 commented Jun 6, 2022

@dbsxdbsx Oh, your point seems convincing. I also found a work that formulates hybrid-sac very similar to this paper, and it has official code implementation. In this implementation, the formulation follows exactly the same to your thoughts.

https://github.com/facebookresearch/hsd3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants