Skip to content

Commit

Permalink
batched implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Oct 14, 2021
1 parent 2f72847 commit 4705eca
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torch_struct/semimarkov.py
Expand Up @@ -62,10 +62,12 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
mask_length = torch.arange(bin_N).view(1, bin_N, 1).expand(batch, bin_N, C)
mask_length = mask_length.to(log_potentials.device)
for k in range(1, K - 1):
for b in range(batch):
end = lengths[b] - 1
mask[:, b, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
mask_length_k = mask_length < (lengths - 1 - (k - 1)).view(batch, 1, 1)
mask_length_k = semiring.convert(mask_length_k)
mask[:, :, :, k - 1, k].diagonal(0, -2, -1).masked_fill_(mask_length_k, True)
init = semiring.fill(init, mask, semiring.one)

K_1 = K - 1
Expand Down

0 comments on commit 4705eca

Please sign in to comment.