Skip to content

Commit

Permalink
Merge 9788879 into e51fecc
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnReid committed Oct 1, 2021
2 parents e51fecc + 9788879 commit b1d8463
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
30 changes: 30 additions & 0 deletions tests/test_alignment_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch_struct
import warnings


def test_alignment_crf_shapes():
batch, N, M = 2, 4, 5
log_potentials = torch.rand(batch, N, M, 3)

if torch.cuda.is_available():
log_potentials = log_potentials.cuda()
else:
warnings.warn('Could not move log potentials to CUDA device. '
'Will not test marginals.')

dist = torch_struct.AlignmentCRF(log_potentials)
assert (batch, N, M, 3) == dist.argmax.shape
if torch.cuda.is_available():
assert (batch, N, M, 3) == dist.marginals.shape
assert (batch,) == dist.partition.shape

# Fail due to AttributeError: 'BandedMatrix' object has no attribute
# 'unsqueeze'
# assert (batch,) == dist.entropy.shape
# assert (9, batch, N, M, 3) == dist.sample([9]).shape

# Fails due to: RuntimeError: Expected condition, x and y to be on
# the same device, but condition is on cpu and x and y are on
# cuda:0 and cuda:0 respectively
# assert (8, batch,) == dist.topk(8).shape
9 changes: 5 additions & 4 deletions torch_struct/alignment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
from .helpers import _Struct
import math
import warnings

try:
import genbmm

except ImportError:
pass
warnings.warn('Could not import genbmm. '
'However, genbmm is only used for CUDA operations.')

from .semirings import LogSemiring
from .semirings.fast_semirings import broadcast
Expand Down Expand Up @@ -97,9 +100,7 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
# Create finalizing paths.
point = (l + M) // 2

charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_(
charta[1][:, b, point:, 1, ind, :, :, Mid]
)
charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0)

for b in range(lengths.shape[0]):
point = (lengths[b] + M) // 2
Expand Down

0 comments on commit b1d8463

Please sign in to comment.