Skip to content

Commit

Permalink
Merge e65832e into e51fecc
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnReid committed Oct 1, 2021
2 parents e51fecc + e65832e commit 4c67de1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
19 changes: 19 additions & 0 deletions tests/test_alignment_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch_struct
import warnings


def test_alignment_crf():
batch, N, M = 1, 4, 5
log_potentials = torch.rand(batch, N, M, 3)

if torch.cuda.is_available():
log_potentials = log_potentials.cuda()
else:
warnings.warn('Could not move log potentials to CUDA device. '
'Will not test marginals.')

dist = torch_struct.AlignmentCRF(log_potentials)
assert (N, M, 3) == dist.argmax[0].shape
if torch.cuda.is_available():
assert (N, M, 3) == dist.marginals[0].shape
9 changes: 5 additions & 4 deletions torch_struct/alignment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
from .helpers import _Struct
import math
import warnings

try:
import genbmm

except ImportError:
pass
warnings.warn('Could not import genbmm. '
'However, genbmm is only used for CUDA operations.')

from .semirings import LogSemiring
from .semirings.fast_semirings import broadcast
Expand Down Expand Up @@ -97,9 +100,7 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
# Create finalizing paths.
point = (l + M) // 2

charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_(
charta[1][:, b, point:, 1, ind, :, :, Mid]
)
charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0)

for b in range(lengths.shape[0]):
point = (lengths[b] + M) // 2
Expand Down

0 comments on commit 4c67de1

Please sign in to comment.