Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to build graph once and run multi images/masks with different resolutions? #194

Closed
c9o opened this issue Dec 13, 2018 · 18 comments
Closed

Comments

@c9o
Copy link

c9o commented Dec 13, 2018

Really great works your deepfill!
Is it possible to build graph once then feed it with multi images/masks with different resolutions?

Similar to #12 (comment)
a loop was created to process multi images but he built graph for every input image.

In your code #12 (comment) graph only has to be built once but all image/mask should be resized to same size then fed to the graph.

Can we build graph once and feed images/masks with different resolutions?

Thanks.

@JiahuiYu
Copy link
Owner

Hey @jtetc . That is a very good question. And I appreciate that you have searched over issues before asking.

In principle our trained models support different image resolutions. You could have a try by setting input shapes as (1, None, None, 3).

However, as I tried last year, I failed. The major obstacle lies in the API of tf.extract_image_patches in tensorflow. I don't know if currently that API supports None shape input.

@c9o
Copy link
Author

c9o commented Dec 13, 2018

Thanks for your instant response @JiahuiYu
Your advice is great. It works after changing related resize ops with option dynamic=True

Thanks again.

@JiahuiYu
Copy link
Owner

@jtetc Great! Thanks for letting me know.

@ccc013
Copy link

ccc013 commented Dec 21, 2018

Hi, @jtetc

Thanks for your instant response @JiahuiYu
Your advice is great. It works after changing related resize ops with option dynamic=True

Thanks again.

where are the resize ops you said? Did you mean the resize in resize_mask_like() function in inpaint_ops.py ?

And I want to set input shape as (1, None, None, 3), but I got an error like this:

Traceback (most recent call last): File "finetune_imageinfer_compare.py", line 220, in <module> input_image_ph, output, sess = deepfill_model(args.checkpoint_dir, args.image_height, args.image_width) File "finetune_imageinfer_compare.py", line 166, in deepfill_model output = model.build_server_graph(input_image_ph, dynamic=True) File "/home/luocai/generative_inpainting/inpaint_model.py", line 315, in build_server_graph config=None) File "/home/luocai/generative_inpainting/inpaint_model.py", line 61, in build_inpaint_net mask_s = resize_mask_like(mask, x, dynamic=dynamic) File "/home/luocai/generative_inpainting/inpaint_ops.py", line 226, in resize_mask_like func=tf.image.resize_nearest_neighbor, dynamic=dynamic) File "/home/luocai/generative_inpainting/neuralgym_c/ops/layers.py", line 128, in resize align_corners=align_corners) File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/ops/gen_image_ops.py", line 1643, in resize_nearest_neighbor align_corners=align_corners, name=name) File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 519, in _apply_op_helper repr(values), type(values).__name__)) TypeError: Expected int32 passed to parameter 'size' of op 'ResizeNearestNeighbor', got [None, None] of type 'list' instead.

Here is my test code, I want to load model only once and test multiple images, but I failed,
`sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.Session(config=sess_config)

model = InpaintCAModel()
input_image_ph = tf.placeholder(
    tf.float32, shape=(1, None, None, 3))
output = model.build_server_graph(input_image_ph, dynamic=True)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
    vname = var.name
    from_name = vname
    var_value = tf.contrib.framework.load_variable(
        checkpoint_dir, from_name)
    assign_ops.append(tf.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.')
.....`

@JiahuiYu could we build graph once and feed images/masks with different resolutions?

Thanks

@JiahuiYu
Copy link
Owner

@jtetc If possible, please help @ccc013. Thanks!

@c9o
Copy link
Author

c9o commented Dec 21, 2018

@JiahuiYu With pleasure.
Hi @ccc013 , I tried to change related resize ops in inpaint_ops.py with option dynamic=True. You can check the function prototype in neuralgym.ops.layers and you will be clear about it.

@ccc013
Copy link

ccc013 commented Dec 21, 2018

@JiahuiYu @jtetc Thanks for your fast response!

But, when I set the option dynamic=True in resize ops in inpaint_ops.py, I still get the same error:

Traceback (most recent call last):
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 510, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1022, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 233, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 212, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 413, in make_tensor_proto
    _AssertCompatible(values, dtype)
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 328, in _AssertCompatible
    (dtype.name, repr(mismatch), type(mismatch).__name__))
TypeError: Expected int32, got None of type '_Message' instead.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "finetune_imageinfer_compare.py", line 220, in <module>
    input_image_ph, output, sess = deepfill_model(args.checkpoint_dir, args.image_height, args.image_width)
  File "finetune_imageinfer_compare.py", line 166, in deepfill_model
    output = model.build_server_graph(input_image_ph, dynamic=True)
  File "/home/luocai/generative_inpainting/inpaint_model.py", line 315, in build_server_graph
    config=None)
  File "/home/luocai/generative_inpainting/inpaint_model.py", line 61, in build_inpaint_net
    mask_s = resize_mask_like(mask, x, dynamic=dynamic)
  File "/home/luocai/generative_inpainting/inpaint_ops.py", line 227, in resize_mask_like
    func=tf.image.resize_nearest_neighbor, dynamic=dynamic)
  File "/home/luocai/generative_inpainting/neuralgym_c/ops/layers.py", line 128, in resize
    align_corners=align_corners)
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/ops/gen_image_ops.py", line 1643, in resize_nearest_neighbor
    align_corners=align_corners, name=name)
  File "/root/anaconda3/envs/tf15/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 519, in _apply_op_helper
    repr(values), type(values).__name__))
TypeError: Expected int32 passed to parameter 'size' of op 'ResizeNearestNeighbor', got [None, None] of type 'list' instead.

Here is my test code, and I use tensorflow 1.5 version, @jtetc I don't know if my code has some wrong place?

 def deepfill_model(checkpoint_dir):
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)
    model = InpaintCAModel()
    input_image_ph = tf.placeholder(
        tf.float32, shape=(1, None, None, 3))
    output = model.build_server_graph(input_image_ph, dynamic=True)
    output = (output + 1.) * 127.5
    output = tf.reverse(output, [-1])
    output = tf.saturate_cast(output, tf.uint8)
    vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    assign_ops = []
    for var in vars_list:
        vname = var.name
        from_name = vname
        var_value = tf.contrib.framework.load_variable(
            checkpoint_dir, from_name)
        assign_ops.append(tf.assign(var, var_value))
    sess.run(assign_ops)
    print('Model loaded.')

    return input_image_ph, output, sess

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--imagedir', default='', type=str,
                        help='The filename of image to be completed.')
    parser.add_argument('--outdir', default='', type=str,
                        help='Where to write output.')
    parser.add_argument('--checkpoint_dir', default='', type=str,
                        help='The directory of tensorflow checkpoint.')

    args = parser.parse_args()

    check_folder(args.outdir)
    output_image_dir = os.path.join(args.outdir, "images")
    if not os.path.exists(output_image_dir):
        os.makedirs(output_image_dir)

    ng.get_gpus(1)
    counts = 0
    input_image_ph, output, sess = deepfill_model(args.checkpoint_dir, args.image_height, args.image_width)
    for img in os.listdir(args.imagedir):
        print("img: ", img)

        fre_img, ex = os.path.splitext(img)
        img_outname_deepfill = fre_img + "_output_deepfill.jpg"

        img_path = os.path.join(args.imagedir, img)
        img_in_path = os.path.join(output_image_dir, img)
        img_out_deepfill_path = os.path.join(output_image_dir, img_outname_deepfill)
        print("img_out_deepfill_path: ", img_out_deepfill_path)
        image = cv2.imread(img_path)
        modelflag, mask = getMask(image)

        counts += 1
        assert image.shape == mask.shape
        h, w, _ = image.shape
        grid = 8
        print('before shape of image: {}'.format(image.shape))
        image_deep = image[:h // grid * grid, :w // grid * grid, :]
        mask_deep = mask[:h // grid * grid, :w // grid * grid, :]
        print('Shape of image: {}'.format(image_deep.shape))

        image_deep = np.expand_dims(image_deep, 0)
        mask_deep = np.expand_dims(mask_deep, 0)
        input_image = np.concatenate([image_deep, mask_deep], axis=2)
        print("input_image shape: ", input_image.shape)
        # load pretrained model
        result = sess.run(output, feed_dict={input_image_ph: input_image})
        # logger.info('Processed: {}'.format(img_out_deepfill_path))
        cv2.imwrite(img_out_deepfill_path, result[0][:, :, ::-1])

@c9o
Copy link
Author

c9o commented Dec 21, 2018

@ccc013 I didn't use to_shape to resize but use scale instead. Here is sample code for your reference

@@ -174,10 +174,13 @@ def resize_mask_like(mask, x):
     Returns:
         tf.Tensor: resized mask
 
-    """
     mask_resize = resize(
         mask, to_shape=x.get_shape().as_list()[1:3],
         func=tf.image.resize_nearest_neighbor)
+    """
+    mask_resize =  resize(
+        mask, scale=1./4, dynamic=True
+    )
     return mask_resize

I am a green hand in tensorflow. I guess we cannot get the shape of x since input shape is (1, None, None, 3). And there are twice downsample before then I set scale=1./4. Such changes work in my tests. Hi @JiahuiYu , is my change available? Thanks!

@JiahuiYu
Copy link
Owner

You modification should be fine.

@ccc013
Copy link

ccc013 commented Dec 21, 2018

@jtetc Thanks you very much!

I change the code like you do, and actually I also change other code which contains resize function, like gen_deconv and contextual_attention function.

particularly,I change the line 250 in contextual_attention function:

# downscaling foreground option: downscaling both foreground and
# background for matching and use original background for reconstruction.
f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor)  # https://github.com/tensorflow/tensorflow/issues/11651

to

# downscaling foreground option: downscaling both foreground and
# background for matching and use original background for reconstruction.
f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor, dynamic=True)
b = resize(b, scale=1. / rate,
                   func=tf.image.resize_nearest_neighbor, dynamic=True)

