Skip to content

Commit

Permalink
Add another property to test.
Browse files Browse the repository at this point in the history
  • Loading branch information
John Reid committed Oct 1, 2021
1 parent e65832e commit 9788879
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions tests/test_alignment_crf.py
Expand Up @@ -3,8 +3,8 @@
import warnings


def test_alignment_crf():
batch, N, M = 1, 4, 5
def test_alignment_crf_shapes():
batch, N, M = 2, 4, 5
log_potentials = torch.rand(batch, N, M, 3)

if torch.cuda.is_available():
Expand All @@ -14,6 +14,17 @@ def test_alignment_crf():
'Will not test marginals.')

dist = torch_struct.AlignmentCRF(log_potentials)
assert (N, M, 3) == dist.argmax[0].shape
assert (batch, N, M, 3) == dist.argmax.shape
if torch.cuda.is_available():
assert (N, M, 3) == dist.marginals[0].shape
assert (batch, N, M, 3) == dist.marginals.shape
assert (batch,) == dist.partition.shape

# Fail due to AttributeError: 'BandedMatrix' object has no attribute
# 'unsqueeze'
# assert (batch,) == dist.entropy.shape
# assert (9, batch, N, M, 3) == dist.sample([9]).shape

# Fails due to: RuntimeError: Expected condition, x and y to be on
# the same device, but condition is on cpu and x and y are on
# cuda:0 and cuda:0 respectively
# assert (8, batch,) == dist.topk(8).shape

0 comments on commit 9788879

Please sign in to comment.