diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index a9a6701..7865fb6 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -3,8 +3,8 @@ import warnings -def test_alignment_crf(): - batch, N, M = 1, 4, 5 +def test_alignment_crf_shapes(): + batch, N, M = 2, 4, 5 log_potentials = torch.rand(batch, N, M, 3) if torch.cuda.is_available(): @@ -14,6 +14,17 @@ def test_alignment_crf(): 'Will not test marginals.') dist = torch_struct.AlignmentCRF(log_potentials) - assert (N, M, 3) == dist.argmax[0].shape + assert (batch, N, M, 3) == dist.argmax.shape if torch.cuda.is_available(): - assert (N, M, 3) == dist.marginals[0].shape + assert (batch, N, M, 3) == dist.marginals.shape + assert (batch,) == dist.partition.shape + + # Fail due to AttributeError: 'BandedMatrix' object has no attribute + # 'unsqueeze' + # assert (batch,) == dist.entropy.shape + # assert (9, batch, N, M, 3) == dist.sample([9]).shape + + # Fails due to: RuntimeError: Expected condition, x and y to be on + # the same device, but condition is on cpu and x and y are on + # cuda:0 and cuda:0 respectively + # assert (8, batch,) == dist.topk(8).shape