Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 26, 2019
1 parent 398ba1a commit a9b4e37
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 22 deletions.
1 change: 0 additions & 1 deletion torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
c.data[:, (~mask).view(-1)] = semiring.zero
c[:] = semiring.sum(torch.stack([c, lp], dim=-1))


# Scan
for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])
Expand Down
24 changes: 4 additions & 20 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):

# Init.
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))
# for k in range(1, K - 1):
# semiring.one_(init[:, :, : , k - 1, k].diagonal(0, -2, -1))

# Length mask
big = torch.zeros(
Expand All @@ -50,33 +48,19 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = init[:, :, :].view(ssize, batch * bin_N, K -1, K-1, C, C)
c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
lp.data[:, mask.view(-1)] = semiring.zero
c.data[:, (~mask).view(-1)] = semiring.zero
c[:, :, :K-1, 0] = semiring.sum(torch.stack([c[:, :, :K-1, 0],
lp[:, :, 1:K]], dim=-1))

# for b in range(lengths.shape[0]):
# end = lengths[b] - 1
# for k in range(1, K - 1):
# semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, -2, -1))
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)

# ks = torch.arange(1, K - 1)
for k in range(1, K - 1):
semiring.one_(init[:, :, : (k - 1), k - 1, k].diagonal(0, -2, -1))

# init[:, :, :end, : (K - 1), 0] = log_potentials[:, :, :end, 1:K]

# init[:, :, :N-1, : (K - 1), 0] = log_potentials[:, :, :N-1, 1:K]
# for b in range(lengths.shape[0]):
# end = lengths[b] - 1
# semiring.one_(init[:, b, end:, 0, 0].diagonal(0, 2, 3))
# init[:, b, :end, : (K - 1), 0] = log_potentials[:, b, :end, 1:K]
# for k in range(1, K - 1):
# semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3))
K_1 = K - 1

# Order n, n-1
Expand Down
1 change: 1 addition & 0 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class EntropySemiring(Semiring):
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
"""
zero = 0

@staticmethod
def size():
Expand Down
3 changes: 2 additions & 1 deletion torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def test_cky(data):
@settings(max_examples=50, deadline=None)
def test_generic_a(data):
model = data.draw(
sampled_from([SemiMarkov]) #Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
sampled_from(
[SemiMarkov, Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
)
semiring = data.draw(sampled_from([LogSemiring, MaxSemiring]))
struct = model(semiring)
Expand Down

0 comments on commit a9b4e37

Please sign in to comment.