Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TorchVision channel preprocessing #3173

Merged
merged 7 commits into from
Mar 2, 2023

Conversation

geoffreyangus
Copy link
Collaborator

@geoffreyangus geoffreyangus commented Mar 1, 2023

This PR introduces automatic channel resizing for torchvision models to ensure that images fed into torchvision encoders always have 3 channels. Closes #3170.

@w4nderlust
Copy link
Collaborator

@jimthompson5802 FYI

@github-actions
Copy link

github-actions bot commented Mar 1, 2023

Unit Test Results

         6 files  ±  0           6 suites  ±0   6h 14m 24s ⏱️ + 7m 44s
  4 006 tests +25    3 962 ✔️ +25    44 💤 ±0  0 ±0 
12 039 runs  +75  11 898 ✔️ +75  141 💤 ±0  0 ±0 

Results for commit 55fff6a. ± Comparison against base commit 4f6d6ae.

♻️ This comment has been updated with latest results.

Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work on this!

@tgaddair tgaddair marked this pull request as ready for review March 1, 2023 22:11
@tgaddair tgaddair added bug Something isn't working release-0.7 Needs cherry-pick into 0.7 release branch labels Mar 1, 2023
Copy link
Collaborator

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fast fix! Just one (optional) nit.

self,
metadata: TrainingSetMetadataDict,
torchvision_transform: Optional[torch.nn.Module] = None,
transform_metadata: Optional[Dict[str, Any]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider breaking up transform_metadata into separate fields: transform_height, transform_width, transform_num_channels`, or used a dataclass for stronger field guarantees and typing:

@dataclass
class ImageTransformMetadata:
    height: int
    width: int
    num_channels: int

Copy link
Collaborator Author

@geoffreyangus geoffreyangus Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice suggestion– I wasn't sure that torchscript supported this, but apparently it does! pytorch/pytorch#72901

@tgaddair tgaddair merged commit 98ac8d2 into master Mar 2, 2023
@tgaddair tgaddair deleted the fix-torchvision-channel-preprocessing branch March 2, 2023 04:59
tgaddair pushed a commit that referenced this pull request Mar 2, 2023
tgaddair added a commit that referenced this pull request Mar 2, 2023
Co-authored-by: Geoffrey Angus <geoffrey@predibase.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working release-0.7 Needs cherry-pick into 0.7 release branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Clarification about TorchVision Pretrained Model Encoders usage
4 participants