Skip to content

Commit

Permalink
Added antialias flag to interpolate (CUDA, bilinear and bicubic) (#70…
Browse files Browse the repository at this point in the history
…930)

Summary:
Description:
- Added antialias flag to interpolate (CUDA)
  - forward and backward for bicubic mode
  - added tests

Previous PR for CPU bilinear, pytorch/pytorch#65142
Previous PR for CPU bicubic, pytorch/pytorch#68819

### Benchmarks

<details>
<summary>
Bilinear forward pass, PIL, PTH CPU and PTH CUDA
</summary>

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```

Torch version: 1.11.0a0+gitd032369
Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 8
[----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (320, 196) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               2851.2              |            874.1          |            57.1
      channels_last non-contiguous torch.float32  |               2856.1              |           1155.8          |           130.6

Times are in microseconds (us).

[----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (460, 220) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               3705.9              |           1005.8          |            66.3
      channels_last non-contiguous torch.float32  |               3742.9              |           1332.8          |           143.5

Times are in microseconds (us).

[------------------------------------ Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 96) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               1768.0              |           725.2           |            77.9
      channels_last non-contiguous torch.float32  |               1753.7              |           942.5           |           144.0

Times are in microseconds (us).

[----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (1200, 196) ----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               9522.6              |           2593.8          |           157.8
      channels_last non-contiguous torch.float32  |               9513.5              |           3622.7          |           241.5

Times are in microseconds (us).

[----------------------------------- Downsampling (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 1200) ----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               2240.1              |           565.5           |            93.3
      channels_last non-contiguous torch.float32  |               2244.2              |           972.7           |           170.8

Times are in microseconds (us).

[------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (320, 196) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              1441.3             |           386.1           |            22.3

Times are in microseconds (us).

[------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (460, 220) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              1815.2             |           376.8           |            27.8

Times are in microseconds (us).

[-------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 96) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              962.3              |           400.0           |            29.4

Times are in microseconds (us).

[------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (1200, 196) -------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              4749.7             |           910.1           |            63.7

Times are in microseconds (us).

[------------------------- Downsampling (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 1200) -------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              1098.1             |           272.0           |            36.4

Times are in microseconds (us).

```

</details>

<details>
<summary>
Bicubic forward pass, PIL, PTH CPU and PTH CUDA
</summary>

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```

Torch version: 1.11.0a0+gitd032369
Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 8
[------------------------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               4522.4              |           1406.7          |           170.3
      channels_last non-contiguous torch.float32  |               4530.0              |           1435.4          |           242.2

Times are in microseconds (us).

[------------------------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               5726.4              |           1628.6          |           164.0
      channels_last non-contiguous torch.float32  |               5722.6              |           1665.6          |           234.7

Times are in microseconds (us).

[------------------------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) ------------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               2909.1              |           1461.5          |           276.9
      channels_last non-contiguous torch.float32  |               2892.9              |           1458.7          |           345.1

Times are in microseconds (us).

[----------------------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |              14699.2              |           4283.9          |           407.1
      channels_last non-contiguous torch.float32  |              14711.3              |           4321.1          |           477.0

Times are in microseconds (us).

[----------------------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) -----------------------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |               3467.0              |           980.0           |           339.2
      channels_last non-contiguous torch.float32  |               3465.2              |           982.3           |           407.8

Times are in microseconds (us).

[-------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              2396.7             |           877.8           |            68.1

Times are in microseconds (us).

[-------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              3068.2             |           777.3           |            64.7

Times are in microseconds (us).

[-------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              1540.2             |           829.3           |           100.4

Times are in microseconds (us).

[------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              7919.5             |           1467.8          |           151.6

Times are in microseconds (us).

[------------------------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------------------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ---------------------------------------------------------------------------------------------------------------
       contiguous torch.float32  |              1695.7             |           631.2           |           117.7

Times are in microseconds (us).

```

</details>

<details>
<summary>
Bilinear backward pass, PTH CPU and PTH CUDA
</summary>

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```
- Measure only backward op

Torch version: 1.11.0a0+gitd032369
Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 8
[------------- Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (320, 196) ------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |           4686.8          |           215.7
      channels_last non-contiguous torch.float32  |           5101.1          |           220.5

Times are in microseconds (us).

[------------- Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (460, 220) ------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |           6011.2          |           204.4
      channels_last non-contiguous torch.float32  |           6396.0          |           210.0

Times are in microseconds (us).

[------------- Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 96) -------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |           2035.6          |           250.2
      channels_last non-contiguous torch.float32  |           1589.6          |           252.5

Times are in microseconds (us).

[------------ Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |          11392.5          |           256.5
      channels_last non-contiguous torch.float32  |          11640.2          |           263.9

Times are in microseconds (us).

[------------ Downsampling backward (bilinear): torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |          11769.6          |           465.9
      channels_last non-contiguous torch.float32  |          12407.0          |           474.4

Times are in microseconds (us).

[---- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (320, 196) ----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |           3931.0          |           133.3

Times are in microseconds (us).

[---- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (460, 220) ----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |           5594.8          |           133.9

Times are in microseconds (us).

[---- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 96) -----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |           1272.6          |           133.0

Times are in microseconds (us).

[--- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (1200, 196) ----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |          10618.1          |           134.0

Times are in microseconds (us).

[--- Downsampling backward (bilinear): torch.Size([1, 1, 906, 438]) -> (120, 1200) ----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |          11082.2          |           154.6

Times are in microseconds (us).

```

</details>

<details>
<summary>
Bicubic backward pass, PTH CPU and PTH CUDA
</summary>

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```
- Measure only backward op

Torch version: 1.11.0a0+gitd032369
Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 8
[------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |           6791.2          |           618.9
      channels_last non-contiguous torch.float32  |           7125.2          |           622.9

Times are in microseconds (us).

[------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |           8806.2          |           600.3
      channels_last non-contiguous torch.float32  |           9167.6          |           607.5

Times are in microseconds (us).

[-------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) -------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |           3683.6          |           693.8
      channels_last non-contiguous torch.float32  |           3617.4          |           695.0

Times are in microseconds (us).

[------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |          17548.2          |           779.4
      channels_last non-contiguous torch.float32  |          17966.2          |           786.5

Times are in microseconds (us).

[------------- Downsampling backward (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------]
                                                  |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: ----------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |            28.4           |            1.6
      channels_last non-contiguous torch.float32  |            28.4           |            1.6

Times are in milliseconds (ms).

[---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) -----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |           6266.1          |           208.5

Times are in microseconds (us).

[---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) -----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |           8218.3          |           200.8

Times are in microseconds (us).

[----- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) -----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |           3458.9          |           231.9

Times are in microseconds (us).

[---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) ----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |          15729.3          |           261.6

Times are in microseconds (us).

[---- Downsampling backward (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) ----]
                                 |  1.11.0a0+gitd032369 cpu  |  1.11.0a0+gitd032369 cuda
8 threads: -----------------------------------------------------------------------------
       contiguous torch.float32  |          26279.8          |           547.0

Times are in microseconds (us).

```

</details>

Code is moved from torchvision: pytorch/vision#4211 and optimized

Pull Request resolved: pytorch/pytorch#70930

Reviewed By: zou3519

Differential Revision: D33817902

Pulled By: jbschlosser

fbshipit-source-id: d63a620f8972ff36b63841f0bc6c820466f58f69
  • Loading branch information
vfdev-5 authored and facebook-github-bot committed Jan 27, 2022
1 parent eacbd6b commit d358cfd
Show file tree
Hide file tree
Showing 4 changed files with 590 additions and 46 deletions.
110 changes: 108 additions & 2 deletions aten/src/ATen/native/cuda/UpSample.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ __device__ __forceinline__ static int nearest_neighbor_exact_compute_source_inde
// input_index = round(index_f32)
// Same as Pillow and Scikit-Image/Scipy ndi.zoom
const int src_index =
min(static_cast<int>(floorf((dst_index + 0.5) * scale)), input_size - 1);
min(static_cast<int>(floorf((dst_index + static_cast<float>(0.5)) * scale)), input_size - 1);
return src_index;
}

Expand All @@ -171,7 +171,7 @@ __device__ __forceinline__ static int nearest_neighbor_exact_bw_compute_source_i
int output_size) {
// Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom
const int src_index =
min(static_cast<int>(ceilf(dst_index * scale - 0.5)), output_size);
min(static_cast<int>(ceilf(dst_index * scale - static_cast<float>(0.5))), output_size);
return src_index;
}

Expand Down Expand Up @@ -255,5 +255,111 @@ __device__ __forceinline__ static accscalar_t cubic_interp1d(
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

namespace upsample_antialias {

// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
struct BilinearFilterFunctor {

template <typename accscalar_t>
__device__ accscalar_t operator()(accscalar_t x) const {
if (x < 0) {
x = -x;
}
if (x < 1) {
return 1 - x;
}
return 0;
}

static const int size = 2;
};

// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L46-L62
struct BicubicFilterFunctor {

template <typename accscalar_t>
__device__ accscalar_t operator()(accscalar_t x) const {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
const accscalar_t a = -0.5;
if (x < 0) {
x = -x;
}
if (x < 1) {
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
if (x < 2) {
return (((x - 5) * x + 8) * x - 4) * a;
}
return 0;
}

static const int size = 4;
};

template <typename accscalar_t>
__device__ __forceinline__ static void _compute_weights_span(
const int i,
const int input_size,
const accscalar_t scale,
const accscalar_t support,
int& xmin,
int& xsize,
accscalar_t& center) {
center = scale * (i + static_cast<accscalar_t>(0.5));
xmin = max(static_cast<int>(center - support + static_cast<accscalar_t>(0.5)), static_cast<int>(0));
xsize = min(static_cast<int>(center + support + static_cast<accscalar_t>(0.5)), input_size) - xmin;
}

template <typename scalar_t, typename accscalar_t, typename interp_filter_t>
__device__ __forceinline__ static void _compute_weights(
scalar_t* wt_ptr,
const accscalar_t scale,
int interp_size,
const interp_filter_t& interp_filter,
accscalar_t xmin_m_center,
int xsize) {

accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
accscalar_t total_w = 0.0;
int j = 0;
for (j = 0; j < xsize; j++) {
accscalar_t w = interp_filter((j + xmin_m_center + static_cast<accscalar_t>(0.5)) * invscale);
wt_ptr[j] = static_cast<scalar_t>(w);
total_w += w;
}
for (j = 0; j < xsize; j++) {
if (total_w != 0.0) {
wt_ptr[j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[j] = static_cast<scalar_t>(0.0);
}
}

template <typename scalar_t, typename accscalar_t>
__device__ __forceinline__ static accscalar_t interpolate_aa_single_dim(
const scalar_t* src,
const scalar_t* weights,
int size) {
scalar_t t = static_cast<accscalar_t>(*src);
scalar_t wts = static_cast<accscalar_t>(weights[0]);
accscalar_t output = t * wts;

int j = 1;
for (; j < size; j++) {
wts = static_cast<accscalar_t>(weights[j]);
t = static_cast<accscalar_t>(*(src + j));
output += t * wts;
}
return output;
}

}

} // namespace native
} // namespace at

0 comments on commit d358cfd

Please sign in to comment.