Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alignment CRF test. Fix missing fill_() #109

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

JohnReid
Copy link
Contributor

@JohnReid JohnReid commented Oct 1, 2021

PR following up discussion here.

For the tests to pass I also had to update genbmm. See PR here.

Note that the tests only check the shape of the argmax and marginals. The values are not checked.

@JohnReid JohnReid mentioned this pull request Oct 1, 2021
@JohnReid
Copy link
Contributor Author

JohnReid commented Oct 1, 2021

I realised I hadn't tested many of the distribution properties. I've tried tests for few more but it looks like there are at least two more issues to resolve.

charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_(
charta[1][:, b, point:, 1, ind, :, :, Mid]
)
charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this is not going to work.

We need to call

init = torch.zeros(charta[1].shape).bool()
init[:, b, point:, 1, ind, :, :, Mid].fill_(True)
charta[1] = semiring.fill(charta[1], init, semiring.one)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this should fix your other issues too)

Copy link
Contributor Author

@JohnReid JohnReid Oct 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for that. I have to admit I had just copied the code from the one_() method before it was removed in #105. My assumption was that it was the correct code.

Copy link
Contributor Author

@JohnReid JohnReid Oct 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still facing a few issues though. I fixed a few of them in the commits below but some remain. The main sticking point seems to be that the BandedMatrixs are not correctly dispatched to multiply rather than matmul() in semirings.py. The matmul implementation only works for standard tensors. This affects dist.entropy, dist.sample(), dist.topk() but not the partition, argmax, marginals.

I tried to fix this rather naively by overloading the classmethod matmul in some of the semirings but this broke the existing tests. I backed that out and am trying to understand how the code relates to the description in the torch struct paper so that I can make the correct fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants