-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✨ Add PyTorch image classification example #13134
✨ Add PyTorch image classification example #13134
Conversation
Nice!! Relevant for #13080 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this! I left a few comments.
459e24e
to
b61ab72
Compare
I'll review this PR in detail (thanks for working on this!). Regarding the fixtures for the tests, I've recently moved these files to the hf-internal-testing organization on the hub. This makes it more clear, as otherwise these fixture files are also downloaded when people do a |
run_image_classification.py | ||
--output_dir {tmp_dir} | ||
--model_name_or_path google/vit-base-patch16-224-in21k | ||
--train_dir tests/fixtures/tests_samples/cats_and_dogs/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add the cats and dogs as a dataset to the hub under the hf-internal-testing organization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, but this is actually testing the fact that the script works on local image folders. I can issue another PR that directly pulls them down into an imagefolder-like cache dir to test. But I'll leave this as-is for now if its not a big deal.
"value if set." | ||
}, | ||
) | ||
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will images be squared? Or is it just the smaller edge of the image that will be matched to this number (which torchvision's Resize
does if you only provide an integer)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're correct, this value is passed directly to torchvision's Resize
train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."}) | ||
validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."}) | ||
train_val_split: Optional[float] = field( | ||
default=0.15, metadata={"help": "Percent to split off of train for validation."} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm so you have a training dataset, a validation dataset and a test set. The validation_dir
is actually the test set? And it only makes sense to add a train_val_split
if you don't provide a validation dataset yourself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no test set here. There's just train + validation. If 'validation' key is not found in dataset, we create a split off of train for validation and set it to the 'validation' key of the dataset dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! Just some small comments.
Most importantly, it would be great to remove the fixtures files from this PR in favor of a HuggingFace dataset.
Last nit on my side: can we move the vision folder to be |
21231cb
to
d979eeb
Compare
Ok, addressed most of the comments. Merging as-is for now. @NielsRogge I did not address these two items, however I can in future PRs (if need be):
|
13134 |
What does this PR do?
Adds PyTorch image classification example. For now, it uses
torchvision.datasets.ImageFolder
to load local image folders (just like the flax image classification example). In the future, we will switch to using thedatasets
package's image folder (once it exists).Marking as draft for now as I'm still working through cleaning up changes I made from this example I wrote earlier that uses
datasets
instead.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.