-
Notifications
You must be signed in to change notification settings - Fork 46
/
Subpixel.py
86 lines (79 loc) · 3.5 KB
/
Subpixel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from keras import backend as K
from keras.layers import Conv2D
"""
Subpixel Layer as a child class of Conv2D. This layer accepts all normal
arguments, with the exception of dilation_rate(). The argument r indicates
the upsampling factor, which is applied to the normal output of Conv2D.
The output of this layer will have the same number of channels as the
indicated filter field, and thus works for grayscale, color, or as a a
hidden layer.
Arguments:
*see Keras Docs for Conv2D args, noting that dilation_rate() is removed*
r: upscaling factor, which is applied to the output of normal Conv2D
A test is included, which performs super-resolution on the Cifar10 dataset.
Since these images are small, only a scale factor of 2 is used. Test images
are saved in the directory 'test_output/'. This test runs for 5 epochs,
which can be altered in line 132. You can run this test by using the
following commands:
mkdir test_output
python keras_subpixel.py
"""
class Subpixel(Conv2D):
def __init__(self,
filters,
kernel_size,
r,
padding='valid',
data_format=None,
strides=(1,1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Subpixel, self).__init__(
filters=r*r*filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.r = r
def _phase_shift(self, I):
r = self.r
bsize, a, b, c = I.get_shape().as_list()
bsize = K.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = K.reshape(I, [bsize, a, b, int(c/(r*r)),r, r]) # bsize, a, b, c/(r*r), r, r
X = K.permute_dimensions(X, (0, 1, 2, 5, 4, 3)) # bsize, a, b, r, r, c/(r*r)
#Keras backend does not support tf.split, so in future versions this could be nicer
X = [X[:,i,:,:,:,:] for i in range(a)] # a, [bsize, b, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, b, a*r, r, c/(r*r)
X = [X[:,i,:,:,:] for i in range(b)] # b, [bsize, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, a*r, b*r, c/(r*r)
return X
def call(self, inputs):
return self._phase_shift(super(Subpixel, self).call(inputs))
def compute_output_shape(self, input_shape):
unshifted = super(Subpixel, self).compute_output_shape(input_shape)
return (unshifted[0], self.r*unshifted[1], self.r*unshifted[2], int(unshifted[3]/(self.r*self.r)))
def get_config(self):
config = super(Conv2D, self).get_config()
config.pop('rank')
config.pop('dilation_rate')
config['filters'] = int(config['filters'] / self.r*self.r)
config['r'] = self.r
return config