Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 4, 2019
1 parent 060b91f commit 529afe7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch_struct/test_algorithms.py
Expand Up @@ -70,7 +70,7 @@ def test_entropy(data):
alpha = struct.sum(vals)

log_z = model(LogSemiring).sum(vals)
_, log_probs = model(LogSemiring).enumerate(vals)
log_probs = model(LogSemiring).enumerate(vals)[1]
log_probs = torch.stack(log_probs, dim=1) - log_z
print(log_probs.shape, log_z.shape, log_probs.exp().sum(1))
entropy = -log_probs.mul(log_probs.exp()).sum(1).squeeze(0)
Expand All @@ -86,7 +86,7 @@ def test_generic_a(data):
struct = model(semiring)
vals, (batch, N) = model._rand()
alpha = struct.sum(vals)
count, _ = struct.enumerate(vals)
count = struct.enumerate(vals)[0]
assert alpha.shape[0] == batch
assert count.shape[0] == batch
assert alpha.shape == count.shape
Expand Down

0 comments on commit 529afe7

Please sign in to comment.