-
Notifications
You must be signed in to change notification settings - Fork 22
/
rnnt_loss.py
1442 lines (1292 loc) · 57.4 KB
/
rnnt_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 Xiaomi Corp. (author: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torch import Tensor
from typing import Optional, Tuple, Union
from .mutual_information import mutual_information_recursion
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and rnnt_type == "regular", px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
called if rnnt_type == "regular", see other docs for `rnnt_type`)
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
"""
if boundary is None:
return px
B, S, T1 = px.shape
boundary = boundary[:, 3].reshape(B, 1, 1).expand(B, S, T1)
return px.scatter_(dim=2, index=boundary, value=float("-inf"))
def get_rnnt_logprobs(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
rnnt_type: str = "regular",
boundary: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss_simple(), but may be useful for
other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3, lm.ndim
assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp()
lm_probs = (lm - lm_max).exp()
# normalizers: [B][S+1][T]
normalizers = (
torch.matmul(lm_probs, am_probs.transpose(1, 2))
+ torch.finfo(am_probs.dtype).tiny
).log()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
if rnnt_type == "regular":
px_am = torch.cat(
(
px_am,
torch.full(
(B, S, 1),
float("-inf"),
device=px_am.device,
dtype=px_am.dtype,
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice with indexes out of
# boundary is -inf
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py)
def rnnt_loss_simple(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
)
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
return (loss, scores_and_grads[1]) if return_grad else loss
def get_rnnt_logprobs_joint(
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss().
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert logits.ndim == 4, logits.ndim
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S), symbols.shape
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))
px = torch.gather(
logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1)
).squeeze(-1)
px = px.permute((0, 2, 1))
if rnnt_type == "regular":
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px[:, :, :T] -= normalizers[:, :S, :]
py = (
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py -= normalizers
if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py)
def rnnt_loss(
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor.
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements
in {0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_joint(
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
)
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
raise ValueError(
f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
)
def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
min_value = min(x[i], min_value)
x[i] = min_value
>>> import torch
>>> x = torch.tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32)
>>> x
tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> x = torch.randint(20, (3, 6), dtype=torch.int32)
>>> x
tensor([[12, 18, 5, 4, 18, 17],
[11, 14, 14, 3, 10, 4],
[19, 3, 8, 13, 7, 19]], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([[ 4, 4, 4, 4, 17, 17],
[ 3, 3, 3, 3, 4, 4],
[ 3, 3, 7, 7, 7, 19]], dtype=torch.int32)
Args:
x:
The source tensor.
Returns:
Returns a tensor which is monotonic on the last dimension
(i.e. satisfiy `x[i] <= x[i+1]`).
"""
x = torch.flip(x, dims=(-1,))
x, _ = torch.cummin(x, dim=-1)
x = torch.flip(x, dims=(-1,))
return x
def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
any symbols.
To make it monotonic increasing, we can use `_monotonic_lower_bound` above,
which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constraint is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
transformation. The idea is: if we want to make
`s_begin[i + 1] - s_begin[i] < s_range` we only need to make
`-(s_begin[i] - i * (s_range - 1))` a non-decreasing array. Proof:
-(s_begin[i] - i * (s_range - 1)) <= -(s_begin[i + 1] - (i + 1) * (s_range - 1))
-s_begin[i] <= -s_begin[i + 1] + (i + 1) * (s_range - 1) - i * (s_range - 1)
-s_begin[i] <= -s_begin[i + 1] + s_range - 1
s_begin[i + 1] - s_begin[i] <= s_range - 1
s_begin[i + 1] - s_begin[i] < s_range
The above transformation can not guarantee the start symbol to be 0, so we
have to make all the elements that less than 0 to be 0 before transforming
back the `s_begin`.
"""
# s_begin (B, T)
(B, T) = s_begin.shape
s_begin = _monotonic_lower_bound(s_begin)
# do the magic transformation
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
# make the transformed tensor to be non-decreasing
s_begin = _monotonic_lower_bound(s_begin)
# make start symbol to be zero.
s_begin = torch.clamp(s_begin, min=0)
# do the magic transformation again to recover s_begin
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
return s_begin
# To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
boundary: torch.Tensor,
s_range: int,
) -> torch.Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion.
For each sequence with T frames, we will generate a tensor with the shape of
(T, s_range) containing the information that which symbols will be token
into consideration for each frame. For example, here is a sequence with 10
frames and the corresponding symbols are `[A B C D E F]`, if the s_range
equals 3, one possible ranges tensor will be::
[[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [1, 2, 3],
[1, 2, 3], [1, 2, 3], [3, 4, 5], [3, 4, 5], [3, 4, 5]]
which means we only consider `[A B C]` at frame 0, 1, 2, 3, and `[B C D]`
at frame 4, 5, 6, `[D E F]` at frame 7, 8, 9.
We can only consider limited number of symbols because frames and symbols
are monotonic aligned, theoretically it can only generate particular range
of symbols given a particular frame.
Note:
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols) - s_range` and
it satisfies `ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
won't skip any symbols.
Args:
px_grad:
The gradient of px, see docs in `mutual_information_recursion` for more
details of px.
py_grad:
The gradient of py, see docs in `mutual_information_recursion` for more
details of py.
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame]
s_range:
How many symbols to keep for each frame.
Returns:
A tensor with the shape of (B, T, s_range) containing the indexes of the
kept symbols for each frame.
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]
assert T1 in [T, T + 1], T1
S1 = S + 1
assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4), boundary.shape
assert S >= 1, S
assert T >= S, (T, S)
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S:
s_range = S + 1
if T1 == T:
assert (
s_range >= 1
), "Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."
else:
assert (
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
py_grad,
(B, S1 - s_range + 1, s_range, T),
(B_stride, S_stride, S_stride, T_stride),
)
# (B, S1 - s_range + 1, T)
blk_sum_grad = torch.sum(blk_grad, axis=2)
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
# (B, S1, T)
px_grad_pad = torch.cat((px_pad, px_grad), dim=1)
# (B, S1 - s_range + 1, T)
final_grad = blk_sum_grad - px_grad_pad[:, : S1 - s_range + 1, :T]
# (B, T)
s_begin = torch.argmax(final_grad, axis=1)
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame (before padding) with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame (before
# padding).
# The shape of the mask is (B, T), for example, we have a batch containing
# 3 sequences, their lengths are 3, 5, 6 (i.e. B = 3, T = 6), so the mask is
# [[True, True, False, False, False, False],
# [True, True, True, True, False, False],
# [True, True, True, True, True, False]]
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases where `len(symbols) < s_range`
s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfy some constraints, see docs in
# `_adjust_pruning_lower_bound` for more details of these constraints.
# T1 == T here means we are using the non-regular(i.e. modified rnnt or
# constrained rnnt) version of transducer, the third constraint becomes
# `s_begin[i + 1] - s_begin[i] < 2`, because it only emits one symbol per
# frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
s_range, device=px_grad.device
)
return ranges
def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prune the output of encoder(am) and prediction network(lm) with ranges
generated by `get_rnnt_prune_ranges`.
Args:
am:
The encoder output, with shape (B, T, C)
lm:
The prediction network output, with shape (B, S + 1, C)
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` for more details of this tensor.
Returns:
Return the pruned am and lm with shape (B, T, s_range, C)
"""
# am (B, T, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
assert ranges.shape[0] == am.shape[0], (ranges.shape[0], am.shape[0])
assert ranges.shape[0] == lm.shape[0], (ranges.shape[0], lm.shape[0])
assert am.shape[1] == ranges.shape[1], (am.shape[1], ranges.shape[1])
(B, T, s_range) = ranges.shape
(B, S1, C) = lm.shape
S = S1 - 1
# (B, T, s_range, C)
am_pruning = am.unsqueeze(2).expand((B, T, s_range, C))
# (B, T, s_range, C)
lm_pruning = torch.gather(
lm.unsqueeze(1).expand((B, T, S + 1, C)),
dim=2,
index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)),
)
return am_pruning, lm_pruning
def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
"""Roll tensor with different shifts for each row.
Note:
We assume the src is a 3 dimensions tensor and roll the last dimension.
Example:
>>> src = torch.arange(15).reshape((1,3,5))
>>> src
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]]])
>>> shift = torch.tensor([[1, 2, 3]])
>>> shift
tensor([[1, 2, 3]])
>>> _roll_by_shifts(src, shift)
tensor([[[ 4, 0, 1, 2, 3],
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
"""
assert src.dim() == 3, src.dim()
(B, T, S) = src.shape
assert shifts.shape == (B, T), shifts.shape
index = (
torch.arange(S, device=src.device)
.view((1, S))
.repeat((T, 1))
.repeat((B, 1, 1))
)
index = (index - shifts.reshape(B, T, 1)) % S
return torch.gather(src, 2, index)
def get_rnnt_logprobs_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C)
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert logits.ndim == 4, logits.ndim
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range), ranges.shape
(B, S) = symbols.shape
assert S >= 1, S
assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3)
symbols_with_terminal = torch.cat(
(
symbols,
torch.tensor(
[termination_symbol] * B,
dtype=torch.int64,
device=symbols.device,
).reshape((B, 1)),
),
dim=1,
)
# (B, T, s_range)
pruned_symbols = torch.gather(
symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)),
dim=2,
index=ranges,
)
# (B, T, s_range)
px = torch.gather(
logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1)
).squeeze(-1)
px = px - normalizers
# (B, T, S) with index larger than s_range in dim 2 fill with -inf
px = torch.cat(
(
px,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=px.device,
dtype=px.dtype,
),
),
dim=2,
)
# (B, T, S) with index out of s_range in dim 2 fill with -inf
px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S]
px = px.permute((0, 2, 1))
if rnnt_type == "regular":
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range)
py = py - normalizers
# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf
py = torch.cat(
(
py,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=py.device,
dtype=py.dtype,
),
),
dim=2,
)
# (B, T, S + 1) with index out of s_range in dim 2 fill with -inf
py = _roll_by_shifts(py, ranges[:, :, 0])
# (B, S + 1, T)
py = py.permute((0, 2, 1))
if rnnt_type == "regular":
px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py)
def rnnt_loss_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the number of symbols kept for each frame.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C),
i.e. batch, time_seq_len, prune_range, num_classes
symbols: