Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check to '_add_norm' #3820

Merged
merged 3 commits into from
Oct 20, 2022
Merged

Conversation

marib00
Copy link
Contributor

@marib00 marib00 commented Oct 18, 2022

from fastai.vision.all import *
path = untar_data(URLs.MNIST_TINY)
def get_tiny_dataloaders():
    return DataBlock(
          (ImageBlock(cls=PILImageBW), ImageBlock(cls=PILImageBW)), 
          get_items=get_image_files, 
    ).dataloaders(TINYDATA_PATH)

print(dls.one_batch()[0].shape)  # inputs have 1 channel...
learn = unet_learner(dls, resnet18, loss_func=mse, n_in=1, n_out=1)
print(dls.one_batch()[0].shape)  # now inputs have 3 channels!
learn.lr_find()  # <- will throw an exception

In the code above, the unet_learner will automatically add a normalization transform using imagenet_stats. As a result, the dataloader will start producing 3-channel images, although the first conv layer of the UNet now has 1 input channel.
This results in an exception being thrown when one tries to train. The same happens with vision_learner.

Wasn't sure what the best way to address this is; in this PR I just throw a warning and don't add the normalization. Another possibility would be to absorb the normallization statistics into the weights and biases of the first conv layer, but this
doesn't go well with zero-padding (would work as if you padded with zeros yourself and then normalized these zeros too).

@marib00 marib00 requested a review from jph00 as a code owner October 18, 2022 19:42
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@marib00
Copy link
Contributor Author

marib00 commented Oct 18, 2022

BTW, my first PR ever so apologies if done something silly 😳

@jph00
Copy link
Member

jph00 commented Oct 19, 2022

Congrats on your first PR! :D You didn't do anything silly at all -- or at least, not anything that I can see. Many thanks for your contribution.

@jph00
Copy link
Member

jph00 commented Oct 19, 2022

Looks like there's a CI failure - if you have a moment, perhaps you could try running the notebook with the failing test to see what the issue is?

@warner-benjamin
Copy link
Collaborator

warner-benjamin commented Oct 19, 2022

Congrats on the first PR. Since we were discussing this issue on the fastai discord, I thought I would chime in after looking at the PR.

I think the CI error is due to the new code preventing the automatic conversion of single channel greyscale images to three channel images when the user doesn't specify n_in=1 to unet_learner or vision_learner. With the default of n_in=3, we want this conversion from single to three channels to occur in the dataloader since the pretrained model will still be expecting three channel inputs.

I think the solution is to add n_in to '_add_norm' (and perhaps _timm_norm?) and then do a channel consistency check between n_in and the pretrained stats. And raise an error/warning if the two differ.

@marib00
Copy link
Contributor Author

marib00 commented Oct 19, 2022

Right, so I have successfully broken things when trying to fix things... 🤣

Will go with @warner-benjamin suggestion and add n_in to _add_norm and _timm_norm (need to see if the latter is an issue to begin with or not).

What should be the behaviour in such a case though? Not applying normalization to the inputs of ImageNet pertained model goes against 'common practice' (for lack of better phrase) but not sure if it actually is a problem - the pixels are between 0 and 1 (so they don't take on crazy values like -1234), the networks have normalisation layers everywhere...

On the other hand inserting a 1-channel normalisation transform with stats based on say the 1st batch might not be the worst idea. I mean, why are we normalizing with ImageNet stats to begin with? If a pertained model expects 0-mean 1-stdev inputs, then using ImageNet stats to normalize say x-ray images quite likely doesn't give you what you want. Although I suppose you at least don't have to calculate the stats for your dataset yourself.

So many questions! 🤯

@jph00
Copy link
Member

jph00 commented Oct 19, 2022

Since it sounds like we're not really sure what the desired behavior is, perhaps I should close this PR, and we can discuss it on the forums or discord, if that's OK?

Generally it's best to use the same normalisation constants used when training the pretrained model, which is why we use that as the default.

@jph00 jph00 closed this Oct 19, 2022
@jph00 jph00 reopened this Oct 19, 2022
@jph00
Copy link
Member

jph00 commented Oct 20, 2022

Actually on further consideration I think this PR as is can only be an improvement on the current situation. I wouldn't say I fully understand all the potential impacts though. What do you both think?

@marib00
Copy link
Contributor Author

marib00 commented Oct 20, 2022

It definitely is an improvement. I have created a new thread on the forum https://forums.fast.ai/t/add-check-to-add-norm-3820-pr/101344 as this whole normalisation business is definitely worth discussing. I will also run some experiments comparing norm using imagenet_stats vs actual stats from the dataset being used (particularly when they're different from imagined_stats like MNIST) vs no normalisation at all (with freezing/unfreezing of batch norm layers, which should be able to fix any issues with non normalised inputs...maybe).

@jph00
Copy link
Member

jph00 commented Oct 20, 2022

OK cool I'll merge this now then, but feel free to do followup PRs if you have more ideas.

@jph00 jph00 merged commit 2ed8cce into fastai:master Oct 20, 2022
@marib00 marib00 deleted the fix-learner-normalization-1ch branch October 20, 2022 07:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants