Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating Through Marginals of Dependency CRF #72

Open
mbp28 opened this issue May 24, 2020 · 3 comments
Open

Differentiating Through Marginals of Dependency CRF #72

mbp28 opened this issue May 24, 2020 · 3 comments

Comments

@mbp28
Copy link

mbp28 commented May 24, 2020

Hi,

I tried using the DependencyCRF in a learning setting which required me to differentiate through the marginals. This turned out to be really difficult to achieve. I noticed that the gradients computed for the marginals tended to be of high variance + larger than I would expect (even though I haven't deep-dived into the Eisner algorithm yet).

I wonder if this a feature of the Eisner algorithm or might potentially hint at a bug?
Below is a minimal example which showcases that the maximum gradient returned for the arcscores can be quite large, even if they are on a reasonable scale.

import torch
from torch_struct import DependencyCRF
torch.manual_seed(99)

maxlen = 50
vals = torch.randn((1, maxlen, maxlen), requires_grad=True)
grad_output = torch.rand(1, maxlen, maxlen)
dist = DependencyCRF(vals)
marginals = dist.marginals
marginals.backward(grad_output)
print(vals.max().item())
print(marginals.max().item())
print(grad_output.max().item())
print(vals.grad.max().item())

#3.5494842529296875
#0.8289076089859009
#0.9995625615119934
#19.625778198242188
@srush
Copy link
Collaborator

srush commented Jul 12, 2020

hi, sorry for the long delay here. I'm going to try to add some tests to make sure it is returning the right values. I don't have a great sense about whether this is a bug, underflow, or correct in this case.

@sustcsonglin
Copy link
Contributor

I am pretty sure that the reason is due to the "Chart" class, one should set cache=False if want to reuse the computation graph

@srush
Copy link
Collaborator

srush commented Aug 10, 2020

that sounds right. I will turn off chart by default.

Also now the backward on marginals approach works with fastlogsemiring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants