Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/tuning] tune performance on rotor with meta info. #1599

Merged
merged 4 commits into from Sep 15, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Sep 15, 2022

What's new?

  1. Change the backward memory estimation. (dependency calculations, free forward dependency)
  2. Remove decay in the memory budget.
  3. Fix error calculations in Normalizations.
  4. Remove Phase.LOSS in op level estimation. Use torch.autograd.backward() instead.

Tests

Resnet

Model mem_limit real_consumption train step time solver time
<function resnet18 at 0x7fc091b074c0> mem_limit: None real memory consumption: 2818.724 MB train step time: 111.869 MS  
<function resnet18 at 0x7fc091b074c0> mem_limit: 900.0 MB real memory consumption: 900.000 MB train step time: inf MS solver time: 0.000 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 1140.0 MB real memory consumption: 1140.000 MB train step time: inf MS solver time: 0.000 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 1380.0 MB real memory consumption: 1380.000 MB train step time: inf MS solver time: 0.000 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 1620.0 MB real memory consumption: 1620.000 MB train step time: inf MS solver time: 0.000 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 1860.0 MB real memory consumption: 1838.223 MB train step time: 135.848 MS solver time: 305.186 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 2100.0 MB real memory consumption: 1838.223 MB train step time: 134.772 MS solver time: 315.092 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 2340.0 MB real memory consumption: 2230.223 MB train step time: 129.707 MS solver time: 312.411 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 2580.0 MB real memory consumption: 2230.223 MB train step time: 135.638 MS solver time: 341.167 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 2820.0 MB real memory consumption: 2621.724 MB train step time: 128.655 MS solver time: 323.381 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 3060.0 MB real memory consumption: 2817.724 MB train step time: 114.653 MS solver time: 312.263 MS
<function resnet18 at 0x7fc091b074c0> mem_limit: 3300.0 MB real memory consumption: 2817.724 MB train step time: 128.851 MS solver time: 310.696 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: None real memory consumption: 4083.226 MB train step time: 191.019 MS  
<function resnet34 at 0x7fd473b0b700> mem_limit: 1300.0 MB real memory consumption: 1300.000 MB train step time: inf MS solver time: 0.000 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 1650.0 MB real memory consumption: 1650.000 MB train step time: inf MS solver time: 0.000 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 2000.0 MB real memory consumption: 1924.219 MB train step time: 241.586 MS solver time: 880.371 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 2350.0 MB real memory consumption: 2718.221 MB train step time: 229.666 MS solver time: 888.187 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 2700.0 MB real memory consumption: 2672.223 MB train step time: 218.032 MS solver time: 941.039 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 3050.0 MB real memory consumption: 2910.722 MB train step time: 212.783 MS solver time: 887.642 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 3400.0 MB real memory consumption: 3110.725 MB train step time: 213.888 MS solver time: 896.155 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 3750.0 MB real memory consumption: 3502.725 MB train step time: 216.258 MS solver time: 908.215 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 4100.0 MB real memory consumption: 3894.226 MB train step time: 204.810 MS solver time: 938.413 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 4450.0 MB real memory consumption: 4090.226 MB train step time: 199.014 MS solver time: 906.662 MS
<function resnet34 at 0x7fd473b0b700> mem_limit: 4800.0 MB real memory consumption: 4090.226 MB train step time: 199.648 MS solver time: 898.411 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: None real memory consumption: 10707.678 MB train step time: 386.905 MS  
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 3500.0 MB real memory consumption: 3454.336 MB train step time: 516.850 MS solver time: 961.755 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 4430.0 MB real memory consumption: 4323.330 MB train step time: 465.461 MS solver time: 1002.539 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 5360.0 MB real memory consumption: 5304.341 MB train step time: 453.650 MS solver time: 966.002 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 6290.0 MB real memory consumption: 6185.351 MB train step time: 440.260 MS solver time: 976.177 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 7220.0 MB real memory consumption: 6971.352 MB train step time: 445.918 MS solver time: 1002.084 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 8150.0 MB real memory consumption: 7755.355 MB train step time: 426.824 MS solver time: 1005.916 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 9080.0 MB real memory consumption: 8833.359 MB train step time: 405.072 MS solver time: 1016.602 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 10010.0 MB real memory consumption: 9715.365 MB train step time: 390.732 MS solver time: 1025.270 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 10940.0 MB real memory consumption: 10499.366 MB train step time: 373.902 MS solver time: 1012.945 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 11870.0 MB real memory consumption: 10695.366 MB train step time: 382.681 MS solver time: 1015.283 MS
<function resnet50 at 0x7fd473b0b8b0> mem_limit: 12800.0 MB real memory consumption: 10695.366 MB train step time: 361.853 MS solver time: 991.407 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: None real memory consumption: 15769.488 MB train step time: 604.583 MS  
<function resnet101 at 0x7fd473b0ba60> mem_limit: 5200.0 MB real memory consumption: 5335.419 MB train step time: 778.706 MS solver time: 4149.994 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 6570.0 MB real memory consumption: 6509.675 MB train step time: 722.648 MS solver time: 4199.405 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 7940.0 MB real memory consumption: 7777.489 MB train step time: 716.253 MS solver time: 4198.619 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 9310.0 MB real memory consumption: 8806.515 MB train step time: 690.975 MS solver time: 4235.528 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 10680.0 MB real memory consumption: 10374.526 MB train step time: 683.163 MS solver time: 4239.577 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 12050.0 MB real memory consumption: 11550.536 MB train step time: 640.606 MS solver time: 4179.048 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 13420.0 MB real memory consumption: 12432.540 MB train step time: 643.900 MS solver time: 4248.472 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 14790.0 MB real memory consumption: 13904.544 MB train step time: 619.255 MS solver time: 4331.583 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 16160.0 MB real memory consumption: 15176.550 MB train step time: 620.196 MS solver time: 4342.992 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 17530.0 MB real memory consumption: 15766.551 MB train step time: 608.358 MS solver time: 4232.965 MS
<function resnet101 at 0x7fd473b0ba60> mem_limit: 18900.0 MB real memory consumption: 15766.551 MB train step time: 594.879 MS solver time: 4333.046 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: None real memory consumption: 22002.603 MB train step time: 834.303 MS  
<function resnet152 at 0x7fd473b0bc10> mem_limit: 7300.0 MB real memory consumption: 7807.728 MB train step time: 1062.401 MS solver time: 11508.363 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 9210.0 MB real memory consumption: 8817.452 MB train step time: 1030.236 MS solver time: 11649.769 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 11120.0 MB real memory consumption: 10773.290 MB train step time: 993.877 MS solver time: 13641.920 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 13030.0 MB real memory consumption: 12394.288 MB train step time: 975.113 MS solver time: 11768.664 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 14940.0 MB real memory consumption: 14747.311 MB train step time: 955.909 MS solver time: 11839.893 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 16850.0 MB real memory consumption: 16218.329 MB train step time: 925.745 MS solver time: 11916.446 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 18760.0 MB real memory consumption: 18276.339 MB train step time: 898.198 MS solver time: 11576.954 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 20670.0 MB real memory consumption: 19060.342 MB train step time: 874.035 MS solver time: 11497.850 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 22580.0 MB real memory consumption: 21020.352 MB train step time: 845.534 MS solver time: 11645.797 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 24490.0 MB real memory consumption: 22000.353 MB train step time: 780.503 MS solver time: 11448.699 MS
<function resnet152 at 0x7fd473b0bc10> mem_limit: 26400.0 MB real memory consumption: 22000.353 MB train step time: 833.828 MS solver time: 11561.679 MS

Densenet

Model mem_limit real_consumption train step time solver time
<function densenet121 at 0x7f2bc32e79d0> mem_limit: None real memory consumption: 16027.334 MB train step time: 171.072 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 5300.0 MB real memory consumption: 5300.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 6690.0 MB real memory consumption: 6313.646 MB train step time: 208.779 MS solver time: 1582.991 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 8080.0 MB real memory consumption: 7292.647 MB train step time: 204.590 MS solver time: 1605.730 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 9470.0 MB real memory consumption: 9500.286 MB train step time: 195.120 MS solver time: 1625.514 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 10860.0 MB real memory consumption: 10728.410 MB train step time: 190.821 MS solver time: 1571.043 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 12250.0 MB real memory consumption: 11897.869 MB train step time: 186.108 MS solver time: 1640.221 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 13640.0 MB real memory consumption: 12093.869 MB train step time: 185.596 MS solver time: 1590.731 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 15030.0 MB real memory consumption: 14851.113 MB train step time: 175.736 MS solver time: 1633.711 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 16420.0 MB real memory consumption: 16027.334 MB train step time: 171.635 MS solver time: 1585.685 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 17810.0 MB real memory consumption: 16027.334 MB train step time: 171.443 MS solver time: 1645.066 MS
<function densenet121 at 0x7fd473b8ea60> mem_limit: 19200.0 MB real memory consumption: 16027.334 MB train step time: 171.485 MS solver time: 1604.234 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: None real memory consumption: 29896.588 MB train step time: 314.716 MS  
<function densenet161 at 0x7fca84e98c10> mem_limit: 9900.0 MB real memory consumption: 9900.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 12490.0 MB real memory consumption: 12664.061 MB train step time: 378.477 MS solver time: 2287.056 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 15080.0 MB real memory consumption: 15176.380 MB train step time: 370.280 MS solver time: 2256.386 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 17670.0 MB real memory consumption: 16648.384 MB train step time: 364.759 MS solver time: 2228.276 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 20260.0 MB real memory consumption: 20184.692 MB train step time: 349.334 MS solver time: 2290.359 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 22850.0 MB real memory consumption: 22533.892 MB train step time: 344.651 MS solver time: 2270.478 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 25440.0 MB real memory consumption: 24007.893 MB train step time: 337.447 MS solver time: 2241.896 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 28030.0 MB real memory consumption: 27536.146 MB train step time: 325.779 MS solver time: 2304.517 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 30620.0 MB real memory consumption: 29896.588 MB train step time: 316.225 MS solver time: 2310.656 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 33210.0 MB real memory consumption: 29896.588 MB train step time: 315.438 MS solver time: 2251.282 MS
<function densenet161 at 0x7fca84e98c10> mem_limit: 35800.0 MB real memory consumption: 29896.588 MB train step time: 315.640 MS solver time: 2289.473 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: None real memory consumption: 19385.591 MB train step time: 209.027 MS  
<function densenet169 at 0x7f986eeeedc0> mem_limit: 6400.0 MB real memory consumption: 6400.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 8080.0 MB real memory consumption: 8189.635 MB train step time: 249.236 MS solver time: 2259.276 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 9760.0 MB real memory consumption: 9760.762 MB train step time: 245.508 MS solver time: 2259.876 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 11440.0 MB real memory consumption: 10542.986 MB train step time: 241.924 MS solver time: 2251.288 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 13120.0 MB real memory consumption: 13104.621 MB train step time: 232.036 MS solver time: 2299.600 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 14800.0 MB real memory consumption: 14670.867 MB train step time: 226.386 MS solver time: 2305.503 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 16480.0 MB real memory consumption: 15453.188 MB train step time: 223.063 MS solver time: 2266.478 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 18160.0 MB real memory consumption: 17911.975 MB train step time: 214.395 MS solver time: 2222.649 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 19840.0 MB real memory consumption: 19385.591 MB train step time: 208.736 MS solver time: 2318.853 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 21520.0 MB real memory consumption: 19385.591 MB train step time: 209.141 MS solver time: 2297.614 MS
<function densenet169 at 0x7f986eeeedc0> mem_limit: 23200.0 MB real memory consumption: 19385.591 MB train step time: 209.131 MS solver time: 2258.708 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: None real memory consumption: 24995.910 MB train step time: 264.470 MS  
<function densenet201 at 0x7f986eeeef70> mem_limit: 8300.0 MB real memory consumption: 8300.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 10460.0 MB real memory consumption: 10460.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 12620.0 MB real memory consumption: 12586.856 MB train step time: 338.658 MS solver time: 2886.305 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 14780.0 MB real memory consumption: 14785.701 MB train step time: 301.186 MS solver time: 2911.516 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 16940.0 MB real memory consumption: 16157.457 MB train step time: 298.102 MS solver time: 2925.236 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 19100.0 MB real memory consumption: 19005.433 MB train step time: 286.495 MS solver time: 2994.757 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 21260.0 MB real memory consumption: 20871.377 MB train step time: 279.330 MS solver time: 2896.458 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 23420.0 MB real memory consumption: 23233.352 MB train step time: 271.682 MS solver time: 2870.894 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 25580.0 MB real memory consumption: 24998.609 MB train step time: 264.336 MS solver time: 2875.226 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 27740.0 MB real memory consumption: 24998.609 MB train step time: 264.623 MS solver time: 2856.833 MS
<function densenet201 at 0x7f986eeeef70> mem_limit: 29900.0 MB real memory consumption: 24998.609 MB train step time: 264.331 MS solver time: 2857.020 MS

