-
-
Notifications
You must be signed in to change notification settings - Fork 948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhance] Random generator refactor #1459
Conversation
if torch_version_geq(1, 10) and "cuda" in str(device): | ||
pytest.skip("AssertionError: Tensor-likes are not close!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To investigate.
@edgarriba We need to take flake8 down to < 4.0, check this: PyCQA/flake8#1419 |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
6349a11
to
45111a4
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we improve in this PR the interface to reuse parameters between calls ? I think it's something that's continuously asked.
@edgarriba Please approve and merge. Tests failed on flake8 bugs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What’s the exact issue with flake here ? We can eventually fix version but better to find the reason
This is the issue. It will probably be fixed in the next flake8 version. |
@shijianjian then downgrade flake version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few things:
- strings we should stick to double ""
- Coverage lowers quite a bit in this PR
- try also to lower deep source errors (too many)
- would be great if the random samplers can be customized
To apply the exact augmenation again, you may take the advantage of the previous parameter state: | ||
>>> input = torch.randn(1, 3, 32, 32) | ||
>>> aug = RandomPerspective(0.5, p=1.) | ||
>>> (aug(input) == aug(input, params=aug._params)).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How’s it working right now. Are the parameters generated before nowing the input shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. it has to be the same shaped inputs. But this code has decoupled that logic that allowed easier refactoring later.
else: | ||
raise TypeError(f"Unsupported type: {type(self.kernel_size)}") | ||
|
||
self.angle_sampler = Uniform(angle[0], angle[1], validate_args=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to customise the different samplers ? Not only here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. You can override the make_samplers
function for doing so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth to mention that in docs or add an example (or and later PR)
This PR only refactored the whole code base, and allows easier customization later. As a backlog, some ideas like:
|
Changes
This PR refactored the random generator module, mainly for dragging all param validation, distribution generation into
__init__
which reduced the unnecessary object reinitialization.Benchmark
Here we benchmarked with Torchvision default CPU augmentation and Kornia GPU augmentation, on Google Colab K80 GPU with different batch sizes. unit=ms, and
bs
means batch size.The results recorded mean per-sample augmentation speed. Kornia got a better results when larger batch size.