Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 14, 2019
1 parent ddc2a9f commit d803af6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
1 change: 0 additions & 1 deletion torch_struct/cky_crf.py
Expand Up @@ -7,7 +7,6 @@
class CKY_CRF(_Struct):
def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
ssize = semiring.size()
batch, N, _, NT = scores.shape
scores = semiring.convert(scores)
if lengths is None:
Expand Down
19 changes: 10 additions & 9 deletions torch_struct/semimarkov.py
Expand Up @@ -22,28 +22,29 @@ def _dp(self, edge, lengths=None, force_grad=False):
semiring = self.semiring
ssize = semiring.size()
edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
edge.requires_grad_(True)

spans = self._make_chart(N - 1, (batch, K, C, C), edge, force_grad)
alpha = self._make_chart(N, (batch, K, C), edge, force_grad)
alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]
beta = self._make_chart(N, (batch, C), edge, force_grad)
semiring.one_(beta[0].data)
for n in range(1, N):
spans[n - 1][:] = semiring.times(
beta[n - 1].view(ssize, batch, 1, 1, C),
edge[:, :, n - 1].view(ssize, batch, K, C, C),
alpha[:, :, n - 1] = semiring.sum(
semiring.times(
beta[n - 1].view(ssize, batch, 1, 1, C),
edge[:, :, n - 1].view(ssize, batch, K, C, C),
)
)
alpha[n - 1][:] = semiring.sum(spans[n - 1])
t = max(n - K, -1)
f1 = torch.arange(n - 1, t, -1)
f2 = torch.arange(1, len(f1) + 1)
beta[n][:] = semiring.sum(
torch.stack([alpha[a][:, :, b] for a, b in zip(f1, f2)], dim=1), dim=1
torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
)
v = semiring.sum(
torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1),
dim=2,
)
return v, spans, beta
return v, [edge], beta

@staticmethod
def _rand():
Expand All @@ -54,7 +55,7 @@ def _rand():
return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item())

def _arrange_marginals(self, marg):
return torch.stack(marg, dim=2)
return self.semiring.unconvert(marg[0])

@staticmethod
def to_parts(sequence, extra, lengths=None):
Expand Down

0 comments on commit d803af6

Please sign in to comment.