@@ -50,7 +54,7 @@ def _is_sink() -> bool:
bool
"""

return not sum([v for _, v in deps.items()])
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a simple example here to show the different between new linearize and older version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[input] 15 15 15
[conv1] 78 78 78
[bn1, relu] 78 78 78
[maxpool] 20 78 59
[layer1_0_conv1, layer1_0_bn1, layer1_0_relu, layer1_0_conv2, layer1_0_bn2, add, layer1_0_relu_1] 20 78 59
[layer1_1_conv1, layer1_1_bn1, layer1_1_relu, layer1_1_conv2, layer1_1_bn2, add_1, layer1_1_relu_1] 20 78 39
[layer2_0_conv1, layer2_0_bn1, layer2_0_relu, layer2_0_conv2, layer2_0_bn2, layer2_0_downsample_0, layer2_0_downsample_1, add_2, layer2_0_relu_1] 10 49 30
[layer2_1_conv1, layer2_1_bn1, layer2_1_relu, layer2_1_conv2, layer2_1_bn2, add_3, layer2_1_relu_1] 10 39 20
[layer3_0_conv1, layer3_0_bn1, layer3_0_relu, layer3_0_conv2, layer3_0_bn2, layer3_0_downsample_0, layer3_0_downsample_1, add_4, layer3_0_relu_1] 5 25 15
[layer3_1_conv1, layer3_1_bn1, layer3_1_relu, layer3_1_conv2, layer3_1_bn2, add_5, layer3_1_relu_1] 5 20 10
[layer4_0_conv1, layer4_0_bn1, layer4_0_relu, layer4_0_conv2, layer4_0_bn2, layer4_0_downsample_0, layer4_0_downsample_1, add_6, layer4_0_relu_1] 3 13 12
[layer4_1_conv1, layer4_1_bn1, layer4_1_relu, layer4_1_conv2, layer4_1_bn2, add_7, layer4_1_relu_1] 0 8 3
[avgpool] 0 0 1
[flatten] 1 1 1
[fc] 1 1 0
[input] 15 15 15
[conv1] 78 78 78
[bn1] 0 0 0
[relu] 78 78 78
[maxpool] 20 78 78
[layer1_0_conv1, layer1_0_bn1, layer1_0_relu, layer1_0_conv2, layer1_0_bn2, add] 0 58 0
[layer1_0_relu_1] 20 20 78
[layer1_1_conv1, layer1_1_bn1, layer1_1_relu, layer1_1_conv2, layer1_1_bn2, add_1] 0 58 0
[layer1_1_relu_1] 20 20 49
[layer2_0_conv1, layer2_0_bn1, layer2_0_relu, layer2_0_conv2, layer2_0_bn2, layer2_0_downsample_0, layer2_0_downsample_1, add_2] 0 39 0
[layer2_0_relu_1] 10 10 39
[layer2_1_conv1, layer2_1_bn1, layer2_1_relu, layer2_1_conv2, layer2_1_bn2, add_3] 0 29 0
[layer2_1_relu_1] 10 10 25
[layer3_0_conv1, layer3_0_bn1, layer3_0_relu, layer3_0_conv2, layer3_0_bn2, layer3_0_downsample_0, layer3_0_downsample_1, add_4] 0 20 0
[layer3_0_relu_1] 5 5 20
[layer3_1_conv1, layer3_1_bn1, layer3_1_relu, layer3_1_conv2, layer3_1_bn2, add_5] 0 15 0
[layer3_1_relu_1] 5 5 13
[layer4_0_conv1, layer4_0_bn1, layer4_0_relu, layer4_0_conv2, layer4_0_bn2, layer4_0_downsample_0, layer4_0_downsample_1, add_6] 0 10 0
[layer4_0_relu_1] 3 3 12
[layer4_1_conv1, layer4_1_bn1, layer4_1_relu, layer4_1_conv2, layer4_1_bn2, add_7] 0 8 0
[layer4_1_relu_1] 0 0 3
[avgpool] 0 0 1
[flatten] 1 1 1
[fc] 1 1 0

Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Work!

@@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None,
eps: float = 0.02) -> ColoGraphModule:
eps: float = 0.0) -> ColoGraphModule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should eps be a very small but non-zero value? e.g. 1e-6

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so 0.0 means no memory decay?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the default setting is 0.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the memory decay is calculated by $M(1 - \epsilon)$, maybe the variable name is not that appropriate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the eps will be something around 0.05 or less, 1e-6 is too small as the memory will be discretized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually decay is unnecessary if i can estimate the memory accurately.
This can be removed in future if I have tested performance of all models

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eps is ok, you can just provide the equation for memory decay in line 338 to explain how eps affect memory decay.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this option could be provided for user as we might not be able to catch up with all the models in reality, so there might be some cases our meta info provides bad estimations. With this option the user might be able to tune the solver if necessary.

@super-dainiu super-dainiu merged commit cd5cf2b into hpcaitech:main Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants