/
rnn.py
975 lines (850 loc) · 34.1 KB
/
rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
import tensorflow as tf
import tree
from keras_core.utils.nest import pack_sequence_as
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
"""Iterates over the time dimension of a tensor.
Args:
step_function: RNN step function.
Args;
`input`; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
`states`; List of tensors.
Returns;
`output`; Tensor with shape `(samples, output_dim)`
(no time dimension).
`new_states`; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D), or nested tensors, and each of which has shape
`(samples, time, ...)`.
initial_states: Tensor with shape `(samples, state_size)`
(no time dimension), containing the initial values for the states
used in the step function. In the case that state_size is in a
nested shape, the shape of initial_states will also follow the
nested structure.
go_backwards: Boolean. If `True`, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: An integer or a 1-D Tensor, depending on whether
the time dimension is fixed-length or not. In case of variable
length input, it is used for masking in case there's no mask
specified.
time_major: Boolean. If `True`, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of
the RNN calculation. However, most TensorFlow data is batch-major,
so by default this function accepts input and emits output in
batch-major form.
zero_output_for_mask: Boolean. If `True`, the output for masked timestep
will be zeros, whereas in the `False` case, output from previous
timestep is returned.
return_all_outputs: Boolean. If `True`, return the recurrent outputs for
all timesteps in the sequence. If `False`, only return the output
for the last timestep (which consumes less memory).
Returns:
A tuple, `(last_output, outputs, new_states)`.
- `last_output`: the latest output of the rnn,
with shape `(samples, ...)`.
- `outputs`:
- If `return_all_outputs=True`: a tensor with shape
`(samples, time, ...)` where each entry `outputs[s, t]` is the
output of the step function at time `t` for sample `s`
- Else, a tensor equal to `last_output` with shape
`(samples, 1, ...)`
- `new_states`: list of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
"""
input_length = input_length or inputs.shape[1]
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return tf.transpose(input_t, axes)
if not time_major:
inputs = tree.map_structure(swap_batch_timestep, inputs)
flattened_inputs = tree.flatten(inputs)
time_steps = flattened_inputs[0].shape[0]
time_steps_t = tf.shape(flattened_inputs[0])[0]
for input_ in flattened_inputs:
input_.shape.with_rank_at_least(3)
if mask is not None:
if mask.dtype != tf.bool:
mask = tf.cast(mask, tf.bool)
if len(mask.shape) == 2:
mask = tf.expand_dims(mask, axis=-1)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
# tf.where needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
# So we need to broadcast the mask to match the shape of inputs.
# That's what the tile call does, it just repeats the mask along its
# second dimension n times.
def _expand_mask(mask_t, input_t, fixed_dim=1):
if tree.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if tree.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = tf.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
return tf.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = tf.unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if tree.is_nested(inputs):
processed_input = tree.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = tf.unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = tf.zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = tf.where(tiled_mask_t, output, prev_output)
flat_states = tree.flatten(states)
flat_new_states = tree.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
tf.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
if zero_output_for_mask:
last_output = tf.where(
_expand_mask(mask_list[-1], last_output),
last_output,
tf.zeros_like(last_output),
)
outputs = tf.where(
_expand_mask(mask, outputs, fixed_dim=2),
outputs,
tf.zeros_like(outputs),
)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = tf.stack(successive_outputs)
else: # Unroll == False
states = tuple(initial_states)
# Create input tensor array, if the inputs is nested tensors, then it
# will be flattened first, and tensor array will be created one per
# flattened tensor.
input_ta = tuple(
tf.TensorArray(
dtype=inp.dtype,
size=time_steps_t,
tensor_array_name=f"input_ta_{i}",
)
for i, inp in enumerate(flattened_inputs)
)
input_ta = tuple(
ta.unstack(input_)
if not go_backwards
else ta.unstack(tf.reverse(input_, [0]))
for ta, input_ in zip(input_ta, flattened_inputs)
)
# Get the time(0) input and compute the output for that, the output will
# be used to determine the dtype of output tensor array. Don't read from
# input_ta due to TensorArray clear_after_read default to True.
input_time_zero = pack_sequence_as(
inputs, [inp[0] for inp in flattened_inputs]
)
# output_time_zero is used to determine the cell output shape and its
# dtype. the value is discarded.
output_time_zero, _ = step_function(
input_time_zero, tuple(initial_states) + tuple(constants)
)
output_ta_size = time_steps_t if return_all_outputs else 1
output_ta = tuple(
tf.TensorArray(
dtype=out.dtype,
size=output_ta_size,
element_shape=out.shape,
tensor_array_name=f"output_ta_{i}",
)
for i, out in enumerate(tree.flatten(output_time_zero))
)
time = tf.constant(0, dtype="int32", name="time")
if input_length is None:
max_iterations = time_steps_t
else:
max_iterations = tf.reduce_max(input_length)
while_loop_kwargs = {
"cond": lambda time, *_: time < time_steps_t,
"maximum_iterations": max_iterations,
"parallel_iterations": 32,
"swap_memory": True,
}
if mask is not None:
if go_backwards:
mask = tf.reverse(mask, [0])
mask_ta = tf.TensorArray(
dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
)
mask_ta = mask_ta.unstack(mask)
def masking_fn(time):
return mask_ta.read(time)
def compute_masked_output(mask_t, flat_out, flat_mask):
tiled_mask_t = tuple(
_expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
for o in flat_out
)
return tuple(
tf.where(m, o, fm)
for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
)
elif isinstance(input_length, tf.Tensor):
if go_backwards:
max_len = tf.reduce_max(input_length, axis=0)
rev_input_length = tf.subtract(max_len - 1, input_length)
def masking_fn(time):
return tf.less(rev_input_length, time)
else:
def masking_fn(time):
return tf.greater(input_length, time)
def compute_masked_output(mask_t, flat_out, flat_mask):
return tuple(
tf.where(mask_t, o, zo)
for (o, zo) in zip(flat_out, flat_mask)
)
else:
masking_fn = None
if masking_fn is not None:
# Mask for the T output will be base on the output of T - 1. In the
# case T = 0, a zero filled tensor will be used.
flat_zero_output = tuple(
tf.zeros_like(o) for o in tree.flatten(output_time_zero)
)
def _step(time, output_ta_t, prev_output, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
# maybe set shape.
current_input = pack_sequence_as(inputs, current_input)
mask_t = masking_fn(time)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
# mask output
flat_output = tree.flatten(output)
flat_mask_output = (
flat_zero_output
if zero_output_for_mask
else tree.flatten(prev_output)
)
flat_new_output = compute_masked_output(
mask_t, flat_output, flat_mask_output
)
# mask states
flat_state = tree.flatten(states)
flat_new_state = tree.flatten(new_states)
flat_final_state = compute_masked_output(
mask_t, flat_new_state, flat_state
)
new_states = pack_sequence_as(new_states, flat_final_state)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_new_output)
)
return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
new_states
)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta, flat_zero_output) + states,
**while_loop_kwargs,
)
# Skip final_outputs[2] which is the output for final timestep.
new_states = final_outputs[3:]
else:
def _step(time, output_ta_t, *states):
"""RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
*states: List of states.
Returns:
Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
"""
current_input = tuple(ta.read(time) for ta in input_ta)
current_input = pack_sequence_as(inputs, current_input)
output, new_states = step_function(
current_input, tuple(states) + tuple(constants)
)
flat_new_state = tree.flatten(new_states)
flat_output = tree.flatten(output)
ta_index_to_write = time if return_all_outputs else 0
output_ta_t = tuple(
ta.write(ta_index_to_write, out)
for ta, out in zip(output_ta_t, flat_output)
)
new_states = pack_sequence_as(initial_states, flat_new_state)
return (time + 1, output_ta_t) + tuple(new_states)
final_outputs = tf.while_loop(
body=_step,
loop_vars=(time, output_ta) + states,
**while_loop_kwargs,
)
new_states = final_outputs[2:]
output_ta = final_outputs[1]
outputs = tuple(o.stack() for o in output_ta)
last_output = tuple(o[-1] for o in outputs)
outputs = pack_sequence_as(output_time_zero, outputs)
last_output = pack_sequence_as(output_time_zero, last_output)
if not time_major:
outputs = tree.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states
def gru(
inputs,
initial_state,
mask,
kernel,
recurrent_kernel,
bias,
activation,
recurrent_activation,
return_sequences=False,
go_backwards=False,
unroll=False,
time_major=False,
reset_after=True,
):
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
cudnn_supported = cudnn_ok(
activation,
recurrent_activation,
unroll,
use_bias=bias is not None,
reset_after=reset_after,
)
if not cudnn_supported or not inputs_supported:
raise NotImplementedError
from keras_core.backend.tensorflow import Variable
if isinstance(kernel, Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
bias = bias.value
try:
return _cudnn_gru(
inputs,
initial_state,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
)
except tf.errors.InvalidArgumentError:
# cuDNN op not found.
raise NotImplementedError
except tf.errors.NotFoundError:
# alternative error: device not found for op
raise NotImplementedError
def _do_gru_arguments_support_cudnn(
activation,
recurrent_activation,
unroll,
use_bias,
reset_after,
):
from keras_core import activations
from keras_core import ops
return (
activation in (activations.tanh, tf.tanh, ops.tanh)
and recurrent_activation
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
and not unroll
and use_bias
and reset_after
)
def _do_lstm_arguments_support_cudnn(
activation,
recurrent_activation,
unroll,
use_bias,
):
from keras_core import activations
from keras_core import ops
return (
activation in (activations.tanh, tf.tanh, ops.tanh)
and recurrent_activation
in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
and not unroll
and use_bias
)
def _do_rnn_inputs_support_cudnn(mask, time_major):
if tf.sysconfig.get_build_info()["is_rocm_build"]:
if mask is not None:
return tf.reduce_all(mask)
return True
if mask is None:
return True
if time_major:
mask = tf.transpose(mask)
return tf.logical_and(
_is_sequence_right_padded(mask),
tf.logical_not(_has_fully_masked_sequence(mask)),
)
def _is_sequence_right_padded(mask):
"""Check the mask tensor and see if it right padded.
For cuDNN kernel, it uses the sequence length param to skip the tailing
timestep. If the data is left padded, or not a strict right padding (has
masked value in the middle of the sequence), then cuDNN kernel won't be work
properly in those cases.
Left padded data: [[False, False, True, True, True]].
Right padded data: [[True, True, True, False, False]].
Mixture of mask/unmasked data: [[True, False, True, False, False]].
Note that for the mixed data example above, the actually data RNN should see
are those 2 Trues (index 0 and 2), the index 1 False should be ignored and
not pollute the internal states.
Args:
mask: the Boolean tensor with shape [batch, timestep]
Returns:
boolean scalar tensor, whether the mask is strictly right padded.
"""
max_seq_length = tf.shape(mask)[1]
count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)
return tf.reduce_all(
tf.equal(
tf.cast(mask, dtype="bool"),
tf.cast(right_padded_mask, dtype="bool"),
)
)
def _has_fully_masked_sequence(mask):
# Cudnn kernel will error out if the input sequence contains any
# fully masked data. We walk around this issue by rerouting the computation
# to standard kernel, until the issue on cudnn side has been fixed. For a
# fully masked sequence, it will contain all Falses. To make it easy to
# check, we inverse the boolean, check if any of the sequence has all True.
return tf.reduce_any(
tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1)
)
def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):
"""Utility function convert variable to cuDNN compatible parameter.
Note that Keras weights for kernels are different from the cuDNN format.
Eg.:
```
Keras cuDNN
[[0, 1, 2], <---> [[0, 2, 4],
[3, 4, 5]] [1, 3, 5]]
```
If the input weights need to be in a unified format, then set
`transpose_weights=True` to convert the weights.
Args:
weights: list of weights for the kernels and recurrent kernels.
biases: list of biases for individual gate.
shape: the shape for the converted variables that will be feed to cuDNN.
transpose_weights: boolean, whether to transpose the weights.
Returns:
The converted weights that can be feed to cuDNN ops as param.
"""
def convert(w):
return tf.transpose(w) if transpose_weights else w
weights = [tf.reshape(convert(x), shape) for x in weights]
biases = [tf.reshape(x, shape) for x in biases]
return tf.concat(weights + biases, axis=0)
def _compute_sequence_length_from_mask(mask, time_major):
"""Calculate the sequence length tensor (1-D) based on the masking tensor.
The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
any timestep that should be masked, the corresponding field will be False.
Consider the following example:
a = [[True, True, False, False],
[True, True, True, False]]
It is a (2, 4) tensor, and the corresponding sequence length result should
be 1D tensor with value [2, 3]. Note that the masking tensor must be right
padded that could be checked by, e.g., `is_sequence_right_padded()`.
Args:
mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
time_major=True.
time_major: Boolean, which indicates whether the mask is time major or
batch major.
Returns:
sequence_length: 1D int32 tensor.
"""
timestep_index = 0 if time_major else 1
return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)
def _is_gpu_available():
return bool(tf.config.list_logical_devices("GPU"))
@tf.function(autograph=False)
def _cudnn_gru(
inputs,
initial_state,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
):
"""GRU with cuDNN implementation which is only available for GPU."""
if mask is not None:
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
else:
sequence_lengths = None
if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
# For init_h, cuDNN expects one more dim of num_layers before or after batch
# dim for time major or batch major inputs respectively
init_h = tf.expand_dims(initial_state, axis=seq_axis)
weights = tf.split(kernel, 3, axis=1)
weights += tf.split(recurrent_kernel, 3, axis=1)
# Note that the bias was initialized as shape (2, 3 * units), flatten it to
# (6 * units)
bias = tf.split(tf.reshape(bias, [-1]), 6)
if tf.sysconfig.get_build_info()["is_cuda_build"]:
# Note that the gate order for cuDNN is different from the canonical
# format. canonical format is [z, r, h], whereas cuDNN is [r, z, h].
# The swap need to be done for kernel, recurrent_kernel, input_bias,
# recurrent_bias.
# z is update gate weights.
# r is reset gate weights.
# h is output gate weights.
weights[0], weights[1] = weights[1], weights[0]
weights[3], weights[4] = weights[4], weights[3]
bias[0], bias[1] = bias[1], bias[0]
bias[3], bias[4] = bias[4], bias[3]
params = _standardize_cudnn_weights(
weights=weights,
biases=bias,
shape=tf.constant([-1]),
transpose_weights=True,
)
if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
# output_from_cudnn = [6, 5, 4, 0, 0]
# expected_output = [0, 0, 6, 5 ,4]
inputs = tf.reverse_sequence(
inputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
input=inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode="gru",
sequence_lengths=sequence_lengths,
time_major=time_major,
)
if go_backwards:
outputs = tf.reverse_sequence(
outputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs = tf.reverse(outputs, axis=[seq_axis])
else:
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = tf.reverse(inputs, axis=[0])
outputs, h, _, _ = tf.raw_ops.CudnnRNN(
input=inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode="gru",
)
last_output = outputs[-1]
if not time_major and sequence_lengths is None and return_sequences:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
state = tf.squeeze(h, axis=seq_axis)
# In the case of variable length input, the cudnn kernel will fill zeros for
# the output, whereas the default keras behavior is to bring over the
# previous output for t-1, so that in the return_sequence=False case, user
# can quickly get the final effect output instead just 0s at the last
# timestep. In order to mimic the default keras behavior, we copy the final
# h state as the last_output, since it is numerically same as the output.
if sequence_lengths is not None:
last_output = state
# Match CPU return format
if not return_sequences:
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
return (
last_output,
outputs,
state,
)
def cudnn_ok(
activation,
recurrent_activation,
unroll,
use_bias,
reset_after=None,
):
if reset_after is None:
args_supported = _do_lstm_arguments_support_cudnn(
activation=activation,
recurrent_activation=recurrent_activation,
unroll=unroll,
use_bias=use_bias,
)
else:
args_supported = _do_gru_arguments_support_cudnn(
activation=activation,
recurrent_activation=recurrent_activation,
unroll=unroll,
use_bias=use_bias,
reset_after=reset_after,
)
return args_supported and _is_gpu_available()
def lstm(
inputs,
initial_state_h,
initial_state_c,
mask,
kernel,
recurrent_kernel,
bias,
activation,
recurrent_activation,
return_sequences=False,
go_backwards=False,
unroll=False,
time_major=False,
):
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
cudnn_supported = cudnn_ok(
activation, recurrent_activation, unroll, use_bias=bias is not None
)
if not cudnn_supported or not inputs_supported:
raise NotImplementedError
from keras_core.backend.tensorflow import Variable
if isinstance(kernel, Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
bias = bias.value
try:
return _cudnn_lstm(
inputs,
initial_state_h,
initial_state_c,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
)
except tf.errors.InvalidArgumentError:
# cuDNN op not found.
raise NotImplementedError
except tf.errors.NotFoundError:
# alternative error: device not found for op
raise NotImplementedError
def _cudnn_lstm(
inputs,
initial_state_h,
initial_state_c,
kernel,
recurrent_kernel,
bias,
mask,
time_major,
go_backwards,
return_sequences,
):
if mask is not None:
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
else:
sequence_lengths = None
if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
# For init_h and init_c, cuDNN expects one more dim of num_layers before or
# after batch dim for time major or batch major inputs respectively
init_h = tf.expand_dims(initial_state_h, axis=seq_axis)
init_c = tf.expand_dims(initial_state_c, axis=seq_axis)
weights = tf.split(kernel, 4, axis=1)
weights += tf.split(recurrent_kernel, 4, axis=1)
# cuDNN has an extra set of bias for inputs, we disable them (setting to 0),
# so that mathematically it is same as the canonical LSTM implementation.
full_bias = tf.concat((tf.zeros_like(bias), bias), 0)
if tf.sysconfig.get_build_info()["is_rocm_build"]:
# ROCm MIOpen's weight sequence for LSTM is different from both
# canonical and Cudnn format
# MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
# i is input gate weights.
# f is forget gate weights.
# o is output gate weights.
# c is cell gate weights.
weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
# full_bias is a tensor of shape (8*n,)
full_bias = tf.split(full_bias, 8, axis=0)
full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
params = _standardize_cudnn_weights(
weights=weights,
biases=tf.split(full_bias, 8),
shape=tf.constant([-1]),
transpose_weights=True,
)
if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
# output_from_cudnn = [6, 5, 4, 0, 0]
# expected_output = [0, 0, 6, 5 ,4]
inputs = tf.reverse_sequence(
inputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
input=inputs,
input_h=init_h,
input_c=init_c,
params=params,
is_training=True,
rnn_mode="lstm",
sequence_lengths=sequence_lengths,
time_major=time_major,
)
if go_backwards:
outputs = tf.reverse_sequence(
outputs,
sequence_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis,
)
outputs = tf.reverse(outputs, axis=[seq_axis])
else:
# # Fill the array with shape [batch] with value of max timesteps.
# sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
# array_ops.shape(inputs)[0])
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = tf.reverse(inputs, axis=[0])
outputs, h, c, _ = tf.raw_ops.CudnnRNN(
input=inputs,
input_h=init_h,
input_c=init_c,
params=params,
is_training=True,
rnn_mode="lstm",
)
last_output = outputs[-1]
if not time_major and sequence_lengths is None and return_sequences:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
h = tf.squeeze(h, axis=seq_axis)
c = tf.squeeze(c, axis=seq_axis)
# In the case of variable length input, the cudnn kernel will fill zeros for
# the output, whereas the default keras behavior is to bring over the
# previous output for t-1, so that in the return_sequence=False case, user
# can quickly get the final effect output instead just 0s at the last
# timestep. In order to mimic the default keras behavior, we copy the final
# h state as the last_output, since it is numerically same as the output.
if sequence_lengths is not None:
last_output = h
# Match CPU return format
if not return_sequences:
outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
return (last_output, outputs, [h, c])