Skip to content

Commit

Permalink
Support for different radiuses for the deconvolution input tensor
Browse files Browse the repository at this point in the history
Also append radius to the checkpoint directory name.
  • Loading branch information
igv committed Sep 29, 2017
1 parent 17b5443 commit 81ab9ae
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
1 change: 1 addition & 0 deletions main.py
Expand Up @@ -12,6 +12,7 @@
flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [3]")
flags.DEFINE_integer("radius", 2, "Max radius of the deconvolution input tensor [2]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
Expand Down
15 changes: 8 additions & 7 deletions model.py
Expand Up @@ -30,13 +30,15 @@ def __init__(self, sess, config):
self.is_grayscale = (self.c_dim == 1)
self.epoch = config.epoch
self.scale = config.scale
self.radius = config.radius
self.batch_size = config.batch_size
self.threads = config.threads
self.distort = config.distort
self.params = config.params

self.padding = self.radius * 2
# Different image/label sub-sizes for different scaling factors x2, x3, x4
scale_factors = [[14, 20], [11, 21], [10, 24]]
scale_factors = [[10 + self.padding, 20], [7 + self.padding, 21], [6 + self.padding, 24]]
self.image_size, self.label_size = scale_factors[self.scale - 2]
# Testing uses different strides to ensure sub-images line up correctly
if not self.train:
Expand All @@ -47,8 +49,6 @@ def __init__(self, sess, config):
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
model_params = [[56, 12, 4], [32, 8, 1]]
self.model_params = model_params[self.fast]

self.deconv_radius = [3, 5, 7][self.scale - 2]

self.checkpoint_dir = config.checkpoint_dir
self.output_dir = config.output_dir
Expand Down Expand Up @@ -168,7 +168,8 @@ def model(self):
d, s, m = self.model_params

# Feature Extraction
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32))
size = self.radius * 2 + 1
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([size, size, 1, d], stddev=0.0378, dtype=tf.float32))
self.biases['b1'] = tf.get_variable('b1', initializer=tf.zeros([d]))
conv = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1)

Expand All @@ -195,7 +196,7 @@ def model(self):
conv = self.prelu(tf.nn.conv2d(conv, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, m + 3)

# Deconvolution
deconv_size = self.deconv_radius * 2 + 1
deconv_size = self.radius * self.scale * 2 + 1
deconv_weights = tf.get_variable('w{}'.format(m + 4), initializer=tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32))
deconv_biases = tf.get_variable('b{}'.format(m + 4), initializer=tf.zeros([1]))
self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)] = deconv_weights, deconv_biases
Expand All @@ -219,7 +220,7 @@ def prelu(self, _x, i):
def save(self, checkpoint_dir, step):
model_name = "FSRCNN.model"
d, s, m = self.model_params
model_dir = "%s_%s_%s-%s-%s" % ("fsrcnn", self.label_size, d, s, m)
model_dir = "%s_%s_%s-%s-%s_%s" % ("fsrcnn", self.label_size, d, s, m, "r"+str(self.radius))
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

if not os.path.exists(checkpoint_dir):
Expand All @@ -232,7 +233,7 @@ def save(self, checkpoint_dir, step):
def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")
d, s, m = self.model_params
model_dir = "%s_%s_%s-%s-%s" % ("fsrcnn", self.label_size, d, s, m)
model_dir = "%s_%s_%s-%s-%s_%s" % ("fsrcnn", self.label_size, d, s, m, "r"+str(self.radius))
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
Expand Down
17 changes: 10 additions & 7 deletions sort.py
Expand Up @@ -2,26 +2,29 @@

def main():
scale = 2
radius = [3, 5, 7][scale-2]
radius = 2
size = radius * scale * 2 + 1
d = 64 #size of the feature layer

if len(sys.argv) == 2:
fname=sys.argv[1]
with open(fname) as f:
content = f.readlines()
content = [x.strip() for x in content]

x=list(reversed(range(scale)))
x=x[-1:]+x[:-1]
xy = []
for i in range(0, scale):
for j in range(0, scale):
for i in x:
for j in x:
xy.append([j, i])
xy = list(reversed(xy))

m = []
for i in range(0, len(xy)):
xi, yi = xy[i]
for x in range(xi, radius*2+1, scale):
for y in range(yi, radius*2+1, scale):
m.append(y + x*(radius*2+1))
for y in range(yi, size, scale):
for x in range(xi, size, scale):
m.append(y + x * size)
#print(m)
content = list(reversed(content))
sort = [content[m[l]].strip(",") for l in range(0, len(m))]
Expand Down
14 changes: 7 additions & 7 deletions utils.py
Expand Up @@ -131,11 +131,11 @@ def modcrop(image, scale=3):

def train_input_worker(args):
image_data, config = args
image_size, label_size, stride, scale, distort = config
image_size, label_size, stride, scale, in_padding, distort = config

single_input_sequence, single_label_sequence = [], []
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
label_padding = abs((image_size - in_padding) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7

input_, label_ = preprocess(image_data, scale, distort=distort)

Expand Down Expand Up @@ -176,7 +176,7 @@ def thread_train_setup(config):
pool = Pool(config.threads)

# Distribute |images_per_thread| images across each worker process
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.distort]
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding, config.distort]
images_per_thread = len(data) // config.threads
workers = []
for thread in range(config.threads):
Expand Down Expand Up @@ -213,14 +213,14 @@ def train_input_setup(config):
Read image files, make their sub-images, and save them as a h5 file format.
"""
sess = config.sess
image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
image_size, label_size, stride, scale, in_padding = config.image_size, config.label_size, config.stride, config.scale, config.padding

# Load data path
data = prepare_data(sess, dataset=config.data_dir)

sub_input_sequence, sub_label_sequence = [], []
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
label_padding = abs((image_size - in_padding) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7

for i in range(len(data)):
input_, label_ = preprocess(data[i], scale, distort=config.distort)
Expand Down Expand Up @@ -253,14 +253,14 @@ def test_input_setup(config):
Read image files, make their sub-images, and save them as a h5 file format.
"""
sess = config.sess
image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
image_size, label_size, stride, scale, in_padding = config.image_size, config.label_size, config.stride, config.scale, config.padding

# Load data path
data = prepare_data(sess, dataset="Test")

sub_input_sequence, sub_label_sequence = [], []
padding = abs(image_size - label_size) // 2 # eg. (21 - 11) / 2 = 5
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
label_padding = abs((image_size - in_padding) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7

pic_index = 2 # Index of image based on lexicographic order in data folder
input_, label_ = preprocess(data[pic_index], config.scale)
Expand Down

0 comments on commit 81ab9ae

Please sign in to comment.