Skip to content

Commit

Permalink
add back dp_standard
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Oct 14, 2021
1 parent 4705eca commit 1fb6fc8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
6 changes: 6 additions & 0 deletions tests/test_algorithms.py
Expand Up @@ -519,9 +519,12 @@ def test_hsmm(model_test, semiring):
partition2 = algorithms[model_test][1].enumerate(semiring, edge)[0]
# third way: dp using edge scores computed from init/transitions/emission
partition3 = algorithms[model_test][0](semiring).logpartition(edge)[0]
# fourth way: dp_standard using edge scores computed from init/transitions/emission
partition4 = algorithms[model_test][0](semiring)._dp_standard(edge)[0]

assert torch.isclose(partition1, partition2).all()
assert torch.isclose(partition2, partition3).all()
assert torch.isclose(partition3, partition4).all()


@given(data())
Expand All @@ -542,3 +545,6 @@ def test_batching_lengths(model_test, semiring, data):
lengths_b = lengths[b:(b + 1)]
partition_b = model(semiring).logpartition(vals_b, lengths=lengths_b)[0][0]
assert torch.isclose(partition[b], partition_b).all()
# test _dp_standard
partition_dp_standard = model(semiring)._dp_standard(vals, lengths=lengths)[0][0]
assert torch.isclose(partition, partition_dp_standard).all()
62 changes: 31 additions & 31 deletions torch_struct/semimarkov.py
Expand Up @@ -86,37 +86,37 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
v = semiring.sum(semiring.sum(final[:, :, 0, :, 0, :].contiguous()))
return v, [log_potentials]

# def _dp_standard(self, edge, lengths=None, force_grad=False):
# semiring = self.semiring
# ssize = semiring.size()
# edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
# edge.requires_grad_(True)

# # Init
# # All paths starting at N of len K
# alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]

# # All paths finishing at N with label C
# beta = self._make_chart(N, (batch, C), edge, force_grad)
# semiring.one_(beta[0].data)

# # Main.
# for n in range(1, N):
# alpha[:, :, n - 1] = semiring.dot(
# beta[n - 1].view(ssize, batch, 1, 1, C),
# edge[:, :, n - 1].view(ssize, batch, K, C, C),
# )

# t = max(n - K, -1)
# f1 = torch.arange(n - 1, t, -1)
# f2 = torch.arange(1, len(f1) + 1)
# beta[n][:] = semiring.sum(
# torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
# )
# v = semiring.sum(
# torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
# )
# return v, [edge], beta
def _dp_standard(self, edge, lengths=None, force_grad=False):
semiring = self.semiring
ssize = semiring.size()
edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
edge.requires_grad_(True)

# Init
# All paths starting at N of len K
alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]

# All paths finishing at N with label C
beta = self._make_chart(N, (batch, C), edge, force_grad)
beta[0] = semiring.fill(beta[0], torch.tensor(True).to(edge.device), semiring.one)

# Main.
for n in range(1, N):
alpha[:, :, n - 1] = semiring.dot(
beta[n - 1].view(ssize, batch, 1, 1, C),
edge[:, :, n - 1].view(ssize, batch, K, C, C),
)

t = max(n - K, -1)
f1 = torch.arange(n - 1, t, -1)
f2 = torch.arange(1, len(f1) + 1)
beta[n][:] = semiring.sum(
torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
)
v = semiring.sum(
torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
)
return v, [edge], beta

@staticmethod
def to_parts(sequence, extra, lengths=None):
Expand Down

0 comments on commit 1fb6fc8

Please sign in to comment.