Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Guides] Pre-processing of images for Xception model #20

Closed
swghosh opened this issue May 11, 2020 · 6 comments · Fixed by #26
Closed

[Guides] Pre-processing of images for Xception model #20

swghosh opened this issue May 11, 2020 · 6 comments · Fixed by #26

Comments

@swghosh
Copy link
Contributor

swghosh commented May 11, 2020

As a part of the Transfer Learning tutorial on the website, Xception model have been used.

Our raw images have a variety of sizes. In addition, each pixel consists of 3 integer
values between 0 and 255 (RGB level values). This isn't a great fit for feeding a
neural network. We need to do 2 things:
- Standardize to a fixed image size. We pick 150x150.
- Normalize pixel values between 0 and 1. We'll do this using a `Rescaling` layer as
part of the model itself.

x = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255.0)(
x
) # Scale inputs to [0. 1]

The preprocessing used to feed in data using the tf.data API as well as with the new Rescaling layer normalizes raw input pixels (0-255) into a range (0, 1). But, as per keras_applications.xception.preprocess_input,
(https://github.com/tensorflow/tensorflow/blob/476ec938b253a9479de09aab88dceec6f0a304ed/tensorflow/python/keras/applications/xception.py#L318-L320) the corresponding preprocess_input uses mode='tf' (https://github.com/tensorflow/tensorflow/blob/476ec938b253a9479de09aab88dceec6f0a304ed/tensorflow/python/keras/applications/imagenet_utils.py#L181-L184) which normalizes input pixels in the range (-1, 1) instead of (0, 1).

Sometimes calculating activations (for deep feature extraction step) from pre-trained weights are prone to these type of problems with different preprocessing input ranges. Is the example affected?

@swghosh swghosh changed the title [Guides] Preprocessing of images for Xception [Guides] Pre-processing of images for Xception model May 11, 2020
@fchollet
Copy link
Member

Yes, that seems like an issue. We should fix it.

@swghosh
Copy link
Contributor Author

swghosh commented May 11, 2020

I'll be working on this and hope to get back to you with a PR soon.
Thanks for addressing the issue.

@swghosh
Copy link
Contributor Author

swghosh commented May 12, 2020

It appears to me that the Rescaling layer used in the code can only support multiplicative op(s). Instead of using 1/255., I can use 1/127.5 but that will scale the data in range (0, 2) which is undesired.

In order to attain required normalisation range of (-1, 1), the inputs should be scaled as: x /= 127.5; x -= 1 or similar.

Should we replace the Rescaling layer with a new preprocess tf.function or use keras_applications.xception.preprocess_input that does exactly the same operation as above and pass it to ds.map?

@tf.function
def preprocess(x):
     x /= 127.5
     x -= 1
     return x

ds = ds.map(lambda x, y: (preprocess(x), y), AUTOTUNE)

or

preprocess = tf.keras.applications.xception.preprocess_input
ds = ds.map(lambda x, y: (preprocess(x), y), AUTOTUNE)

Although, it wouldn't be a good idea to let go away the Rescaling layer from a documentation example as it'd then indirectly discourage users from using the new Preprocessing Layers; one workaround would be to use the Rescaling layer combined with a Lambda layer to scale inputs appropriately.

Something like this:

# x has range (0, 255)
x = tf.keras.layers.experimental.preprocessing.Rescaling(1 / 127.5)(x) # x has range (0, 2)
x = tf.keras.layers.Lambda(lambda inp: inp - 1.)(x) # x has range (-1, 1)

@fchollet
Copy link
Member

We need preprocessing to be part of the model, as a best practice.

I will look into adding a new offset argument to Rescaling to support this use case.

For the time being I would recommend using a Normalization layer (from layers.experimental.preprocessing). It can do both scaling and offsetting. Just set the weights correctly.

@swghosh
Copy link
Contributor Author

swghosh commented May 12, 2020

Thanks @fchollet.

I'll make the code changes to the script, use the Normalization layer and submit a PR as of now. Also, hope to get back to you soon regarding the offset argument for Rescaling as well.

/cc: @tanzhenyu

We need preprocessing to be part of the model, as a best practice.

I will look into adding a new offset argument to Rescaling to support this use case.

We can have a discussion regarding the same. Thanks.

@fchollet
Copy link
Member

Let's move the discussion to the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants