Skip to content

Commit

Permalink
Merge 5cde02d into c885fba
Browse files Browse the repository at this point in the history
  • Loading branch information
erksch committed May 9, 2023
2 parents c885fba + 5cde02d commit c79beaa
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 11 deletions.
5 changes: 5 additions & 0 deletions README.rst
Expand Up @@ -82,3 +82,8 @@ Running linter

Run ``flake8`` in the project root directory. This will also run ``mypy``,
thanks to ``flake8-mypy`` package.

TorchScript
-----------

The ``CRF`` module is compatible with TorchScript on PyTorch ``>=1.10.0``.
2 changes: 1 addition & 1 deletion requirements-test.txt
@@ -1,6 +1,6 @@
# This only installs PyTorch with a specific CUDA version which may not be
# compatible with yours. If so, install PyTorch with the correct CUDA version
# as instructed on https://pytorch.org/get-started/locally/
torch
torch>=1.10.0
pytest==3.2.5
pytest-cov==2.5.1
127 changes: 127 additions & 0 deletions tests/test_crf.py
Expand Up @@ -467,3 +467,130 @@ def test_first_timestep_mask_is_not_all_on(self):
with pytest.raises(ValueError) as excinfo:
crf.decode(emissions, mask=mask)
assert 'mask of the first timestep must all be on' in str(excinfo.value)

class TestTorchScript:
def test_torch_scriptable(self):
crf = make_crf()
scripted_module = torch.jit.script(crf)
assert hasattr(scripted_module, 'decode')

def test_scripted_forward(self):
# Test default case
crf = make_crf()
crf_script = torch.jit.script(crf)
seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf, seq_length, batch_size)
# shape: (seq_length, batch_size)
tags = make_tags(crf, seq_length, batch_size)
# mask should have size of (seq_length, batch_size)
mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
llh = crf(emissions, tags, mask=mask)
llh_scripted = crf_script(emissions, tags, mask=mask)
assert torch.equal(llh, llh_scripted), f"scripted crf forward output {llh_scripted} " \
f"not matching non-scripted forward output {llh}"

# Test scripted forward works without mask
llh_no_mask = crf(emissions, tags)
llh_no_mask_script = crf_script(emissions, tags)
assert torch.equal(llh_no_mask, llh_no_mask_script), f"scripted crf forward output {llh_no_mask_script} " \
f"not matching non-scripted forward output {llh_no_mask}"
# No mask means the mask is all ones
llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).byte())
llh_mask_script = crf_script(emissions, tags, mask=torch.ones_like(tags).byte())
assert torch.equal(llh_mask, llh_mask_script), f"scripted crf forward output {llh_mask_script} " \
f"not matching non-scripted forward output {llh_mask}"

# Test scripted forward in batched setting
batch_size = 10
# shape: (seq_length, batch_size, num_tags)
emissions_batch = make_emissions(crf, batch_size=batch_size)
# shape: (seq_length, batch_size)
tags_batch = make_tags(crf, batch_size=batch_size)
llh = crf(emissions_batch, tags_batch)
llh_script = crf_script(emissions_batch, tags_batch)
assert torch.equal(llh_script, llh), f"scripted crf forward output {llh_script} " \
f"not matching non-scripted forward output {llh}"

# Test scripted forward when reduction is None, mean, token_mean
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf)
# shape: (seq_length, batch_size)
tags = make_tags(crf)
llh = crf(emissions, tags, reduction='none')
llh_script = crf_script(emissions, tags, reduction='none')
assert torch.equal(llh_script, llh), f"scripted crf forward output {llh_script} " \
f"not matching non-scripted forward output {llh}"
llh = crf(emissions, tags, reduction='mean')
llh_script = crf_script(emissions, tags, reduction='mean')
assert torch.equal(llh_script, llh), f"scripted crf forward output {llh_script} " \
f"not matching non-scripted forward output {llh}"

mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
llh = crf(emissions, tags, mask=mask, reduction='token_mean')
llh_script = crf_script(emissions, tags, mask=mask, reduction='token_mean')
assert torch.equal(llh_script, llh), f"scripted crf forward output {llh_script} " \
f"not matching non-scripted forward output {llh}"

# Test scripted forward when running batch first mode
crf_bf = make_crf(batch_first=True)
# Copy parameter values from non-batch-first CRF; requires_grad must be False
# to avoid runtime error of in-place operation on a leaf variable
crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions)
crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions)
crf_bf.transitions.requires_grad_(False).copy_(crf.transitions)
crf_bf_script = torch.jit.script(crf_bf)
emissions = emissions.transpose(0, 1)
# shape: (batch_size, seq_length)
tags = tags.transpose(0, 1)
llh_bf = crf_bf(emissions, tags)
llh_bf_script = crf_bf_script(emissions, tags)
assert torch.equal(llh_bf_script, llh_bf), f"scripted crf forward output {llh_bf_script} " \
f"not matching non-scripted forward output {llh_bf}"