If I don't change this code, I stil got an error.

And my test code works after changing above function codes!

Hi, @JiahuiYu is my change above code in contextual_attention available? Thanks!

@c9o
Copy link
Author

c9o commented Dec 21, 2018

@ccc013 I have same modifications. They are same issues.

@c9o
Copy link
Author

c9o commented Dec 21, 2018

@ccc013 FYI

diff --git a/inpaint_ops.py b/inpaint_ops.py
index afdec3d..25624ea 100644
--- a/inpaint_ops.py
+++ b/inpaint_ops.py
@@ -63,7 +63,7 @@ def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True):
 
     """
     with tf.variable_scope(name):
-        x = resize(x, func=tf.image.resize_nearest_neighbor)
+        x = resize(x, func=tf.image.resize_nearest_neighbor, dynamic=True)
         x = gen_conv(
             x, cnum, 3, 1, name=name+'_conv', padding=padding,
             training=training)
@@ -174,10 +174,13 @@ def resize_mask_like(mask, x):
     Returns:
         tf.Tensor: resized mask
 
-    """
     mask_resize = resize(
         mask, to_shape=x.get_shape().as_list()[1:3],
         func=tf.image.resize_nearest_neighbor)
+    """
+    mask_resize =  resize(
+        mask, scale=1./4, dynamic=True
+    )
     return mask_resize
 
 
@@ -246,10 +249,10 @@ def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
     raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
     # downscaling foreground option: downscaling both foreground and
     # background for matching and use original background for reconstruction.
-    f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
-    b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor)  # https://github.com/tensorflow/tensorflow/issues/11651
+    f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor, dynamic=True)
+    b = resize(b, scale=1./rate, func=tf.image.resize_nearest_neighbor, dynamic=True)
     if mask is not None:
-        mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
+        mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor, dynamic=True)
     fs = tf.shape(f)
     int_fs = f.get_shape().as_list()
     f_groups = tf.split(f, int_fs[0], axis=0)
@@ -320,7 +323,7 @@ def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
     # # case2: visualize which pixels are attended
     # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
     if rate != 1:
-        flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
+        flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor, dynamic=True)
     return y, flow
 

@JiahuiYu
Copy link
Owner

@jtetc Thanks for your contribution!

@ccc013
Copy link

ccc013 commented Dec 23, 2018

@jtetc Thanks!

@awsssix
Copy link

awsssix commented Jan 24, 2019

@jtetc I have trained my model already. If I want build graph once when testing, do I need to retrain the model?

@c9o
Copy link
Author

c9o commented Jan 24, 2019

@jtetc I have trained my model already. If I want build graph once when testing, do I need to retrain the model?

@awsssix If you didn't modify the model, you don't need to retrain.

@awsssix
Copy link

awsssix commented Jan 24, 2019

@jtetc Thanks for your replay.
I saw above that you modified contextual_attention and gen_deconv, just test directly?

@c9o
Copy link
Author

c9o commented Jan 24, 2019

@jtetc Thanks for your replay.
I saw above that you modified contextual_attention and gen_deconv, just test directly?

You can take a try

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants