SimpleUnet bug fixes#496
Conversation
max_pool, preventing NaN propagation into BatchNorm - Fix 1x1 ConvolutionPlan source/target grids in SimpleUNetDown and SimpleUNetUp to match the grid the data actually resides on - Remove unused imports Fix non-contiguous grad_output crash in convolution_plan.py by calling .contiguous() before passing to the C++ transpose convolution backward kernel. Add tests/unit/test_simple_unet.py with smoke tests covering forward pass, backward pass, batched grids, single-layer UNet, and reset_parameters. Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
There was a problem hiding this comment.
Pull request overview
This PR fixes several bugs in the SimpleUNet implementation and adds test coverage for it. The bugs include NaN propagation via -inf values post max-pooling entering BatchNorm, incorrect ConvolutionPlan source/target grid assignments in SimpleUNetDown and SimpleUNetUp, and a crash on non-contiguous grad_output in the convolution backward kernel. It also removes unused imports.
Changes:
fvdb/convolution_plan.py: Added.contiguous()call ongrad_outputbefore passing it to both the regular and transposed C++ convolution backward kernels, fixing a crash on non-contiguous gradient tensors.fvdb/nn/simple_unet.py: ReplacedGridBatch.max_poolwithfvnn.MaxPoolinSimpleUNetDown(which includes-infcleanup), fixedConvolutionPlansource/target grid arguments in bothSimpleUNetDownandSimpleUNetUpto match the grid where data actually resides after pooling, and removed several unused imports.tests/unit/test_simple_unet.py: New test file with smoke tests covering forward pass, backward pass, batched grids, single-layer UNet, andreset_parameters.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
fvdb/convolution_plan.py |
Adds .contiguous() to grad_output in the backward pass to fix crashes with non-contiguous gradient tensors |
fvdb/nn/simple_unet.py |
Uses fvnn.MaxPool for -inf cleanup, fixes ConvolutionPlan grid arguments, and removes unused imports |
tests/unit/test_simple_unet.py |
New smoke/integration tests for SimpleUNet covering key behaviors |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
harrism
left a comment
There was a problem hiding this comment.
Nice to have this fixed and tested.
SimpleUNetDown, preventing NaN propagation into BatchNormAdd tests/unit/test_simple_unet.py with smoke tests covering forward pass, backward pass, batched grids, single-layer UNet, and reset_parameters.