def test_scripted_decode(self):
# Test decoding with a mask
crf = make_crf()
crf_script = torch.jit.script(crf)

seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf, seq_length, batch_size)
# mask should be (seq_length, batch_size)
mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
best_tags = crf.decode(emissions, mask=mask)
best_tags_scripted = crf_script.decode(emissions, mask=mask)
assert best_tags == best_tags_scripted, f"scripted decode output {best_tags_scripted} " \
f"doesn't match non-scripted output {best_tags}"

# Test decoding without a mask
best_tags_no_mask = crf.decode(emissions)
best_tags_no_mask_scripted = crf_script.decode(emissions)
assert best_tags_no_mask == best_tags_no_mask_scripted, f"scripted decode output {best_tags_no_mask_scripted} " \
f"doesn't match non-scripted output {best_tags_no_mask}"

# Test batched decode
batch_size, seq_length = 2, 3
# shape: (seq_length, batch_size, num_tags)
emissions_batched = make_emissions(crf, seq_length, batch_size)
# shape: (seq_length, batch_size)
mask_batched = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
batched = crf.decode(emissions_batched, mask=mask_batched)
batched_scripted = crf_script.decode(emissions_batched, mask=mask_batched)
assert batched == batched_scripted, f"scripted decode output {batched_scripted} " \
f"doesn't match non-scripted output {batched}"

# Test batch first decode
crf_bf = make_crf(batch_first=True)
# Copy parameter values from non-batch-first CRF; requires_grad must be False
# to avoid runtime error of in-place operation on a leaf variable
crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions)
crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions)
crf_bf.transitions.requires_grad_(False).copy_(crf.transitions)
crf_bf_script = torch.jit.script(crf_bf)
# shape: (batch_size, seq_length, num_tags)
emissions = emissions.transpose(0, 1)
best_tags_bf = crf_bf.decode(emissions)
best_tags_bf_script = crf_bf_script.decode(emissions)
assert best_tags_bf == best_tags_bf_script, f"scripted decode output {best_tags_bf_script} " \
f"doesn't match non-scripted decode output {best_tags_bf}"
25 changes: 15 additions & 10 deletions torchcrf/__init__.py
Expand Up @@ -114,7 +114,9 @@ def forward(
assert reduction == 'token_mean'
return llh.sum() / mask.type_as(emissions).sum()

def decode(self, emissions: torch.Tensor,
@torch.jit.export
def decode(self,
emissions: torch.Tensor,
mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
"""Find the most likely tag sequence using Viterbi algorithm.
Expand Down Expand Up @@ -151,16 +153,18 @@ def _validate(
f'got {emissions.size(2)}')

if tags is not None:
if emissions.shape[:2] != tags.shape:
if emissions.shape[0] != tags.shape[0] or emissions.shape[1] != tags.shape[1]:
raise ValueError(
'the first two dimensions of emissions and tags must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
f'got {(emissions.shape[0], emissions.shape[1])} and {(tags.shape[0], tags.shape[1])}'
)

if mask is not None:
if emissions.shape[:2] != mask.shape:
if emissions.shape[0] != mask.shape[0] or emissions.shape[1] != mask.shape[1]:
raise ValueError(
'the first two dimensions of emissions and mask must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
f'got {(emissions.shape[0], emissions.shape[1])} and {(mask.shape[0], mask.shape[1])}'
)
no_empty_seq = not self.batch_first and mask[0].all()
no_empty_seq_bf = self.batch_first and mask[:, 0].all()
if not no_empty_seq and not no_empty_seq_bf:
Expand All @@ -173,7 +177,7 @@ def _compute_score(
# tags: (seq_length, batch_size)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and tags.dim() == 2
assert emissions.shape[:2] == tags.shape
assert emissions.shape[0] == mask.shape[0] and emissions.shape[1] == mask.shape[1]
assert emissions.size(2) == self.num_tags
assert mask.shape == tags.shape
assert mask[0].all()
Expand Down Expand Up @@ -270,7 +274,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor,
# Start transition and first emission
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
history = []
history: List[torch.Tensor] = []

# score is a tensor of size (batch_size, num_tags) where for every batch,
# value at column j stores the score of the best tag sequence so far that ends
Expand Down Expand Up @@ -313,17 +317,18 @@ def _viterbi_decode(self, emissions: torch.FloatTensor,

# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0) - 1
best_tags_list = []
best_tags_list: List[List[int]] = []

for idx in range(batch_size):
# Find the tag which maximizes the score at the last timestep; this is our best tag
# for the last timestep
_, best_last_tag = score[idx].max(dim=0)
best_tags = [best_last_tag.item()]
best_tags: List[int] = []
best_tags.append(best_last_tag.item())

# We trace back where the best last tag comes from, append that to our best tag
# sequence, and trace it back again, and so on
for hist in reversed(history[:seq_ends[idx]]):
for hist in history[:seq_ends[idx]][::-1]:
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(best_last_tag.item())

Expand Down

0 comments on commit c79beaa

Please sign in to comment.