Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
80 lines (67 sloc) 2.77 KB
import numpy as np
import tensorflow as tf
# ----------------------------------------------------------------------------
def SubPixel1D_v2(I, r):
"""One-dimensional subpixel upsampling layer
Based on https://github.com/Tetrachrome/subpixel/blob/master/subpixel.py
"""
with tf.name_scope('subpixel'):
bsize, a, r = I.get_shape().as_list()
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = tf.split(1, a, I) # a, [bsize, 1, r]
if 'axis' in tf.squeeze.func_code.co_varnames:
X = tf.concat(1, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r
elif 'squeeze_dims' in tf.squeeze.func_code.co_varnames:
X = tf.concat(1, [tf.squeeze(x, squeeze_dims=[1]) for x in X]) # bsize, a*r
else:
raise Exception('Unsupported version of tensorflow')
return tf.reshape(X, (bsize, a*r, 1))
def SubPixel1D(I, r):
"""One-dimensional subpixel upsampling layer
Calls a tensorflow function that directly implements this functionality.
We assume input has dim (batch, width, r)
"""
with tf.name_scope('subpixel'):
X = tf.transpose(I, [2,1,0]) # (r, w, b)
X = tf.batch_to_space_nd(X, [r], [[0,0]]) # (1, r*w, b)
X = tf.transpose(X, [2,1,0])
return X
def SubPixel1D_multichan(I, r):
"""One-dimensional subpixel upsampling layer
Calls a tensorflow function that directly implements this functionality.
We assume input has dim (batch, width, r).
Works with multiple channels: (B,L,rC) -> (B,rL,C)
"""
with tf.name_scope('subpixel'):
_, w, rc = I.get_shape()
assert rc % r == 0
c = rc / r
X = tf.transpose(I, [2,1,0]) # (rc, w, b)
X = tf.batch_to_space_nd(X, [r], [[0,0]]) # (c, r*w, b)
X = tf.transpose(X, [2,1,0])
return X
# ----------------------------------------------------------------------------
# demonstration
if __name__ == "__main__":
with tf.Session() as sess:
x = np.arange(2*4*2).reshape(2, 4, 2)
X = tf.placeholder("float32", shape=(2, 4, 2), name="X")
Y = SubPixel1D(X, 2)
y = sess.run(Y, feed_dict={X: x})
print 'single-channel:'
print 'original, element 0 (2 channels):', x[0,:,0], x[0,:,1]
print 'rescaled, element 1:', y[0,:,0]
print
print 'original, element 0 (2 channels) :', x[1,:,0], x[1,:,1]
print 'rescaled, element 1:', y[1,:,0]
print
x = np.arange(2*4*4).reshape(2, 4, 4)
X = tf.placeholder("float32", shape=(2, 4, 4), name="X")
Y = SubPixel1D(X, 2)
y = sess.run(Y, feed_dict={X: x})
print 'multichannel:'
print 'original, element 0 (4 channels):', x[0,:,0], x[0,:,1], x[0,:,2], x[0,:,3]
print 'rescaled, element 1:', y[0,:,0], y[0,:,1]
print
print 'original, element 0 (2 channels) :', x[1,:,0], x[1,:,1], x[1,:,2], x[1,:,3]
print 'rescaled, element 1:', y[1,:,0], y[1,:,1],