本项目致力于通过 OpenAI Triton 深度优化(ECCV 2020) Hierarchical Dynamic Filtering Network for RGB-D Salient Object Detection 中 Dynamic Dilated Pyramid Module 核心模块的性能瓶颈。
在传统方案中(见 unfold_impl.py),如果我们需要对一个提取到深层的 Feature Map 执行不同膨胀率的局部截取与运算(例如 Dilation=1, 3, 5 的融合网络),常见实现依赖 nn.Unfold 将特征切面全量缓存进行提取。这种方案在深层会面临如下问题:
- 庞大的内存:例如当我们选定
Kernel Size = 5,特征图会被nn.Unfold强制在内存中膨胀$25\times$ 的体积。在大批量和稍微增大的分辨率下,可以轻易突破 10GB~20GB 从而把整个工作站的物理显存撑爆。 - 多分支数据搬运:不同 Dilation 分支运算结束后,我们通常要强行用
torch.cat去完成所有算链的汇总。这使得极其珍贵的片上内存又被反复读写搬迁。
在 triton_impl.py 中弃用 PyTorch 自带接口,直接使用基于 Triton 的并行方案。单线程 Block 同时拉起并承载 Dilation=1, 3, 5 这三个不同感受野的独立采集、坐标计算与点乘求和。并将最后运算获得的不同深度的像素,通过底层跨步指针重写(tl.store 加上严格对齐的步长偏移),直接送给下级聚合算子。这直接消除了那些25倍膨胀产生的中间矩阵所调用的内存读写。
我们写了验证比对框架,支持所有矩阵维度的排列组合遍历计算。
# 执行如下自动化扫描脚本即可跑完所有矩阵组合,并支持检验浮点反推梯度的正确性
python bench.py --backward在执行中部分维度的横向与纵深吞吐,Triton 方案平均实现了相比 PyTorch Native 的 约 2倍 速率提升:
| Batch | Kernel | Dim | Res | AMP | Impl | Mean Time (ms) | Peak Mem (MB) | Aligned |
|---|---|---|---|---|---|---|---|---|
| 1 | 64 | 128 | none | triton | 1.07604 | 83.1743 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | none | unfold | 1.23796 | 95.1743 | ✓对齐 | 0.87x |
| 1 | 64 | 128 | fp16 | triton | 1.55533 | 153.27 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | fp16 | unfold | 2.42349 | 153.27 | ✓对齐 | 0.64x |
| 1 | 64 | 128 | bf16 | triton | 1.80216 | 153.269 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | bf16 | unfold | 2.34165 | 153.269 | ✓对齐 | 0.77x |
| 1 | 64 | 256 | none | triton | 2.26978 | 323.425 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | none | unfold | 2.80636 | 371.425 | ✓对齐 | 0.81x |
| 1 | 64 | 256 | fp16 | triton | 1.65957 | 265.199 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | fp16 | unfold | 2.20246 | 313.199 | ✓对齐 | 0.75x |
| 1 | 64 | 256 | bf16 | triton | 1.6835 | 265.198 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | bf16 | unfold | 2.1826 | 313.198 | ✓对齐 | 0.77x |
| 1 | 64 | 128 | none | triton | 1.58221 | 249.556 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | none | unfold | 2.67495 | 321.556 | ✓对齐 | 0.59x |
| 1 | 64 | 128 | fp16 | triton | 1.46708 | 303.768 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | fp16 | unfold | 2.04344 | 281.768 | ✓对齐 | 0.72x |
| 1 | 64 | 128 | bf16 | triton | 1.42362 | 303.767 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | bf16 | unfold | 2.01173 | 281.767 | ✓对齐 | 0.71x |
| 1 | 64 | 256 | none | triton | 6.66185 | 993.556 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | none | unfold | 12.026 | 1281.56 | ✓对齐 | 0.55x |
| 1 | 64 | 256 | fp16 | triton | 5.99051 | 1209.77 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | fp16 | unfold | 9.73161 | 1121.77 | ✓对齐 | 0.62x |
| 1 | 64 | 256 | bf16 | triton | 5.99144 | 1209.77 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | bf16 | unfold | 9.69384 | 1121.77 | ✓对齐 | 0.62x |
| 1 | 64 | 128 | none | triton | 3.66603 | 634.317 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | none | unfold | 7.36078 | 834.317 | ✓对齐 | 0.50x |
| 1 | 64 | 128 | fp16 | triton | 3.56193 | 785.732 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | fp16 | unfold | 6.0926 | 730.904 | ✓对齐 | 0.58x |
| 1 | 64 | 128 | bf16 | triton | 3.56308 | 785.731 | 参考(Ref) | 1.00x |
| 1 | 64 | 128 | bf16 | unfold | 6.08826 | 730.903 | ✓对齐 | 0.59x |
| 1 | 64 | 256 | none | triton | 14.4004 | 2531.15 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | none | unfold | 29.2564 | 3331.15 | ✓对齐 | 0.49x |
| 1 | 64 | 256 | fp16 | triton | 13.9935 | 3130.9 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | fp16 | unfold | 24.1352 | 2914.9 | ✓对齐 | 0.58x |
| 1 | 64 | 256 | bf16 | triton | 13.971 | 3130.9 | 参考(Ref) | 1.00x |
| 1 | 64 | 256 | bf16 | unfold | 24.1342 | 2914.9 | ✓对齐 | 0.58x |
| 2 | 64 | 128 | none | triton | 1.27239 | 193.175 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | none | unfold | 1.73402 | 193.175 | ✓对齐 | 0.73x |
| 2 | 64 | 128 | fp16 | triton | 1.24348 | 177.316 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | fp16 | unfold | 2.05371 | 177.316 | ✓对齐 | 0.61x |
| 2 | 64 | 128 | bf16 | triton | 1.69059 | 177.315 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | bf16 | unfold | 2.29611 | 177.315 | ✓对齐 | 0.74x |
| 2 | 64 | 256 | none | triton | 5.20902 | 643.425 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | none | unfold | 6.81913 | 739.425 | ✓对齐 | 0.76x |
| 2 | 64 | 256 | fp16 | triton | 3.7009 | 529.199 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | fp16 | unfold | 4.96051 | 625.199 | ✓对齐 | 0.75x |
| 2 | 64 | 256 | bf16 | triton | 3.78571 | 529.198 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | bf16 | unfold | 5.02962 | 625.198 | ✓对齐 | 0.75x |
| 2 | 64 | 128 | none | triton | 3.20137 | 497.556 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | none | unfold | 5.74176 | 641.556 | ✓对齐 | 0.56x |
| 2 | 64 | 128 | fp16 | triton | 2.98321 | 605.768 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | fp16 | unfold | 4.69186 | 561.768 | ✓对齐 | 0.64x |
| 2 | 64 | 128 | bf16 | triton | 2.93801 | 605.767 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | bf16 | unfold | 4.69902 | 561.767 | ✓对齐 | 0.63x |
| 2 | 64 | 256 | none | triton | 12.6582 | 1985.56 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | none | unfold | 23.5228 | 2561.56 | ✓对齐 | 0.54x |
| 2 | 64 | 256 | fp16 | triton | 11.7618 | 2417.77 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | fp16 | unfold | 19.2712 | 2241.77 | ✓对齐 | 0.61x |
| 2 | 64 | 256 | bf16 | triton | 11.7394 | 2417.77 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | bf16 | unfold | 19.2795 | 2241.77 | ✓对齐 | 0.61x |
| 2 | 64 | 128 | none | triton | 7.21617 | 1266.32 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | none | unfold | 14.7464 | 1666.32 | ✓对齐 | 0.49x |
| 2 | 64 | 128 | fp16 | triton | 7.06767 | 1566.9 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | fp16 | unfold | 12.2673 | 1458.9 | ✓对齐 | 0.58x |
| 2 | 64 | 128 | bf16 | triton | 7.04598 | 1566.9 | 参考(Ref) | 1.00x |
| 2 | 64 | 128 | bf16 | unfold | 12.274 | 1458.9 | ✓对齐 | 0.57x |
| 2 | 64 | 256 | none | triton | 28.8071 | 5059.15 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | none | unfold | 58.382 | 6659.15 | ✓对齐 | 0.49x |
| 2 | 64 | 256 | fp16 | triton | 28.1097 | 6258.9 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | fp16 | unfold | 48.2942 | 5826.9 | ✓对齐 | 0.58x |
| 2 | 64 | 256 | bf16 | triton | 28.1075 | 6258.9 | 参考(Ref) | 1.00x |
| 2 | 64 | 256 | bf16 | unfold | 48.2801 | 5826.9 | ✓对齐 | 0.58x |
| 4 | 64 | 128 | none | triton | 2.33753 | 323.174 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | none | unfold | 3.0506 | 371.174 | ✓对齐 | 0.77x |
| 4 | 64 | 128 | fp16 | triton | 1.69218 | 265.199 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | fp16 | unfold | 2.29615 | 313.199 | ✓对齐 | 0.74x |
| 4 | 64 | 128 | bf16 | triton | 1.91026 | 265.198 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | bf16 | unfold | 2.64404 | 313.198 | ✓对齐 | 0.72x |
| 4 | 64 | 256 | none | triton | 9.80572 | 1283.17 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | none | unfold | 13.6664 | 1475.17 | ✓对齐 | 0.72x |
| 4 | 64 | 256 | fp16 | triton | 7.40646 | 1057.2 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | fp16 | unfold | 10.523 | 1249.2 | ✓对齐 | 0.70x |
| 4 | 64 | 256 | bf16 | triton | 7.4664 | 1057.2 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | bf16 | unfold | 10.5942 | 1249.2 | ✓对齐 | 0.70x |
| 4 | 64 | 128 | none | triton | 6.41024 | 993.556 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | none | unfold | 11.8664 | 1281.56 | ✓对齐 | 0.54x |
| 4 | 64 | 128 | fp16 | triton | 5.84712 | 1209.77 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | fp16 | unfold | 9.63093 | 1121.77 | ✓对齐 | 0.61x |
| 4 | 64 | 128 | bf16 | triton | 5.85482 | 1209.77 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | bf16 | unfold | 9.62217 | 1121.77 | ✓对齐 | 0.61x |
| 4 | 64 | 256 | none | triton | 25.4249 | 3969.56 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | none | unfold | 47.263 | 5121.56 | ✓对齐 | 0.54x |
| 4 | 64 | 256 | fp16 | triton | 23.5271 | 4833.77 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | fp16 | unfold | 38.714 | 4481.77 | ✓对齐 | 0.61x |
| 4 | 64 | 256 | bf16 | triton | 23.5103 | 4833.77 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | bf16 | unfold | 38.6848 | 4481.77 | ✓对齐 | 0.61x |
| 4 | 64 | 128 | none | triton | 14.6556 | 2530.32 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | none | unfold | 29.7304 | 3330.32 | ✓对齐 | 0.49x |
| 4 | 64 | 128 | fp16 | triton | 14.1473 | 3130.9 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | fp16 | unfold | 24.5036 | 2914.9 | ✓对齐 | 0.58x |
| 4 | 64 | 128 | bf16 | triton | 14.0807 | 3131.73 | 参考(Ref) | 1.00x |
| 4 | 64 | 128 | bf16 | unfold | 24.4947 | 2914.9 | ✓对齐 | 0.57x |
| 4 | 64 | 256 | none | triton | 59.5633 | 10114.3 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | none | unfold | 118.181 | 13314.3 | ✓对齐 | 0.50x |
| 4 | 64 | 256 | fp16 | triton | 57.2382 | 12514.9 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | fp16 | unfold | 97.0677 | 11650.9 | ✓对齐 | 0.59x |
| 4 | 64 | 256 | bf16 | triton | 57.1962 | 12514.9 | 参考(Ref) | 1.00x |
| 4 | 64 | 256 | bf16 | unfold | 97.0336 | 11651.7 | ✓对齐 | 0.59x |
@inproceedings{HDFNet-ECCV2020,
author = {Youwei Pang and Lihe Zhang and Xiaoqi Zhao and Huchuan Lu},
title = {Hierarchical Dynamic Filtering Network for RGB-D Salient Object Detection},
booktitle = ECCV,
year = {2020}
}