-
Notifications
You must be signed in to change notification settings - Fork 34
/
layers.py
1129 lines (993 loc) · 43.3 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Core classes for the KerasLMU package."""
import warnings
import keras
import numpy as np
import tensorflow as tf
from packaging import version
# pylint: disable=ungrouped-imports
tf_version = version.parse(tf.__version__)
if tf_version < version.parse("2.9.0rc0"):
from keras.layers.recurrent import DropoutRNNCellMixin
elif tf_version < version.parse("2.13.0rc0"):
from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
elif tf_version < version.parse("2.16.0rc0"):
from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
else:
from keras.src.layers.rnn.dropout_rnn_cell import (
DropoutRNNCell as DropoutRNNCellMixin,
)
if tf_version < version.parse("2.8.0rc0"):
from tensorflow.keras.layers import Layer as BaseRandomLayer
elif tf_version < version.parse("2.13.0rc0"):
from keras.engine.base_layer import BaseRandomLayer
elif tf_version < version.parse("2.16.0rc0"):
from keras.src.engine.base_layer import BaseRandomLayer
else:
from keras.layers import Layer as BaseRandomLayer
@tf.keras.utils.register_keras_serializable("keras-lmu")
class LMUCell(
DropoutRNNCellMixin, BaseRandomLayer
): # pylint: disable=too-many-ancestors
"""
Implementation of LMU cell (to be used within Keras RNN wrapper).
In general, the LMU cell consists of two parts: a memory component (decomposing
the input signal using Legendre polynomials as a basis), and a hidden component
(learning nonlinear mappings from the memory component). [1]_ [2]_
This class processes one step within the whole time sequence input. Use the ``LMU``
class to create a recurrent Keras layer to process the whole sequence. Calling
``LMU()`` is equivalent to doing ``RNN(LMUCell())``.
Parameters
----------
memory_d : int
Dimensionality of input to memory component.
order : int
The number of degrees in the transfer function of the LTI system used to
represent the sliding window of history. This parameter sets the number of
Legendre polynomials used to orthogonally represent the sliding window.
theta : float
The number of timesteps in the sliding window that is represented using the
LTI system. In this context, the sliding window represents a dynamic range of
data, of fixed size, that will be used to predict the value at the next time
step. If this value is smaller than the size of the input sequence, only that
number of steps will be represented at the time of prediction, however the
entire sequence will still be processed in order for information to be
projected to and from the hidden layer. If ``trainable_theta`` is enabled, then
theta will be updated during the course of training.
hidden_cell : ``keras.layers.Layer``
Keras Layer/RNNCell implementing the hidden component.
trainable_theta : bool
If True, theta is learnt over the course of training. Otherwise, it is kept
constant.
hidden_to_memory : bool
If True, connect the output of the hidden component back to the memory
component (default False).
memory_to_memory : bool
If True, add a learnable recurrent connection (in addition to the static
Legendre system) to the memory component (default False).
input_to_hidden : bool
If True, connect the input directly to the hidden component (in addition to
the connection from the memory component) (default False).
discretizer : str
The method used to discretize the A and B matrices of the LMU. Current
options are "zoh" (short for Zero Order Hold) and "euler".
"zoh" is more accurate, but training will be slower than "euler" if
``trainable_theta=True``. Note that a larger theta is needed when discretizing
using "euler" (a value that is larger than ``4*order`` is recommended).
kernel_initializer : ``tf.initializers.Initializer``
Initializer for weights from input to memory/hidden component. If ``None``,
no weights will be used, and the input size must match the memory/hidden size.
recurrent_initializer : ``tf.initializers.Initializer``
Initializer for ``memory_to_memory`` weights (if that connection is enabled).
kernel_regularizer : ``keras.regularizers.Regularizer``
Regularizer for weights from input to memory/hidden component.
recurrent_regularizer : ``keras.regularizers.Regularizer``
Regularizer for ``memory_to_memory`` weights (if that connection is enabled).
use_bias : bool
If True, the memory component includes a bias term.
bias_initializer : ``tf.initializers.Initializer``
Initializer for the memory component bias term. Only used if ``use_bias=True``.
bias_regularizer : ``keras.regularizers.Regularizer``
Regularizer for the memory component bias term. Only used if ``use_bias=True``.
dropout : float
Dropout rate on input connections.
recurrent_dropout : float
Dropout rate on ``memory_to_memory`` connection.
References
----------
.. [1] Voelker and Eliasmith (2018). Improving spiking dynamical
networks: Accurate delays, higher-order synapses, and time cells.
Neural Computation, 30(3): 569-609.
.. [2] Voelker and Eliasmith. "Methods and systems for implementing
dynamic neural networks." U.S. Patent Application No. 15/243,223.
Filing date: 2016-08-22.
"""
def __init__(
self,
memory_d,
order,
theta,
hidden_cell,
trainable_theta=False,
hidden_to_memory=False,
memory_to_memory=False,
input_to_hidden=False,
discretizer="zoh",
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
kernel_regularizer=None,
recurrent_regularizer=None,
use_bias=False,
bias_initializer="zeros",
bias_regularizer=None,
dropout=0,
recurrent_dropout=0,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
self.memory_d = memory_d
self.order = order
self._init_theta = theta
self.hidden_cell = hidden_cell
self.trainable_theta = trainable_theta
self.hidden_to_memory = hidden_to_memory
self.memory_to_memory = memory_to_memory
self.input_to_hidden = input_to_hidden
self.discretizer = discretizer
self.kernel_initializer = kernel_initializer
self.recurrent_initializer = recurrent_initializer
self.kernel_regularizer = kernel_regularizer
self.recurrent_regularizer = recurrent_regularizer
self.use_bias = use_bias
self.bias_initializer = bias_initializer
self.bias_regularizer = bias_regularizer
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.seed = seed
if tf_version >= version.parse("2.16.0"):
self.seed_generator = keras.random.SeedGenerator(seed)
self.kernel = None
self.recurrent_kernel = None
self.bias = None
self.theta_inv = None
self.A = None
self.B = None
if self.discretizer not in ("zoh", "euler"):
raise ValueError(
f"discretizer must be 'zoh' or 'euler' (got '{self.discretizer}')"
)
if self.hidden_cell is None:
if self.hidden_to_memory:
raise ValueError(
"hidden_to_memory must be False if hidden_cell is None"
)
self.hidden_output_size = self.memory_d * self.order
self.hidden_state_size = []
elif hasattr(self.hidden_cell, "state_size"):
self.hidden_output_size = self.hidden_cell.output_size
self.hidden_state_size = self.hidden_cell.state_size
else:
# TODO: support layers that don't have the `units` attribute
self.hidden_output_size = self.hidden_cell.units
self.hidden_state_size = [self.hidden_cell.units]
self.state_size = [self.memory_d * self.order] + tf.nest.flatten(
self.hidden_state_size
)
self.output_size = self.hidden_output_size
@property
def theta(self):
"""
Value of the ``theta`` parameter.
If ``trainable_theta=True`` this returns the trained value, not the
initial value passed in to the constructor.
"""
if self.built:
return 1 / self.theta_inv.numpy()
return self._init_theta
def _gen_AB(self):
"""Generates A and B matrices."""
# compute analog A/B matrices
Q = np.arange(self.order, dtype=np.float64)
R = (2 * Q + 1)[:, None]
j, i = np.meshgrid(Q, Q)
A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
B = (-1.0) ** Q[:, None] * R
# discretize matrices
if self.discretizer == "zoh":
# save the un-discretized matrices for use in .call
self._base_A = tf.constant(A.T, dtype=self.dtype)
self._base_B = tf.constant(B.T, dtype=self.dtype)
self.A, self.B = LMUCell._cont2discrete_zoh(
self._base_A / self._init_theta, self._base_B / self._init_theta
)
else:
if not self.trainable_theta:
A = A / self._init_theta + np.eye(self.order)
B = B / self._init_theta
self.A = tf.constant(A.T, dtype=self.dtype)
self.B = tf.constant(B.T, dtype=self.dtype)
@staticmethod
def _cont2discrete_zoh(A, B):
"""
Function to discretize A and B matrices using Zero Order Hold method.
Functionally equivalent to
``scipy.signal.cont2discrete((A.T, B.T, _, _), method="zoh", dt=1.0)``
(but implemented in TensorFlow so that it is differentiable).
Note that this accepts and returns matrices that are transposed from the
standard linear system implementation (as that makes it easier to use in
`.call`).
"""
# combine A/B and pad to make square matrix
em_upper = tf.concat([A, B], axis=0) # pylint: disable=no-value-for-parameter
em = tf.pad(em_upper, [(0, 0), (0, B.shape[0])])
# compute matrix exponential
ms = tf.linalg.expm(em)
# slice A/B back out of combined matrix
discrt_A = ms[: A.shape[0], : A.shape[1]]
discrt_B = ms[A.shape[0] :, : A.shape[1]]
return discrt_A, discrt_B
def build(self, input_shape):
"""
Builds the cell.
Notes
-----
This method should not be called manually; rather, use the implicit layer
callable behaviour (like ``my_layer(inputs)``), which will apply this method
with some additional bookkeeping.
"""
super().build(input_shape)
enc_d = input_shape[-1]
if self.hidden_to_memory:
enc_d += self.hidden_output_size
if self.kernel_initializer is not None:
self.kernel = self.add_weight(
name="kernel",
shape=(enc_d, self.memory_d),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
)
elif enc_d != self.memory_d:
raise ValueError(
f"For LMUCells with no input kernel, the input dimension ({enc_d})"
f" must equal `memory_d` ({self.memory_d})."
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.memory_d,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
# when using euler, 1/theta results in better gradients for the memory
# update since you are multiplying 1/theta, as compared to dividing theta
if self.trainable_theta:
self.theta_inv = self.add_weight(
name="theta_inv",
shape=(),
initializer=tf.initializers.constant(1 / self._init_theta),
constraint=keras.constraints.NonNeg(),
)
else:
self.theta_inv = tf.constant(1 / self._init_theta, dtype=self.dtype)
if self.memory_to_memory:
self.recurrent_kernel = self.add_weight(
name="recurrent_kernel",
shape=(self.memory_d * self.order, self.memory_d),
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
)
else:
self.recurrent_kernel = None
if self.hidden_cell is not None and not self.hidden_cell.built:
hidden_input_d = self.memory_d * self.order
if self.input_to_hidden:
hidden_input_d += input_shape[-1]
with tf.name_scope(self.hidden_cell.name):
self.hidden_cell.build((input_shape[0], hidden_input_d))
# generate A and B matrices
self._gen_AB()
def call(self, inputs, states, training=False): # noqa: C901
"""
Apply this cell to inputs.
Notes
-----
This method should not be called manually; rather, use the implicit layer
callable behaviour (like ``my_layer(inputs)``), which will apply this method
with some additional bookkeeping.
"""
states = tf.nest.flatten(states)
# state for the LMU memory
m = states[0]
# state for the hidden cell
h = states[1:]
# compute memory input
u = (
tf.concat((inputs, h[0]), axis=1) # pylint: disable=no-value-for-parameter
if self.hidden_to_memory
else inputs
)
if training and self.dropout > 0:
u *= self.get_dropout_mask(u)
if self.kernel is not None:
u = tf.matmul(u, self.kernel, name="kernel_matmul")
if self.bias is not None:
u = u + self.bias
if self.memory_to_memory:
if training and self.recurrent_dropout > 0:
# note: we don't apply dropout to the memory input, only
# the recurrent kernel
rec_m = m * self.get_recurrent_dropout_mask(m)
else:
rec_m = m
u = u + tf.matmul(
rec_m, self.recurrent_kernel, name="recurrent_kernel_matmul"
)
# separate memory/order dimensions
m = tf.reshape(m, (-1, self.memory_d, self.order))
u = tf.expand_dims(u, -1)
# update memory
if self.discretizer == "zoh" and self.trainable_theta:
# apply updated theta and re-discretize
A, B = LMUCell._cont2discrete_zoh(
self._base_A * self.theta_inv, self._base_B * self.theta_inv
)
else:
A, B = self.A, self.B
_m = tf.matmul(m, A) + tf.matmul(u, B)
if self.discretizer == "euler" and self.trainable_theta:
# apply updated theta. this is the same as scaling A/B by theta, but it's
# more efficient to do it this way.
# note that when computing this way the A matrix does not
# include the identity matrix along the diagonal (since we don't want to
# scale that part by theta), which is why we do += instead of =
m += _m * self.theta_inv
else:
m = _m
# re-combine memory/order dimensions
m = tf.reshape(m, (-1, self.memory_d * self.order))
# apply hidden cell
h_in = (
tf.concat((m, inputs), axis=1) # pylint: disable=no-value-for-parameter
if self.input_to_hidden
else m
)
if self.hidden_cell is None:
o = h_in
h = []
elif hasattr(self.hidden_cell, "state_size"):
o, h = self.hidden_cell(h_in, h, training=training)
else:
o = self.hidden_cell(h_in, training=training)
h = [o]
return o, [m] + h
def get_dropout_mask(self, step_input):
"""Get dropout mask for cell input."""
if tf_version < version.parse("2.16.0rc0"):
return super().get_dropout_mask_for_cell(step_input, True, count=1)
return super().get_dropout_mask(step_input)
def get_recurrent_dropout_mask(self, step_input):
"""Get dropout mask for recurrent input."""
if tf_version < version.parse("2.16.0rc0"):
return super().get_recurrent_dropout_mask_for_cell(
step_input, True, count=1
)
# This is copied from DropoutRNNCell.get_recurrent_dropout_mask, with the
# change noted below in order to fix a bug.
# See https://github.com/keras-team/keras/issues/19395
if not hasattr(self, "_recurrent_dropout_mask"):
self._recurrent_dropout_mask = None
if self._recurrent_dropout_mask is None and self.recurrent_dropout > 0:
ones = keras.ops.ones_like(step_input)
self._recurrent_dropout_mask = keras.src.backend.random.dropout(
ones,
# --- START DIFF ---
# rate=self.dropout,
rate=self.recurrent_dropout,
# --- END DIFF ---
seed=self.seed_generator,
)
return self._recurrent_dropout_mask
def reset_dropout_mask(self):
"""Reset dropout mask for memory and hidden components."""
super().reset_dropout_mask()
if isinstance(self.hidden_cell, DropoutRNNCellMixin):
self.hidden_cell.reset_dropout_mask()
def reset_recurrent_dropout_mask(self):
"""Reset recurrent dropout mask for memory and hidden components."""
super().reset_recurrent_dropout_mask()
if isinstance(self.hidden_cell, DropoutRNNCellMixin):
self.hidden_cell.reset_recurrent_dropout_mask()
def get_config(self):
"""Return config of layer (for serialization during model saving/loading)."""
config = super().get_config()
config.update(
{
"memory_d": self.memory_d,
"order": self.order,
"theta": self._init_theta,
"hidden_cell": keras.layers.serialize(self.hidden_cell),
"trainable_theta": self.trainable_theta,
"hidden_to_memory": self.hidden_to_memory,
"memory_to_memory": self.memory_to_memory,
"input_to_hidden": self.input_to_hidden,
"discretizer": self.discretizer,
"kernel_initializer": self.kernel_initializer,
"recurrent_initializer": self.recurrent_initializer,
"kernel_regularizer": self.kernel_regularizer,
"recurrent_regularizer": self.recurrent_regularizer,
"use_bias": self.use_bias,
"bias_initializer": self.bias_initializer,
"bias_regularizer": self.bias_regularizer,
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"seed": self.seed,
}
)
return config
@classmethod
def from_config(cls, config):
"""Load model from serialized config."""
config["hidden_cell"] = (
None
if config["hidden_cell"] is None
else keras.layers.deserialize(config["hidden_cell"])
)
return super().from_config(config)
@tf.keras.utils.register_keras_serializable("keras-lmu")
class LMU(keras.layers.Layer): # pylint: disable=too-many-ancestors,abstract-method
"""
A layer of trainable low-dimensional delay systems.
Each unit buffers its encoded input
by internally representing a low-dimensional
(i.e., compressed) version of the sliding window.
Nonlinear decodings of this representation,
expressed by the A and B matrices, provide
computations across the window, such as its
derivative, energy, median value, etc ([1]_, [2]_).
Note that these decoder matrices can span across
all of the units of an input sequence.
Parameters
----------
memory_d : int
Dimensionality of input to memory component.
order : int
The number of degrees in the transfer function of the LTI system used to
represent the sliding window of history. This parameter sets the number of
Legendre polynomials used to orthogonally represent the sliding window.
theta : float
The number of timesteps in the sliding window that is represented using the
LTI system. In this context, the sliding window represents a dynamic range of
data, of fixed size, that will be used to predict the value at the next time
step. If this value is smaller than the size of the input sequence, only that
number of steps will be represented at the time of prediction, however the
entire sequence will still be processed in order for information to be
projected to and from the hidden layer. If ``trainable_theta`` is enabled, then
theta will be updated during the course of training.
hidden_cell : ``keras.layers.Layer``
Keras Layer/RNNCell implementing the hidden component.
trainable_theta : bool
If True, theta is learnt over the course of training. Otherwise, it is kept
constant.
hidden_to_memory : bool
If True, connect the output of the hidden component back to the memory
component (default False).
memory_to_memory : bool
If True, add a learnable recurrent connection (in addition to the static
Legendre system) to the memory component (default False).
input_to_hidden : bool
If True, connect the input directly to the hidden component (in addition to
the connection from the memory component) (default False).
discretizer : str
The method used to discretize the A and B matrices of the LMU. Current
options are "zoh" (short for Zero Order Hold) and "euler".
"zoh" is more accurate, but training will be slower than "euler" if
``trainable_theta=True``. Note that a larger theta is needed when discretizing
using "euler" (a value that is larger than ``4*order`` is recommended).
kernel_initializer : ``tf.initializers.Initializer``
Initializer for weights from input to memory/hidden component. If ``None``,
no weights will be used, and the input size must match the memory/hidden size.
recurrent_initializer : ``tf.initializers.Initializer``
Initializer for ``memory_to_memory`` weights (if that connection is enabled).
kernel_regularizer : ``keras.regularizers.Regularizer``
Regularizer for weights from input to memory/hidden component.
recurrent_regularizer : ``keras.regularizers.Regularizer``
Regularizer for ``memory_to_memory`` weights (if that connection is enabled).
use_bias : bool
If True, the memory component includes a bias term.
bias_initializer : ``tf.initializers.Initializer``
Initializer for the memory component bias term. Only used if ``use_bias=True``.
bias_regularizer : ``keras.regularizers.Regularizer``
Regularizer for the memory component bias term. Only used if ``use_bias=True``.
dropout : float
Dropout rate on input connections.
recurrent_dropout : float
Dropout rate on ``memory_to_memory`` connection.
return_sequences : bool, optional
If True, return the full output sequence. Otherwise, return just the last
output in the output sequence.
References
----------
.. [1] Voelker and Eliasmith (2018). Improving spiking dynamical
networks: Accurate delays, higher-order synapses, and time cells.
Neural Computation, 30(3): 569-609.
.. [2] Voelker and Eliasmith. "Methods and systems for implementing
dynamic neural networks." U.S. Patent Application No. 15/243,223.
Filing date: 2016-08-22.
"""
def __init__(
self,
memory_d,
order,
theta,
hidden_cell,
trainable_theta=False,
hidden_to_memory=False,
memory_to_memory=False,
input_to_hidden=False,
discretizer="zoh",
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
kernel_regularizer=None,
recurrent_regularizer=None,
use_bias=False,
bias_initializer="zeros",
bias_regularizer=None,
dropout=0,
recurrent_dropout=0,
return_sequences=False,
**kwargs,
):
super().__init__(**kwargs)
self.memory_d = memory_d
self.order = order
self._init_theta = theta
self.hidden_cell = hidden_cell
self.trainable_theta = trainable_theta
self.hidden_to_memory = hidden_to_memory
self.memory_to_memory = memory_to_memory
self.input_to_hidden = input_to_hidden
self.discretizer = discretizer
self.kernel_initializer = kernel_initializer
self.recurrent_initializer = recurrent_initializer
self.kernel_regularizer = kernel_regularizer
self.recurrent_regularizer = recurrent_regularizer
self.use_bias = use_bias
self.bias_initializer = bias_initializer
self.bias_regularizer = bias_regularizer
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.return_sequences = return_sequences
self.layer = None
@property
def theta(self):
"""
Value of the ``theta`` parameter.
If ``trainable_theta=True`` this returns the trained value, not the
initial value passed in to the constructor.
"""
if self.built:
return (
self.layer.theta
if isinstance(self.layer, LMUFeedforward)
else self.layer.cell.theta
)
return self._init_theta
def build(self, input_shape):
"""
Builds the layer.
Notes
-----
This method should not be called manually; rather, use the implicit layer
callable behaviour (like ``my_layer(inputs)``), which will apply this method
with some additional bookkeeping.
"""
super().build(input_shape)
if (
not self.hidden_to_memory
and not self.memory_to_memory
and not self.trainable_theta
):
self.layer = LMUFeedforward(
memory_d=self.memory_d,
order=self.order,
theta=self._init_theta,
hidden_cell=self.hidden_cell,
input_to_hidden=self.input_to_hidden,
discretizer=self.discretizer,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer,
use_bias=self.use_bias,
bias_initializer=self.bias_initializer,
bias_regularizer=self.bias_regularizer,
dropout=self.dropout,
return_sequences=self.return_sequences,
dtype=self.dtype,
)
else:
self.layer = keras.layers.RNN(
LMUCell(
memory_d=self.memory_d,
order=self.order,
theta=self._init_theta,
hidden_cell=self.hidden_cell,
trainable_theta=self.trainable_theta,
hidden_to_memory=self.hidden_to_memory,
memory_to_memory=self.memory_to_memory,
input_to_hidden=self.input_to_hidden,
discretizer=self.discretizer,
kernel_initializer=self.kernel_initializer,
recurrent_initializer=self.recurrent_initializer,
kernel_regularizer=self.kernel_regularizer,
recurrent_regularizer=self.recurrent_regularizer,
use_bias=self.use_bias,
bias_initializer=self.bias_initializer,
bias_regularizer=self.bias_regularizer,
dropout=self.dropout,
recurrent_dropout=self.recurrent_dropout,
dtype=self.dtype,
),
return_sequences=self.return_sequences,
dtype=self.dtype,
)
self.layer.build(input_shape)
def call(self, inputs, training=False):
"""
Apply this layer to inputs.
Notes
-----
This method should not be called manually; rather, use the implicit layer
callable behaviour (like ``my_layer(inputs)``), which will apply this method
with some additional bookkeeping.
"""
return self.layer.call(inputs, training=training)
def get_config(self):
"""Return config of layer (for serialization during model saving/loading)."""
config = super().get_config()
config.update(
{
"memory_d": self.memory_d,
"order": self.order,
"theta": self._init_theta,
"hidden_cell": keras.layers.serialize(self.hidden_cell),
"trainable_theta": self.trainable_theta,
"hidden_to_memory": self.hidden_to_memory,
"memory_to_memory": self.memory_to_memory,
"input_to_hidden": self.input_to_hidden,
"discretizer": self.discretizer,
"kernel_initializer": self.kernel_initializer,
"recurrent_initializer": self.recurrent_initializer,
"kernel_regularizer": self.kernel_regularizer,
"recurrent_regularizer": self.recurrent_regularizer,
"use_bias": self.use_bias,
"bias_initializer": self.bias_initializer,
"bias_regularizer": self.bias_regularizer,
"dropout": self.dropout,
"recurrent_dropout": self.recurrent_dropout,
"return_sequences": self.return_sequences,
}
)
return config
@classmethod
def from_config(cls, config):
"""Load model from serialized config."""
config["hidden_cell"] = (
None
if config["hidden_cell"] is None
else keras.layers.deserialize(config["hidden_cell"])
)
return super().from_config(config)
@tf.keras.utils.register_keras_serializable("keras-lmu")
class LMUFeedforward(
keras.layers.Layer
): # pylint: disable=too-many-ancestors,abstract-method
"""
Layer class for the feedforward variant of the LMU.
This class assumes no recurrent connections are desired in the memory component.
Produces the output of the delay system by evaluating the convolution of the input
sequence with the impulse response from the LMU cell.
Parameters
----------
memory_d : int
Dimensionality of input to memory component.
order : int
The number of degrees in the transfer function of the LTI system used to
represent the sliding window of history. This parameter sets the number of
Legendre polynomials used to orthogonally represent the sliding window.
theta : float
The number of timesteps in the sliding window that is represented using the
LTI system. In this context, the sliding window represents a dynamic range of
data, of fixed size, that will be used to predict the value at the next time
step. If this value is smaller than the size of the input sequence, only that
number of steps will be represented at the time of prediction, however the
entire sequence will still be processed in order for information to be
projected to and from the hidden layer.
hidden_cell : ``keras.layers.Layer``
Keras Layer implementing the hidden component.
input_to_hidden : bool
If True, connect the input directly to the hidden component (in addition to
the connection from the memory component) (default False).
discretizer : str
The method used to discretize the A and B matrices of the LMU. Current
options are "zoh" (short for Zero Order Hold) and "euler".
"zoh" is more accurate, but training will be slower than "euler" if
``trainable_theta=True``. Note that a larger theta is needed when discretizing
using "euler" (a value that is larger than ``4*order`` is recommended).
kernel_initializer : ``tf.initializers.Initializer``
Initializer for weights from input to memory/hidden component. If ``None``,
no weights will be used, and the input size must match the memory/hidden size.
kernel_regularizer : ``keras.regularizers.Regularizer``
Regularizer for weights from input to memory/hidden component.
use_bias : bool
If True, the memory component includes a bias term.
bias_initializer : ``tf.initializers.Initializer``
Initializer for the memory component bias term. Only used if ``use_bias=True``.
bias_regularizer : ``keras.regularizers.Regularizer``
Regularizer for the memory component bias term. Only used if ``use_bias=True``.
dropout : float
Dropout rate on input connections.
return_sequences : bool, optional
If True, return the full output sequence. Otherwise, return just the last
output in the output sequence.
conv_mode : "fft" or "raw"
The method for performing the inpulse response convolution. "fft" uses FFT
convolution (default). "raw" uses explicit convolution, which may be faster
for particular models on particular hardware.
truncate_ir : float
The portion of the impulse response to truncate when using "raw"
convolution (see ``conv_mode``). This is an approximate upper bound on the error
relative to the exact implementation. Smaller ``theta`` values result in more
truncated elements for a given value of ``truncate_ir``, improving efficiency.
"""
def __init__(
self,
memory_d,
order,
theta,
hidden_cell,
input_to_hidden=False,
discretizer="zoh",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
use_bias=False,
bias_initializer="zeros",
bias_regularizer=None,
dropout=0,
return_sequences=False,
conv_mode="fft",
truncate_ir=1e-4,
**kwargs,
):
super().__init__(**kwargs)
if conv_mode not in ("fft", "raw"):
raise ValueError(f"Unrecognized conv mode '{conv_mode}'")
self.memory_d = memory_d
self.order = order
self.theta = theta
self.hidden_cell = hidden_cell
self.input_to_hidden = input_to_hidden
self.discretizer = discretizer
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
self.use_bias = use_bias
self.bias_initializer = bias_initializer
self.bias_regularizer = bias_regularizer
self.dropout = dropout
self.return_sequences = return_sequences
self.conv_mode = conv_mode.lower()
self.truncate_ir = truncate_ir
# create a standard LMUCell to generate the impulse response during `build`
self.delay_layer = keras.layers.RNN(
LMUCell(
memory_d=1,
order=order,
theta=theta,
hidden_cell=None,
trainable_theta=False,
input_to_hidden=False,
hidden_to_memory=False,
memory_to_memory=False,
discretizer=discretizer,
kernel_initializer=None,
trainable=False,
dtype=self.dtype,
),
return_sequences=True,
dtype=self.dtype,
)
self.impulse_response = None
self.kernel = None
self.bias = None
self.dropout_layer = None
def build(self, input_shape): # noqa: C901
"""
Builds the layer.
Notes
-----
This method should not be called manually; rather, use the implicit layer
callable behaviour (like ``my_layer(inputs)``), which will apply this method
with some additional bookkeeping.
"""
super().build(input_shape)
enc_d = input_shape[-1]
seq_len = input_shape[1]
if seq_len is None:
theta_factor = 5
warnings.warn(
f"Approximating unknown impulse length with {theta_factor}*theta; "
f"setting a fixed sequence length on inputs will remove the need for "
f"approximation"
)
impulse_len = self.theta * theta_factor
else:
impulse_len = seq_len
impulse = tf.reshape(tf.eye(impulse_len, 1), (1, -1, 1))
self.impulse_response = tf.squeeze(
self.delay_layer(impulse, training=False), axis=0
)
if self.conv_mode == "fft":
self.impulse_response_fft = (
None
if seq_len is None
else tf.signal.rfft(
tf.transpose(self.impulse_response),
fft_length=[2 * seq_len],
)
)
else:
if self.truncate_ir is not None:
assert self.impulse_response.shape == (impulse_len, self.order)
cumsum = tf.math.cumsum(
tf.math.abs(self.impulse_response), axis=0, reverse=True
)
cumsum = cumsum / cumsum[0]
to_drop = tf.reduce_all(cumsum < self.truncate_ir, axis=-1)
if to_drop[-1]:
cutoff = tf.where(to_drop)[0, -1]
self.impulse_response = self.impulse_response[:cutoff]
self.impulse_response = tf.reshape(
self.impulse_response,
(self.impulse_response.shape[0], 1, 1, self.order),
)
self.impulse_response = self.impulse_response[::-1, :, :, :]
if self.kernel_initializer is not None:
self.kernel = self.add_weight(
name="kernel",
shape=(input_shape[-1], self.memory_d),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
)
else:
self.kernel = None
if enc_d != self.memory_d:
raise ValueError(
f"For LMUCells with no input kernel, the input dimension ({enc_d})"
f" must equal `memory_d` ({self.memory_d})."
)
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(self.memory_d,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
if self.hidden_cell is not None and not self.hidden_cell.built:
hidden_input_d = self.memory_d * self.order
if self.input_to_hidden:
hidden_input_d += input_shape[-1]
with tf.name_scope(self.hidden_cell.name):
self.hidden_cell.build((input_shape[0], hidden_input_d))
if self.dropout:
self.dropout_layer = keras.layers.Dropout(
self.dropout, noise_shape=(input_shape[0], 1) + tuple(input_shape[2:])
)
self.dropout_layer.build(input_shape)
def call(self, inputs, training=False):
"""