Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is get_kl() doing in main.py? #2

Closed
jtoyama4 opened this issue Jun 19, 2017 · 12 comments
Closed

What is get_kl() doing in main.py? #2

jtoyama4 opened this issue Jun 19, 2017 · 12 comments

Comments

@jtoyama4
Copy link

Hi. Thanks for publishing implementation of trpo.

I have question about get_kl().

I thought what get_kl() is supposed to do is to calculate the kl divergence of old policy and new policy, but this get_kl() seems always returning 0.

Also,I do not see kl constraining part in the parameters updating process.

Is this code the modification of trpo or do I have some misunderstanding?

Thanks,

@ikostrikov
Copy link
Owner

KL is used to compute the hessian (as in the original code).

@pxlong
Copy link

pxlong commented Jul 5, 2017

Hi, I've tested the get_kl() function in various environments, but the return is always 0. The following is an example of get_kl() outputs in BipedalWalker-v2:

*** kl: Variable containing:
0 0 0 0
0 0 0 0
0 0 0 0

0 0 0 0
0 0 0 0
0 0 0 0
[torch.DoubleTensor of size 15160x4]

*** kl sum: Variable containing:
0
0
0

0
0
0
[torch.DoubleTensor of size 15160]

*** kl mean: Variable containing:
0
[torch.DoubleTensor of size 1]

*** kl grad: Variable containing:
1.00000e-34 *
2.4390
[torch.DoubleTensor of size 1]

*** kl grad grad: Variable containing:
1.00000e-34 *
2.4390
[torch.DoubleTensor of size 1]

('lagrange multiplier:', 2.509155592601353e-17, 'grad_norm:', 4.4529719813727475e-18)
fval before 2.0545072021432404e-13
a/e/r 2.1300610441917076e-13 5.018311185202932e-19 424457.7439662247
fval after -7.555384204846721e-15

Does this make sense?

@pxlong
Copy link

pxlong commented Jul 5, 2017

@jtoyama4 Hi, have you figure out how does get_kl() work?
I have the same question with you.

@jtoyama4
Copy link
Author

jtoyama4 commented Jul 5, 2017

@pxlong I think what get_kl() does is to get the gradient of kl (for hessian computing) , and kl-constraining part is somehow working with ratio in def linesearch but I really do not understand it theoretically.

@pxlong
Copy link

pxlong commented Jul 5, 2017

@jtoyama4, thanks for a quick reply.
but what get_kl() return is always 0 how can you get a valid/useful gradient of it?
I am a little confused.

@ikostrikov
Copy link
Owner

ikostrikov commented Jul 5, 2017

@pxlong a simple example:

In this case we have something like this f(x)=(x_0^2 - x^2), f(x_0) = 0 but f'(x_0)=-2x_0.

The function is not f(x) == 0 but it has a value at one specific point == 0.

@pxlong
Copy link

pxlong commented Jul 5, 2017

@ikostrikov, thanks for your explanation.
But I've tested this implementation in various envs, such as BipedalWalker-v2, MountainCarContinuous-v0, Pendulum-v0 (except Reacher-v1), none of them gives reasonable results (i.e. the agent learns nothing during training).

To debug it, I added some print as you can see belowing:

def Fvp(v):
        kl = get_kl()
        kl = kl.mean()
        print('*** kl mean: ', kl)

        grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * Variable(v)).sum()
        print('*** kl_v: ', kl_v)
        grads = torch.autograd.grad(kl_v, model.parameters())
        flat_grad_grad_kl = torch.cat([grad.view(-1) for grad in grads]).data
        print('*** flat grad grad kl: ', flat_grad_grad_kl)

but all the tensors are zero (showed in the second message of this issue).

I am wondering what the problem of the bad performance on these environments?

@ikostrikov
Copy link
Owner

Check hyperparams from the original implementation (modular rl). And also estimates of how long does convergence take. Default hyperparams of this code are tuned specifically for reacher-v1

@pxlong
Copy link

pxlong commented Jul 5, 2017

ok, thanks.

@ikostrikov
Copy link
Owner

ikostrikov commented Oct 7, 2017

This part is used to compute the hessian of KL. KL itself == 0, the derivative of the KL == 0 but the hessian is not.

This is the reason why we have to compute a second order approximation of the KL terms. Because its first order approximation is equal to zero.

@ikostrikov
Copy link
Owner

@pxlong Sorry, that it took me so long to fix the bug.

It didn't work because they've changed default argument values for some functions in PyTorch recently.

@josiahls
Copy link

josiahls commented Feb 3, 2023

@pxlong And anyone else, I found the get_kl is related to the statement from (Schulman et al., 2015) [TRPO] Trust Region Policy Optimization:

computing the Hessian of DKL with respect to θ

It still feels non-intuitive to me, but I guess the goal is auto diff / hessian calc vs getting an actual value out of get_kl.

Also For two univariate normal distributions p and q the above simplifies t has the math that looks directly related to the code.

The paper cites Numerical Optimization so I guess I have some reading to do :)

*edit:
So I if I understand correctly get_kl simply structures the kl for auto grad. The actual 2nd hessian is built in Fvp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants