Skip to content

Commit

Permalink
Hybrid autoregressive transducer (HAT) (#1244)
Browse files Browse the repository at this point in the history
* removed workflow

* minor fix in nbest str representation

* initial commit for HAT loss

* add HAT loss

* remove unnecessary style changes

* fix style issue

* put hat option at end
  • Loading branch information
desh2608 committed Dec 19, 2023
1 parent 2989b0b commit 7711d16
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 9 deletions.
2 changes: 1 addition & 1 deletion k2/python/k2/nbest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self,

def __str__(self):
s = 'Nbest('
s += f'num_seqs:{self.shape.dim0()}, '
s += f'num_seqs:{self.shape.dim0}, '
s += f'num_fsas:{self.fsa.shape[0]})'
return s

Expand Down
221 changes: 213 additions & 8 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,194 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
return torch.gather(src, 2, index)


def get_hat_logprobs_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
This is a variant of get_rnnt_logprobs_pruned based on the Hybrid Autoregressive
Transducer (HAT) model proposed in https://arxiv.org/abs/2003.07705.
NOTE: We assume that the RNNT blank is the zeroth symbol.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C)
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C.
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame whether emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert logits.ndim == 4, logits.shape
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), (ranges.shape, B, T, s_range)
(B, S) = symbols.shape
assert S >= 0, S
assert (
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
assert termination_symbol == 0, f"Termination symbol must be 0, but got {termination_symbol}"

# For blank symbol, log-prob is log-sigmoid of the score
logp_b = torch.nn.functional.logsigmoid(logits[..., 0])

# For non-blank, we will compute the log-probs using log-softmax, for which we
# will need the following normalization factor.
nb_normalizers = torch.logsumexp(logits[..., 1:], dim=3)

# Additionally, to ensure the the probs of blank and non-blank sum to 1, we
# need to add the following term to the log-probs of non-blank symbols. This
# is equivalent to log(1 - sigmoid(logits[..., 0])).
nb_shift = logp_b - logits[..., 0]

symbols_with_terminal = torch.cat(
(
symbols,
torch.tensor(
[termination_symbol] * B,
dtype=torch.int64,
device=symbols.device,
).reshape((B, 1)),
),
dim=1,
)

# (B, T, s_range)
pruned_symbols = torch.gather(
symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)),
dim=2,
index=ranges,
)

# (B, T, s_range)
px = torch.gather(
logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1)
).squeeze(-1)
px = px - nb_normalizers + nb_shift

# (B, T, S) with index larger than s_range in dim 2 fill with -inf
px = torch.cat(
(
px,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=px.device,
dtype=px.dtype,
),
),
dim=2,
)

# (B, T, S) with index out of s_range in dim 2 fill with -inf
px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S]

px = px.permute((0, 2, 1))

if rnnt_type == "regular":
px = torch.cat(
(
px,
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

py = logp_b.clone() # (B, T, s_range)
# py is blank log-probs, so we need to subtract the normalizers and add the shift.
# Note that it denotes the horizontal arcs on the RNNT lattice (blank transition)

# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf
py = torch.cat(
(
py,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=py.device,
dtype=py.dtype,
),
),
dim=2,
)

# (B, T, S + 1) with index out of s_range in dim 2 fill with -inf
py = _roll_by_shifts(py, ranges[:, :, 0])
# (B, S + 1, T)
py = py.permute((0, 2, 1))

if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]

return (px, py)


def get_rnnt_logprobs_pruned(
logits: Tensor,
symbols: Tensor,
Expand Down Expand Up @@ -1169,6 +1357,7 @@ def rnnt_loss_pruned(
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
use_hat_loss: bool = False,
) -> Tensor:
"""A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
Expand Down Expand Up @@ -1219,19 +1408,35 @@ def rnnt_loss_pruned(
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
use_hat_loss:
If True, we compute the Hybrid Autoregressive Transducer (HAT) loss from
https://arxiv.org/abs/2003.07705. This is a variant of RNN-T that models
the blank distribution separately as a Bernoulli distribution, and the
non-blanks are modeled as a multinomial. This formulation may be useful
for performing internal LM estimation, as described in the paper.
Returns:
If reduction is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each sequence of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
)
if not use_hat_loss:
px, py = get_rnnt_logprobs_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
)
else:
px, py = get_hat_logprobs_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
)

if delay_penalty > 0.0:
B, S, T0 = px.shape
Expand Down
68 changes: 68 additions & 0 deletions k2/python/tests/rnnt_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,74 @@ def test_rnnt_loss_pruned_small_s_range(self):
), f"Pruned loss is inf for r={r}, S={S}, T={T}."
print(f"Pruned loss with range {r} : {pruned_loss}")

def test_hat_loss_pruned(self):
B = 4
T = 300
S = 50
C = 10

frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)

am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(1, C, (B, S))
terminal_symbol = 0

boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames

for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)

# pruning
k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)

for r in range(2, 50, 5):
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = k2.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)

logits = pruned_am + pruned_lm
# nonlinear transform
logits = torch.tanh(logits)

pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
use_hat_loss=True,
)
print(f"Pruned HAT loss with range {r} : {pruned_loss}")

# Check that training with an empty reference does not cause a crash.
def _test_rnnt_loss_empty_reference(self):
B = 1
Expand Down

0 comments on commit 7711d16

Please sign in to comment.