-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
conv.py
3046 lines (2656 loc) · 136 KB
/
conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation of convolutional Sonnet modules.
Classes defining convolutional operations, inheriting from `snt.Module`, with
easy weight sharing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import numbers
# Dependency imports
import numpy as np
import six
from sonnet.python.modules import base
from sonnet.python.modules import util
import tensorflow.compat.v1 as tf
# Strings for TensorFlow convolution padding modes. See the following
# documentation for an explanation of VALID versus SAME:
# https://www.tensorflow.org/api_docs/python/tf/nn/convolution
SAME = "SAME"
VALID = "VALID"
FULL = "FULL"
CAUSAL = "CAUSAL"
REVERSE_CAUSAL = "REVERSE_CAUSAL"
CONV_OP_ALLOWED_PADDINGS = {SAME, VALID}
ALLOWED_PADDINGS = {
SAME, VALID, FULL, CAUSAL, REVERSE_CAUSAL
}
CONSTANT_PADDING = "CONSTANT"
REFLECT_PADDING = "REFLECT"
SYMMETRIC_PADDING = "SYMMETRIC"
ALLOWED_PADDING_VALUES = {CONSTANT_PADDING, REFLECT_PADDING, SYMMETRIC_PADDING}
DATA_FORMAT_NCW = "NCW"
DATA_FORMAT_NWC = "NWC"
SUPPORTED_1D_DATA_FORMATS = {DATA_FORMAT_NCW, DATA_FORMAT_NWC}
DATA_FORMAT_NCHW = "NCHW"
DATA_FORMAT_NHWC = "NHWC"
SUPPORTED_2D_DATA_FORMATS = {DATA_FORMAT_NCHW, DATA_FORMAT_NHWC}
DATA_FORMAT_NDHWC = "NDHWC"
DATA_FORMAT_NCDHW = "NCDHW"
SUPPORTED_3D_DATA_FORMATS = {DATA_FORMAT_NDHWC, DATA_FORMAT_NCDHW}
def _default_transpose_size(input_shape, stride, kernel_shape=None,
padding=SAME):
"""Returns default (maximal) output shape for a transpose convolution.
In general, there are multiple possible output shapes that a transpose
convolution with a given `input_shape` can map to. This function returns the
output shape which evenly divides the stride to produce the input shape in
a forward convolution, i.e. the maximal valid output shape with the given
configuration:
if the padding type is SAME then: output_shape = input_shape * stride
if the padding type is VALID then: output_shape = input_shape * stride +
kernel_shape - 1
See the following documentation for an explanation of VALID versus SAME
padding modes:
https://www.tensorflow.org/versions/r0.8/api_docs/python/nn.html#convolution
Args:
input_shape: Sequence of sizes of each dimension of the input, excluding
batch and channel dimensions.
stride: Sequence or integer of kernel strides, excluding batch and channel
dimension strides.
kernel_shape: Sequence or integer of kernel sizes.
padding: Padding algorithm, either `snt.SAME` or `snt.VALID`.
Returns:
output_shape: A tuple of sizes for a transposed convolution that divide
evenly with the given strides, kernel shapes, and padding algorithm.
"""
if not input_shape:
raise TypeError("input_shape is None; if using Sonnet, are you sure you "
"have connected the module to inputs?")
input_length = len(input_shape)
stride = _fill_and_verify_parameter_shape(stride, input_length, "stride")
padding = _verify_conv_op_supported_padding(padding)
output_shape = tuple(x * y for x, y in zip(input_shape, stride))
if padding == VALID:
kernel_shape = _fill_and_verify_parameter_shape(kernel_shape, input_length,
"kernel")
output_shape = tuple(x + y - 1 for x, y in zip(output_shape, kernel_shape))
return output_shape
def _fill_shape(x, n):
"""Converts a dimension to a tuple of dimensions of a given size.
This is used to allow shorthand notation for various configuration parameters.
A user can provide either, for example, `2` or `[2, 2]` as a kernel shape, and
this function returns `(2, 2)` in both cases. Passing `[1, 2]` will return
`(1, 2)`.
Args:
x: An integer, tf.Dimension, or an iterable of them.
n: An integer, the size of the desired output list
Returns:
If `x` is an integer, a tuple of size `n` containing `n` copies of `x`.
If `x` is an iterable of integers or tf.Dimension of size `n`, it returns
`tuple(x)`.
Raises:
TypeError: If n is not a positive integer;
or if x is neither integer nor an iterable of size n.
"""
if not isinstance(n, numbers.Integral) or n < 1:
raise TypeError("n must be a positive integer")
if (isinstance(x, numbers.Integral) or isinstance(x, tf.Dimension)) and x > 0:
return (x,) * n
try:
if len(x) == n and all(v > 0 for v in x):
return tuple(x)
except TypeError:
pass
raise TypeError("x is {}, must be either a positive integer "
"or an iterable of positive integers of size {}"
.format(x, n))
def _fill_and_verify_parameter_shape(x, n, parameter_label):
"""Expands x if necessary into a `n`-D kernel shape and reports errors."""
try:
return _fill_shape(x, n)
except TypeError as e:
raise base.IncompatibleShapeError("Invalid " + parameter_label + " shape: "
"{}".format(e))
def _verify_conv_op_supported_padding(padding):
"""Verifies that the given padding type is supported for conv ops.
Args:
padding: One of CONV_OP_ALLOWED_PADDINGS.
Returns:
padding.
Raises:
ValueError: If padding is not one of CONV_OP_ALLOWED_PADDINGS.
"""
if padding not in CONV_OP_ALLOWED_PADDINGS:
raise ValueError(
"Padding must be member of '{}', not {}".format(
CONV_OP_ALLOWED_PADDINGS, padding))
return padding
def _verify_padding_value(padding_value):
"""Verifies that the given padding mode is supported.
Args:
padding_value: One of ALLOWED_PADDING_VALUES.
Returns:
padding_value.
Raises:
ValueError: If padding_value is not one of ALLOWED_PADDING_VALUES.
"""
if padding_value not in ALLOWED_PADDING_VALUES:
raise ValueError(
"Padding must be member of '{}', not {}".format(
ALLOWED_PADDING_VALUES, padding_value))
return padding_value
def _fill_and_verify_padding(padding, n):
"""Verifies that the provided padding is supported and expands to size n.
Args:
padding: One of ALLOWED_PADDINGS, or an iterable of them.
n: An integer, the size of the desired output list.
Returns:
If `padding` is one of ALLOWED_PADDINGS, a tuple of size `n` containing `n`
copies of `padding`.
If `padding` is an iterable of ALLOWED_PADDINGS of size `n`, it returns
`padding(x)`.
Raises:
TypeError: If n is not a positive integer; if padding is neither one of
ALLOWED_PADDINGS nor an iterable of ALLOWED_PADDINGS of size n.
"""
if not isinstance(n, numbers.Integral) or n < 1:
raise TypeError("n must be a positive integer")
if isinstance(padding, six.string_types) and padding in ALLOWED_PADDINGS:
return (padding,) * n
try:
if len(padding) == n and all(p in ALLOWED_PADDINGS for p in padding):
return tuple(padding)
except TypeError:
pass
raise TypeError("padding is {}, must be member of '{}' or an iterable of "
"these of size {}".format(padding, ALLOWED_PADDINGS, n))
def _padding_to_conv_op_padding(padding, padding_value):
"""Whether to use SAME or VALID for the underlying convolution op.
Args:
padding: A tuple of members of ALLOWED_PADDINGS, e.g. as returned from
`_fill_and_verify_padding`.
padding_value: A string of ALLOWED_PADDING_VALUES.
Returns:
One of CONV_OP_ALLOWED_PADDINGS, the padding method to use for the
underlying convolution op.
Raises:
ValueError: If padding is not a tuple.
"""
if not isinstance(padding, tuple):
raise ValueError("padding should be a tuple.")
if all(p == SAME for p in padding) and padding_value == CONSTANT_PADDING:
# If we want SAME padding for all dimensions then we can use SAME for the
# conv and avoid doing any extra padding.
return SAME
else:
# Otherwise we prefer to use VALID, since we can implement all the other
# padding types just by adding some extra padding before doing a VALID conv.
# (We could use SAME but then we'd also have to crop outputs in some cases).
return VALID
def _fill_and_one_pad_stride(stride, n, data_format=DATA_FORMAT_NHWC):
"""Expands the provided stride to size n and pads it with 1s."""
if isinstance(stride, numbers.Integral) or (
isinstance(stride, collections.Iterable) and len(stride) <= n):
if data_format.startswith("NC"):
return (1, 1,) + _fill_shape(stride, n)
elif data_format.startswith("N") and data_format.endswith("C"):
return (1,) + _fill_shape(stride, n) + (1,)
else:
raise ValueError(
"Invalid data_format {:s}. Must start with N and have a channel dim "
"either follow the N dim or come at the end".format(data_format))
elif isinstance(stride, collections.Iterable) and len(stride) == n + 2:
return stride
else:
raise base.IncompatibleShapeError(
"stride is {} ({}), must be either a positive integer or an iterable of"
" positive integers of size {}".format(stride, type(stride), n))
def _verify_inputs(inputs, channel_index, data_format):
"""Verifies `inputs` is semantically correct.
Args:
inputs: An input tensor provided by the user.
channel_index: The index of the channel dimension.
data_format: The format of the data in `inputs`.
Raises:
base.IncompatibleShapeError: If the shape of `inputs` doesn't match
`data_format`.
base.UnderspecifiedError: If the channel dimension of `inputs` isn't
defined.
TypeError: If input Tensor dtype is not compatible with either
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
"""
# Check shape.
input_shape = tuple(inputs.get_shape().as_list())
if len(input_shape) != len(data_format):
raise base.IncompatibleShapeError((
"Input Tensor must have rank {} corresponding to "
"data_format {}, but instead was {} of rank {}.").format(
len(data_format), data_format, input_shape, len(input_shape)))
# Check type.
if not (tf.float16.is_compatible_with(inputs.dtype) or
tf.bfloat16.is_compatible_with(inputs.dtype) or
tf.float32.is_compatible_with(inputs.dtype) or
tf.float64.is_compatible_with(inputs.dtype)):
raise TypeError(
"Input must have dtype tf.float16, tf.bfloat16, tf.float32 or "
"tf.float64, but dtype was {}".format(inputs.dtype))
# Check channel dim.
input_channels = input_shape[channel_index]
if input_channels is None:
raise base.UnderspecifiedError(
"Number of input channels must be known at module build time")
def create_weight_initializer(fan_in_shape, dtype=tf.float32):
"""Returns a default initializer for the weights of a convolutional module."""
stddev = 1 / math.sqrt(np.prod(fan_in_shape))
return tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)
def create_bias_initializer(unused_bias_shape, dtype=tf.float32):
"""Returns a default initializer for the biases of a convolutional module."""
return tf.zeros_initializer(dtype=dtype)
def _find_channel_index(data_format):
"""Returns the index of the channel dimension.
Args:
data_format: A string of characters corresponding to Tensor dimensionality.
Returns:
channel_index: An integer indicating the channel dimension.
Raises:
ValueError: If no channel dimension was found.
"""
for i, c in enumerate(data_format):
if c == "C":
return i
raise ValueError("data_format requires a channel dimension. Got: {}"
.format(data_format))
def _apply_bias(inputs, outputs, channel_index, data_format, output_channels,
initializers, partitioners, regularizers):
"""Initialize and apply a bias to the outputs.
Figures out the shape of the bias vector, initialize it, and applies it.
Args:
inputs: A Tensor of shape `data_format`.
outputs: A Tensor of shape `data_format`.
channel_index: The index of the channel dimension in `inputs`.
data_format: Format of `inputs`.
output_channels: Channel dimensionality for `outputs`.
initializers: Optional dict containing ops to initialize the biases
(with key 'b').
partitioners: Optional dict containing partitioners to partition the
biases (with key 'b').
regularizers: Optional dict containing regularizers for the biases
(with key 'b').
Returns:
b: The constructed bias variable.
outputs: The `outputs` argument that has had a bias applied.
"""
bias_shape = (output_channels,)
if "b" not in initializers:
initializers["b"] = create_bias_initializer(bias_shape,
dtype=inputs.dtype)
b = tf.get_variable("b",
shape=bias_shape,
dtype=inputs.dtype,
initializer=initializers["b"],
partitioner=partitioners.get("b", None),
regularizer=regularizers.get("b", None))
# tf.nn.bias_add only supports 2 data formats.
if data_format in (DATA_FORMAT_NHWC, DATA_FORMAT_NCHW):
# Supported as-is.
outputs = tf.nn.bias_add(outputs, b, data_format=data_format)
else:
# Create our own bias vector.
bias_correct_dim = [1] * len(data_format)
bias_correct_dim[channel_index] = output_channels
outputs += tf.reshape(b, bias_correct_dim)
return b, outputs
class _ConvND(base.AbstractModule):
"""N-dimensional convolution and dilated convolution module, including bias.
This acts as a light wrapper around the TensorFlow ops `tf.nn.convolution`
abstracting away variable creation and sharing.
"""
def __init__(self, output_channels, kernel_shape, stride=1, rate=1,
padding=SAME, use_bias=True, initializers=None,
partitioners=None, regularizers=None,
mask=None, data_format=DATA_FORMAT_NHWC,
padding_value=CONSTANT_PADDING, custom_getter=None,
name="conv_nd"):
"""Constructs a _ConvND module.
Args:
output_channels: Number of output channels. `output_channels` can be
either a number or a callable. In the latter case, since the function
invocation is deferred to graph construction time, the user must only
ensure that output_channels can be called, returning an integer,
when `build` is called.
kernel_shape: Sequence of kernel sizes (up to size N), or an integer.
`kernel_shape` will be expanded to define a kernel size in all
dimensions.
stride: Sequence of strides (up to size N), or an integer.
`stride` will be expanded to define stride in all dimensions.
rate: Sequence of dilation rates (of size N), or integer that is used to
define dilation rate in all dimensions. 1 corresponds to standard ND
convolution, `rate > 1` corresponds to dilated convolution. Cannot be
> 1 if any of `stride` is also > 1.
padding: Padding algorithm. Either `snt.SAME`, `snt.VALID`, `snt.FULL`,
`snt.CAUSAL`, `snt.REVERSE_CAUSAL`, or a sequence of these paddings
(up to size N).
* snt.SAME and snt.VALID are explained in the Tensorflow docs at
https://www.tensorflow.org/api_docs/python/tf/nn/convolution.
* snt.FULL pre- and post-pads with the maximum padding which does not
result in a convolution over just padded elements.
* snt.CAUSAL pre-pads to ensure that each output value only depends on
input values at the same or preceding indices ("no dependence on the
future").
* snt.REVERSE_CAUSAL post-pads to ensure that each output value only
depends on input values at the same or *greater* indices ("no
dependence on the past").
If you use the same padding for all dimensions, and it is one of SAME
or VALID, then this is supported directly by the underlying
convolution op. In all other cases, the input data will be padded
using tf.pad before calling the convolution op.
use_bias: Whether to include bias parameters. Default `True`.
initializers: Optional dict containing ops to initialize the filters (with
key 'w') or biases (with key 'b'). The default initializer for the
weights is a truncated normal initializer, which is commonly used
when the inputs are zero centered (see
https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for
the bias is a zero initializer.
partitioners: Optional dict containing partitioners to partition
weights (with key 'w') or biases (with key 'b'). As a default, no
partitioners are used.
regularizers: Optional dict containing regularizers for the filters
(with key 'w') and the biases (with key 'b'). As a default, no
regularizers are used. A regularizer should be a function that takes
a single `Tensor` as an input and returns a scalar `Tensor` output,
e.g. the L1 and L2 regularizers in `tf.contrib.layers`.
mask: A convertible to a ND tensor which is multiplied
component-wise with the weights (Optional).
data_format: The data format of the input.
padding_value: The type of padding to use, either "CONSTANT", "SYMMETRIC"
or "REFLECT", as supported by the underlying tf.pad
(https://www.tensorflow.org/api_docs/python/tf/pad). Can only be set
globally for all dimensions. Defaults to "CONSTANT" which will pad
with zeros, potentially directly via the underlying convolution op if
the padding is SAME or VALID for all dimensions.
custom_getter: Callable or dictionary of callables to use as
custom getters inside the module. If a dictionary, the keys
correspond to regexes to match variable names. See the
`tf.get_variable` documentation for information about the
custom_getter API.
name: Name of the module.
Raises:
base.IncompatibleShapeError: If the given kernel shape is not an integer;
or if the given kernel shape is not a sequence of two integers.
base.IncompatibleShapeError: If the given stride is not an integer; or if
the given stride is not a sequence of two integers.
base.IncompatibleShapeError: If the given rate is not an integer; or if
the given rate is not a sequence of two integers.
base.IncompatibleShapeError: If a mask is a TensorFlow Tensor with
a not fully defined shape.
base.NotSupportedError: If rate in any dimension and the stride in any
dimension are simultaneously > 1.
ValueError: If the given padding is not `snt.VALID`, `snt.SAME`,
`snt.FULL`, `snt.CAUSAL`, `snt.REVERSE_CAUSAL` or a sequence of these.
KeyError: If `initializers`, `partitioners` or `regularizers` contain any
keys other than 'w' or 'b'.
TypeError: If any of the given initializers, partitioners or regularizers
are not callable.
TypeError: If mask is given and it is not convertible to a Tensor.
ValueError: If the passed-in data_format doesn't have a channel dimension.
"""
super(_ConvND, self).__init__(custom_getter=custom_getter, name=name)
self._n = len(data_format) - 2
self._input_channels = None
self._output_channels = output_channels
self._kernel_shape = _fill_and_verify_parameter_shape(kernel_shape, self._n,
"kernel")
self._data_format = data_format
# The following is for backwards-compatibility from when we used to accept
# N-strides of the form [1, ..., 1].
if (isinstance(stride, collections.Sequence) and
len(stride) == len(data_format)):
self._stride = tuple(stride)[1:-1]
else:
self._stride = _fill_and_verify_parameter_shape(stride, self._n, "stride")
self._rate = _fill_and_verify_parameter_shape(rate, self._n, "rate")
if any(x > 1 for x in self._stride) and any(x > 1 for x in self._rate):
raise base.NotSupportedError("Cannot have stride > 1 with rate > 1")
self._padding = _fill_and_verify_padding(padding, self._n)
self._padding_value = _verify_padding_value(padding_value)
self._conv_op_padding = _padding_to_conv_op_padding(
self._padding, self._padding_value)
self._use_bias = use_bias
self.possible_keys = self.get_possible_initializer_keys(use_bias=use_bias)
self._initializers = util.check_initializers(
initializers, self.possible_keys)
self._partitioners = util.check_partitioners(
partitioners, self.possible_keys)
self._regularizers = util.check_regularizers(
regularizers, self.possible_keys)
if mask is not None:
if isinstance(mask, (tf.Tensor, list, tuple, np.ndarray)):
self._mask = tf.convert_to_tensor(mask)
if not (tf.float16.is_compatible_with(self._mask.dtype) or
tf.bfloat16.is_compatible_with(self._mask.dtype) or
tf.float32.is_compatible_with(self._mask.dtype) or
tf.float64.is_compatible_with(self._mask.dtype)):
raise TypeError(
"Mask needs to have dtype float16, bfloat16, float32 or float64")
if not self._mask.get_shape().is_fully_defined():
base.IncompatibleShapeError(
"Mask needs to have a statically defined shape")
else:
raise TypeError("Invalid type for mask: {}".format(type(mask)))
else:
self._mask = None
self._channel_index = _find_channel_index(self._data_format)
@classmethod
def get_possible_initializer_keys(cls, use_bias=True):
return {"w", "b"} if use_bias else {"w"}
def _build(self, inputs):
"""Connects the _ConvND module into the graph, with input Tensor `inputs`.
If this is not the first time the module has been connected to the graph,
the input Tensor provided here must have the same number of channels, in
order for the existing variables to be the correct size for the
multiplication; the batch size and input spatial dimensions may differ for
each connection.
Args:
inputs: A ND Tensor of the same rank as `data_format`, and either of types
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ...,
output_channels].
Raises:
ValueError: If connecting the module into the graph any time after the
first time and the inferred size of the input does not match previous
invocations.
base.IncompatibleShapeError: If the input tensor has the wrong number
of dimensions.
base.UnderspecifiedError: If the channel dimension of `inputs` isn't
defined.
base.IncompatibleShapeError: If a mask is present and its shape is
incompatible with the shape of the weights.
TypeError: If input Tensor dtype is not compatible with either
`tf.float16`, `tf.bfloat16`, `tf.float32` or `tf.float64`.
"""
_verify_inputs(inputs, self._channel_index, self._data_format)
self._input_shape = tuple(inputs.get_shape().as_list())
self._input_channels = self._input_shape[self._channel_index]
self._w = self._construct_w(inputs)
if self._mask is not None:
w = self._apply_mask()
else:
w = self._w
inputs = self._pad_input(inputs)
outputs = self._apply_conv(inputs, w)
if self._use_bias:
self._b, outputs = _apply_bias(
inputs, outputs, self._channel_index, self._data_format,
self.output_channels, self._initializers, self._partitioners,
self._regularizers)
return outputs
def _pad_input(self, inputs):
"""Pad input in case the desired padding type requires it.
VALID and SAME padding types are directly supported by tensorflow
convolution ops, so don't require us to pad input ourselves, at least
in cases where the same method is used for all dimensions.
Other padding types (FULL, CAUSAL, REVERSE_CAUSAL) aren't directly supported
by conv ops but can be implemented by using VALID and padding the input
appropriately ourselves.
If different padding types are used for different dimensions, we use VALID
but pad the input ourselves along any dimensions that require other padding
types.
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
inputs: The `inputs` argument that has had any required padding added.
"""
if all(p == self._conv_op_padding for p in self._padding):
# All axes require the same padding type that we're going to use for the
# underlying convolution op and we use the padding mode that is used by
# the convolution op, so nothing needs to be done:
return inputs
# In all other cases we use VALID as the underlying padding type, and for
# the axes which require something other than VALID, we pad inputs ourselves
# before the convolution.
assert self._conv_op_padding == VALID
def pad_amount(kernel_size, rate, padding):
"""Pre- and post-padding required for a particular axis before conv op."""
# The effective kernel size includes any holes/gaps introduced by the
# dilation rate. It's equal to kernel_size when rate == 1.
effective_kernel_size = int((kernel_size - 1) * rate + 1)
if padding == FULL:
return [effective_kernel_size - 1, effective_kernel_size - 1]
if padding == CAUSAL:
return [effective_kernel_size - 1, 0]
if padding == REVERSE_CAUSAL:
return [0, effective_kernel_size - 1]
if padding == SAME:
return [(effective_kernel_size - 1) // 2, effective_kernel_size // 2]
# padding == VALID
return [0, 0]
paddings = map(pad_amount, self._kernel_shape, self._rate, self._padding)
if self._data_format.startswith("NC"): # N, C, ...
paddings = [[0, 0], [0, 0]] + list(paddings)
else: # N, ..., C
paddings = [[0, 0]] + list(paddings) + [[0, 0]]
return tf.pad(inputs, paddings, mode=self._padding_value)
def _apply_conv(self, inputs, w):
"""Apply a convolution operation on `inputs` using variable `w`.
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16`, `tf.float32` or `tf.float64`.
w: A weight matrix of the same type as `inputs`.
Returns:
outputs: The result of the convolution operation on `inputs`.
"""
outputs = tf.nn.convolution(inputs, w, strides=self._stride,
padding=self._conv_op_padding,
dilation_rate=self._rate,
data_format=self._data_format)
return outputs
def _construct_w(self, inputs):
"""Construct the convolution weight matrix.
Figures out the shape of the weight matrix, initialize it, and return it.
Args:
inputs: A Tensor of shape `data_format` and of type `tf.float16`,
`tf.bfloat16`, `tf.float32` or `tf.float64`.
Returns:
w: A weight matrix of the same type as `inputs`.
"""
weight_shape = self._kernel_shape + (self._input_channels,
self.output_channels)
if "w" not in self._initializers:
self._initializers["w"] = create_weight_initializer(weight_shape[:-1],
dtype=inputs.dtype)
w = tf.get_variable("w",
shape=weight_shape,
dtype=inputs.dtype,
initializer=self._initializers["w"],
partitioner=self._partitioners.get("w", None),
regularizer=self._regularizers.get("w", None))
return w
def _apply_mask(self):
"""Applies the passed-in mask to the convolution matrix.
Returns:
w: A copy of the convolution matrix that has had the mask applied.
Raises:
base.IncompatibleShapeError: If the mask shape has more dimensions than
the weight matrix.
base.IncompatibleShapeError: If the mask and the weight matrix don't
match on shape.
"""
w = self._w
w_shape = w.get_shape()
mask_shape = self._mask.get_shape()
if mask_shape.ndims > w_shape.ndims:
raise base.IncompatibleShapeError(
"Invalid mask shape: {}. Max shape: {}".format(
mask_shape.ndims, len(self._data_format)
)
)
if mask_shape != w_shape[:mask_shape.ndims]:
raise base.IncompatibleShapeError(
"Invalid mask shape: {}. Weight shape: {}".format(
mask_shape, w_shape
)
)
# TF broadcasting is a bit fragile.
# Expand the shape of self._mask by one dim at a time to the right
# until the rank matches `weight_shape`.
while self._mask.get_shape().ndims < w_shape.ndims:
self._mask = tf.expand_dims(self._mask, -1)
# tf.Variable & tf.ResourceVariable don't support *=.
w = w * self._mask # pylint: disable=g-no-augmented-assignment
return w
@property
def output_channels(self):
"""Returns the number of output channels."""
if callable(self._output_channels):
self._output_channels = self._output_channels()
# Channel must be integer.
self._output_channels = int(self._output_channels)
return self._output_channels
@property
def kernel_shape(self):
"""Returns the kernel shape."""
return self._kernel_shape
@property
def stride(self):
"""Returns the stride."""
# Backwards compatibility with old stride format.
return _fill_and_one_pad_stride(self._stride, self._n, self._data_format)
@property
def rate(self):
"""Returns the dilation rate."""
return self._rate
@property
def padding(self):
"""Returns the padding algorithm used, if this is the same for all dims.
Use `.paddings` if you want a tuple with the padding algorithm used for each
dimension.
Returns:
The padding algorithm used, if this is the same for all dimensions.
Raises:
ValueError: If different padding algorithms are used for different
dimensions.
"""
# This is for backwards compatibility -- previously only a single
# padding setting was supported across all dimensions.
if all(p == self._padding[0] for p in self._padding):
return self._padding[0]
else:
raise ValueError("This layer uses different paddings for different "
"dimensions. Use .paddings if you want a tuple of "
"per-dimension padding settings.")
@property
def paddings(self):
"""Returns a tuple with the padding algorithm used for each dimension."""
return self._padding
@property
def conv_op_padding(self):
"""Returns the padding algorithm used for the underlying convolution op."""
return self._conv_op_padding
@property
def w(self):
"""Returns the Variable containing the weight matrix."""
self._ensure_is_connected()
return self._w
@property
def b(self):
"""Returns the Variable containing the bias.
Returns:
Variable object containing the bias, from the most recent __call__.
Raises:
base.NotConnectedError: If the module has not been connected to the graph
yet, meaning the variables do not exist.
AttributeError: If the module does not use bias.
"""
self._ensure_is_connected()
if not self._use_bias:
raise AttributeError(
"No bias Variable in Conv2D Module when `use_bias=False`.")
return self._b
@property
def has_bias(self):
"""Returns `True` if bias Variable is present in the module."""
return self._use_bias
@property
def initializers(self):
"""Returns the initializers dictionary."""
return self._initializers
@property
def partitioners(self):
"""Returns the partitioners dictionary."""
return self._partitioners
@property
def regularizers(self):
"""Returns the regularizers dictionary."""
return self._regularizers
@property
def mask(self):
"""Returns the mask."""
return self._mask
@property
def data_format(self):
"""Returns the data format."""
return self._data_format
# Implements Transposable interface.
@property
def input_shape(self):
"""Returns the input shape."""
self._ensure_is_connected()
return self._input_shape
@property
def input_channels(self):
"""Returns the number of input channels."""
if self._input_channels is None:
self._ensure_is_connected()
return self._input_channels
def clone(self, name=None):
"""Returns a cloned `_ConvND` module.
Args:
name: Optional string assigning name of cloned module. The default name
is constructed by appending "_clone" to `self.module_name`.
Returns:
A copy of the current class.
"""
if name is None:
name = self.module_name + "_clone"
return type(self)(output_channels=self.output_channels,
kernel_shape=self._kernel_shape,
stride=self._stride,
rate=self._rate,
padding=self._padding,
use_bias=self._use_bias,
initializers=self._initializers,
partitioners=self._partitioners,
regularizers=self._regularizers,
mask=self._mask,
data_format=self._data_format,
custom_getter=self._custom_getter,
name=name)
class _ConvNDTranspose(base.AbstractModule):
"""Spatial transposed / reverse / up ND convolution module, including bias.
This acts as a light wrapper around the TensorFlow `conv_nd_transpose` ops,
abstracting away variable creation and sharing.
"""
def __init__(self, output_channels, output_shape=None, kernel_shape=None,
stride=1, padding=SAME, use_bias=True, initializers=None,
partitioners=None, regularizers=None,
data_format=DATA_FORMAT_NHWC, custom_getter=None,
name="conv_nd_transpose"):
"""Constructs a `ConvNDTranspose module`. Support for N = (1, 2, 3).
See the following documentation for an explanation of VALID versus SAME
padding modes:
https://www.tensorflow.org/api_docs/python/tf/nn/convolution
Args:
output_channels: Number of output channels.
Can be either a number or a callable. In the latter case, since the
function invocation is deferred to graph construction time, the user
must only ensure `output_channels` can be called, returning an
integer, when build is called.
output_shape: Output shape of transpose convolution.
Can be either an iterable of integers or `Dimension`s, a
`TensorShape`, or a callable. In the latter case, since the function
invocation is deferred to graph construction time, the user must only
ensure that `output_shape` can be called, returning an iterable of
output shapes when `build` is called. Note that `output_shape` defines
the size of output signal domain, as opposed to the shape of the
output `Tensor`. If a None value is given, a default shape is
automatically calculated (see docstring of
`_default_transpose_size` function for more details).
kernel_shape: Sequence of kernel sizes (of size N), or integer that is
used to define kernel size in all dimensions.
stride: Sequence of kernel strides (of size N), or integer that is used
to define stride in all dimensions.
padding: Padding algorithm, either `snt.SAME` or `snt.VALID`.
use_bias: Whether to include bias parameters. Default `True`.
initializers: Optional dict containing ops to initialize the filters (with
key 'w') or biases (with key 'b').
partitioners: Optional dict containing partitioners to partition
weights (with key 'w') or biases (with key 'b'). As a default, no
partitioners are used.
regularizers: Optional dict containing regularizers for the filters
(with key 'w') and the biases (with key 'b'). As a default, no
regularizers are used. A regularizer should be a function that takes
a single `Tensor` as an input and returns a scalar `Tensor` output,
e.g. the L1 and L2 regularizers in `tf.contrib.layers`.
data_format: The data format of the input.
custom_getter: Callable or dictionary of callables to use as
custom getters inside the module. If a dictionary, the keys
correspond to regexes to match variable names. See the
`tf.get_variable` documentation for information about the
custom_getter API.
name: Name of the module.
Raises:
base.IncompatibleShapeError: If the given kernel shape is neither an
integer nor a sequence of two integers.
base.IncompatibleShapeError: If the given stride is neither an integer nor
a sequence of two or four integers.
ValueError: If the given padding is not `snt.VALID` or `snt.SAME`.
ValueError: If the given kernel_shape is `None`.
KeyError: If `initializers`, `partitioners` or `regularizers` contain any
keys other than 'w' or 'b'.
TypeError: If any of the given initializers, partitioners or regularizers
are not callable.
ValueError: If the passed-in data_format doesn't have a channel dimension.
"""
super(_ConvNDTranspose, self).__init__(custom_getter=custom_getter,
name=name)
self._data_format = data_format
self._n = len(self._data_format) - 2
if self._n > 3:
raise base.NotSupportedError(
"We only support (1, 2, 3) convolution transpose operations. "
"Received data format of: {}".format(self._data_format))
self._output_channels = output_channels
if output_shape is None:
self._output_shape = None
self._use_default_output_shape = True
else:
self._use_default_output_shape = False
if callable(output_shape):
self._output_shape = output_shape
else:
self._output_shape = _fill_and_verify_parameter_shape(output_shape,
self._n,
"output_shape")
if kernel_shape is None:
raise ValueError("`kernel_shape` cannot be None.")
self._kernel_shape = _fill_and_verify_parameter_shape(kernel_shape, self._n,
"kernel")
if (isinstance(stride, collections.Sequence) and