Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I test the model? #12

Closed
Yangshell opened this issue Aug 8, 2018 · 8 comments
Closed

How can I test the model? #12

Yangshell opened this issue Aug 8, 2018 · 8 comments
Labels
help wanted Extra attention is needed

Comments

@Yangshell
Copy link

No description provided.

@ogroth
Copy link
Owner

ogroth commented Aug 8, 2018

Hi Yangshell, you can run a trained model by setting up an estimator.
gqn_model = tf.estimator.Estimator(model_fn=gqn_draw_model_fn, model_dir='your/model/directory')
Obviously, 'you/model/directory' needs to point to the directory, where a snapshot is stored (e.g. one of the model directories downloaded from the README).
Then you can call gqn_model.predict(.) feeding context images and poses as well as a query pose. The data_provider test case shows how to run the DataReader as a standalone object providing the correct data.
The predict(.) function returns a dictionary with the predicted mean image and its variance.
For a better understanding of the model's inputs and outputs, please refer to the gqn_draw_model_fn.

@Yangshell
Copy link
Author

Can I get the image that model predict and real?

@ogroth
Copy link
Owner

ogroth commented Aug 8, 2018

The predict(.) returns a dictionary with the predicted mean image. The ground truth image is fetched by the DataReader with the remaining data tuple, see here.

@Yangshell
Copy link
Author

Hello, I use:
with tf.train.SingularMonitoredSession() as sess:
d = sess.run(data)

import tensorflow as tf
from gqn.gqn_model import gqn_draw_model_fn
from gqn.gqn_params import _DEFAULTS

gqn_model = tf.estimator.Estimator(model_fn=gqn_draw_model_fn, model_dir='/Users/yangshell/Downloads/rooms_ring_debug/gqn_pool_draw2', params=_DEFAULTS)
result = gqn_model.predict(d)

The "result" I get is a generator class, not a dict. What is the mistake in it?
Thank you for your patience!

@stefanwayon
Copy link
Collaborator

Hi Yangshell,

A few notes on your code to help you get going with this:

The _DEFAULTS constant in gqn.gqn_params is meant for internal use (hence the underscore at the beginning). The actual default parameters are gqn.gqn_params.PARAMS.

If you take a look at the training script, where the Estimator is configured, the params attribute for the estimator is a dict { "gqn_params": PARAMS, "debug": FLAGS.debug }. You can set "debug": False.

Estimator.predict works similarly to Estimator.train [code example] or Estimator.evaluate, meaning you need to pass in an input function. We provide an input function that works with the GQN datasets: data_provider.gqn_tfr_provider.gqn_input_fn.

The result of Estimator.predict is indeed a generator class. This is a python object that you can iterate over (i.e. do things like for i in ...).

Putting all that together, your code will end up looking something like:

from gqn.gqn_model import gqn_draw_model_fn
from gqn.gqn_params import PARAMS
from data_provider.gqn_tfr_provider import gqn_input_fn

MODEL_DIR='/Users/yangshell/Downloads/rooms_ring_debug/gqn_pool_draw2'
DATA_DIR='/tmp/data/gqn-dataset'
DATASET='rooms_ring_camera'

estimator = tf.estimator.Estimator(
    model_fn=gqn_draw_model_fn,
    model_dir=MODEL_DIR,
    params={'gqn_params' : PARAMS,  'debug' : False})

input_fn = lambda mode: gqn_input_fn(
        dataset=DATASET,
        context_size=PARAMS.CONTEXT_SIZE,
        root=DATA_DIR,
        mode=mode)

for prediction in estimator.predict(input_fn=input_fn):
    # prediction is the dict @ogroth was mentioning
    print(prediction['predicted_mean'])  # this is probably what you want to look at
    print(prediction['predicted_variance'])  # or use this to sample a noisy image

If you already have a data_provider.gqn_tfr_provider.TaskData object with, say, numpy images as your input to the network, you could write a custom input_fn that maps that into tensors using something like tf.contrib.framework.nest, and then predict using that input function.

Let us know how this goes. :-)

Best,
Ștefan

@stefanwayon stefanwayon added question Further information is requested help wanted Extra attention is needed and removed question Further information is requested labels Aug 9, 2018
@Yangshell
Copy link
Author

I have achieve the test process. But I found problem in image result.
This is true image:
query0
This is predict image:
test0
I used model "gqn_pool_draw12". You can see the effect of the wall is good, but the model did not predict the blue cylinder.

@ogroth
Copy link
Owner

ogroth commented Aug 14, 2018

Hey Yangshell,
This is not unexpected and happens occasionally when the model seems to be "unsure" about the geometry of objects. In such cases, it seems to fall back to the prediction of the room's geometry. This behaviour might be mitigated by training the model longer (we've trained for ~200K iterations, the paper reported ~2M iterations) or feeding different context views.
But if you find more odd behavior, please feel free to open a new issue and post your findings. We've just started to experiment with the model ourselves and sharing failure cases like this can be super helpful for other people working with the code. :)

@tarunsharma1
Copy link

tarunsharma1 commented Jan 23, 2019

Can someone provide a clean implementation of how to use the network to predict an image from the test set? It seems like people have got it to work but no one has provided clean code that works now. This would be super helpful. Thanks

@Yangshell

EDIT: this works
#17 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants