Skip to content

Improve TreeNodeElementId hash function#16459

Merged
xadupre merged 10 commits intomicrosoft:mainfrom
lhrios:luisrios/improve-tree-node-element-id-hash
Jul 25, 2023
Merged

Improve TreeNodeElementId hash function#16459
xadupre merged 10 commits intomicrosoft:mainfrom
lhrios:luisrios/improve-tree-node-element-id-hash

Conversation

@lhrios
Copy link
Contributor

@lhrios lhrios commented Jun 22, 2023

Description

This PR improves TreeNodeElementId hash function by employing Elegant Pairing function. In few works, Elegant Pairing function maps two non−negative integers to a non−negative integer that is uniquely associated with that pair. This drastically reduces the collision and therefore reduces the time required to create a session in order to use a large tree ensemble model.

Motivation and Context

We use ONNX runtime to serve our models as part of Triton backend. We noticed that it was taking around 2 minutes to load a model which is a large tree ensemble model (around 5k trees with around 3 millions nodes in total). After investigating the issue, it was clear that the TreeNodeElementId hash function wasn't being able to map keys to buckets of C++ unordered_map without a significant amount of collisions (in same cases 700 items per bucket).

The following picture shows graphically the improvement obtained by the proposed change. We used the onnx_test_runner command.
flamegraph

Before

$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m55.695s
user	0m52.919s
sys	0m0.760s

After

$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m17.152s
user	0m14.318s
sys	0m0.619s

@lhrios lhrios marked this pull request as draft June 23, 2023 05:42
@lhrios lhrios marked this pull request as ready for review June 24, 2023 02:50
@lhrios
Copy link
Contributor Author

lhrios commented Jun 24, 2023

@microsoft-github-policy-service agree company="Block"

@lhrios lhrios marked this pull request as draft June 27, 2023 19:29
@lhrios lhrios marked this pull request as ready for review June 27, 2023 19:30
u_int64_t combined_id;
// Use absolute value, before casting, to reduce overflow chance if any of the values is negative.
// In the worst (and also unlike) case, we will map 4 pairs of tree_id and node_id to the same value
u_int64_t x = static_cast<u_int64_t>(std::abs(key.tree_id));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the overflow exactly what you want here, though?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is a good point. I asked myself the same question. However, if we ignore the overflow caused by casting a signed integer to an unsigned one we would have a decrease in the spread of values into different buckets. By using abs we create a more "predictable" scenario.

To test it, I created a small program to compare all three hash functions (the current one, Elegant Pairing with abs and Elegant Pairing without abs). It generates hash values for each pair of integers inside the interval. Those were the results:

[-10000, 10000] Interval (mixing negatives and non-negatives numbers)

Elegant Pairing with "abs"

[-10000, 10000] [-10000, 10000] 1 1

ID's pair count: 400040001
distinct hash values produced: 100020001
colision count: 400040000

there are 100000000 buckets with 4 elements
there are 20000 buckets with 2 elements
there are 1 buckets with 1 elements

Elegant Pairing without "abs"

[-10000, 10000] [-10000, 10000] 0 1

ID's pair count: 400040001
distinct hash values produced: 100030001
collision count: 400030000

there are 33336668 buckets with 2 elements
there are 33336668 buckets with 3 elements
there are 13334668 buckets with 4 elements
there are 6667336 buckets with 5 elements
there are 3809907 buckets with 6 elements
...
there are 142 buckets with 279 elements
there are 141 buckets with 278 elements
there are 141 buckets with 280 elements
there are 141 buckets with 276 elements
there are 120 buckets with 282 elements
printed first and last 5 buckets with more than one element (still 272 remaining of 282)

Current implementation

[-10000, 10000] [-10000, 10000] 0 0

ID's pair count: 400040001
distinct hash values produced: 32768
collision count: 400040001

there are 16384 buckets with 7234 elements
there are 12288 buckets with 16384 elements
there are 3584 buckets with 19522 elements
there are 448 buckets with 19968 elements
there are 32 buckets with 19970 elements
there are 31 buckets with 20000 elements
there are 1 buckets with 20001 elements

[2^60, 2^60 + 5000] Interval (inevitable overflow)

Note that even in a scenario where overflows are inevitable (if ID's are big), Elegant Pairing is still able to better spread the values into different buckets.

Elegant Pairing with "abs"

[1152921504606846976, 1152921504606851976] [1152921504606846976, 1152921504606851976] 1 1

ID's pair count: 25010001
distinct hash values produced: 25010001
collision count: 0

there are 25010001 buckets with 1 elements

Current implementation

[1152921504606846976, 1152921504606851976] [1152921504606846976, 1152921504606851976] 0 0

ID's pair count: 25010001
distinct hash values produced: 8192
collision count: 25010001

there are 4096 buckets with 1810 elements
there are 3072 buckets with 4096 elements
there are 896 buckets with 4882 elements
there are 112 buckets with 4992 elements
there are 8 buckets with 4994 elements
there are 7 buckets with 5000 elements
there are 1 buckets with 5001 elements

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source code used to produce the results reported above:

#include <cstdio>
#include <cstdint>
#include <cstring>
#include <unordered_map>
#include <queue>
#include <vector>
#include <map>

using namespace std;

// g++ -O2 -std=c++20 elegant_paring.cc

size_t elegant_pairing_hash(int64_t x0, int64_t y0, bool use_abs) {
  u_int64_t combined_id;

  u_int64_t x;
  u_int64_t y; 

  if (use_abs) {
    x = static_cast<u_int64_t>(abs(x0));
    y = static_cast<u_int64_t>(abs(y0));
  } else {
    x = static_cast<u_int64_t>(x0);
    y = static_cast<u_int64_t>(y0);
  }

  if (x >= y) {
    combined_id = x * x + (x + y);
  } else {
    combined_id = y * y + x;
  }
  return hash<u_int64_t>()(combined_id);
}

size_t current_hash(int64_t x0, int64_t y0) {
  size_t h1 = hash<int64_t>()(x0);
  size_t h2 = hash<int64_t>()(y0);
  return h1 ^ h2;
}

int main(int argc, char** argv) {
  if (argc != 7) {
    printf("[x_begin, x_end] [y_begin, y_end] hash_name use_abs\n");
  }

  int64_t x_begin;
  int64_t x_end;
  int64_t y_begin;
  int64_t y_end;

  sscanf(argv[1], "%lld", &x_begin);
  sscanf(argv[2], "%lld", &x_end);
  sscanf(argv[3], "%lld", &y_begin);
  sscanf(argv[4], "%lld", &y_end);

  bool use_elegant_pairing_hash = strcmp(argv[5], "true") == 0;
  bool use_abs = strcmp(argv[6], "true") == 0;

  printf("[%lld, %lld] [%lld, %lld] %d %d\n\n", x_begin, x_end, y_begin, y_end, use_abs, use_elegant_pairing_hash);

  if (x_begin > x_end ) {
    swap(x_begin, x_end);
  }

  if (y_begin > y_end ) {
    swap(y_begin, y_end);
  }

  unordered_map<size_t, u_int64_t> count_by_hash_value;
  u_int64_t pair_count = 0;
  u_int64_t colision_count = 0;

  for (int64_t x = x_begin; x <= x_end; x++) {
    for (int64_t y = y_begin; y <= y_end; y++) {
      std:size_t hash;
      if (use_elegant_pairing_hash) {
        hash = elegant_pairing_hash(x, y, use_abs);  
      } else {
        hash = current_hash(x, y);
      }

      if (count_by_hash_value.contains(hash)) {
        count_by_hash_value[hash] = count_by_hash_value[hash] + 1;
      
      } else {
        count_by_hash_value[hash] = 1;
      }
      
      //printf("%lu (%lX) <- %lld %lld\n", hash, hash, x, y);

      pair_count++;
    }
  }

  u_int64_t max = 0;
  map<size_t, u_int64_t> bucket_count_by_size;

  for (const pair<u_int64_t, size_t>& pair : count_by_hash_value) {
    u_int64_t bucket_size = pair.second;
    if (bucket_size > 1) {
      colision_count += bucket_size;
    }
    if (bucket_size > max) {
      max = bucket_size;
    }

    if (bucket_count_by_size.contains(bucket_size)) {
      bucket_count_by_size[bucket_size] = bucket_count_by_size[bucket_size] + 1;
    } else {
      bucket_count_by_size[bucket_size] = 1;
    }
  }
  printf("ID's pair count: %llu\n", pair_count);
  printf("distinct hash values produced: %lu\n", count_by_hash_value.size());
  printf("collision count: %llu\n", colision_count);
  printf("\n");

  vector<pair<size_t, u_int64_t>> sorted_bucket_count_by_size;
  for (auto& it : bucket_count_by_size) {
    sorted_bucket_count_by_size.push_back(it);
  }
 
   sort(sorted_bucket_count_by_size.begin(), sorted_bucket_count_by_size.end(), [&](auto a, auto b) {return a.second > b.second;});

  constexpr size_t bucket_limit = 5;
  size_t bucket_count = 0;
  for (const auto& [bucket_size, count] : sorted_bucket_count_by_size) {
    if (bucket_count < bucket_limit || sorted_bucket_count_by_size.size() - bucket_limit <= bucket_count) {
      if (bucket_limit * 2 < bucket_count_by_size.size() && sorted_bucket_count_by_size.size() - bucket_limit == bucket_count) {
        printf("...\n");
      }
      printf("there are %llu buckets with %lu elements\n", count, bucket_size);
    }
    ++bucket_count;
  }
  if (bucket_limit * 2 < bucket_count_by_size.size()) {
    printf("printed first and last %lu buckets with more than one element (still %lu remaining of %lu)\n", bucket_limit, bucket_count_by_size.size() - bucket_limit * 2, bucket_count_by_size.size());      
  }

  return 0;
}

Copy link
Contributor

@cbourjau cbourjau Jul 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the very detailed follow-up!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am leaving this comment open since it seems a good discussion for future readers.

@lhrios lhrios requested a review from cbourjau July 4, 2023 16:11
std::size_t operator()(const TreeNodeElementId& key) const {
std::size_t h1 = std::hash<int64_t>()(key.tree_id);
std::size_t h2 = std::hash<int64_t>()(key.node_id);
return h1 ^ h2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the large number of collisions came from the fact that this is symmetrical. Doesn't everything just work if we do h1 ^ (h2 <<1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the root cause here is related with the fact that both id's (tree_id's and node_id's) for a given TreeEnsembleClassifier node are chosen sequentially during the conversion to ONNX. Also, for node_id's we can have the same value being reused across different trees ("Ids may restart at zero for each tree, but it not required to" from nodes_nodeids attribute description). I confirmed that this is happening for our model, i.e. there are repetitions of values on nodes_nodeids array.

In that case, the difference between the hash of each element of the pair end up being of few bits. After applying the XOR operator, since few bits are different, the number of collisions is high. I was able to find some examples (tables below).

Since we don't have many guarantees on the values chosen for tree_id's and node_id's (besides each pair being unique), I believe it's hard to beat Elegant Pairing in terms of collision avoidance. Even though Elegant Pairing assumes two positive integers, the edge case handling showed pretty good results.

Using (h1 ^ h2)

All have the same hash value equals to 0xF:

  1 E F | 10 1F F | 20 2F F |     | 50 5F F                               
  2 D F | 11 1E F | 21 2E F |     | 51 5E F                               
  3 C F | 12 1D F | 22 2D F |     | 52 5D F                               
  4 B F | 13 1C F | 23 2C F |     | 53 5C F                               
  5 A F | 14 1B F | 24 2B F |     | 54 5B F                               
  6 9 F | 15 1A F | 25 2A F |     | 55 5A F                               
  7 8 F | 16 19 F | 26 29 F |     | 56 59 F                               
  8 7 F | 17 18 F | 27 28 F |     | 57 58 F                               
  9 6 F | 18 17 F | 28 27 F | ... | 58 57 F                               
  A 5 F | 19 16 F | 29 26 F |     | 59 56 F                               
  B 4 F | 1A 15 F | 2A 25 F |     | 5A 55 F                               
  C 3 F | 1B 14 F | 2B 24 F |     | 5B 54 F                               
  D 2 F | 1C 13 F | 2C 23 F |     | 5C 53 F                               
  E 1 F | 1D 12 F | 2D 22 F |     | 5D 52 F                               
        | 1E 11 F | 2E 21 F |     | 5E 51 F                               
        | 1F 10 F | 2F 20 F |     | 5F 50 F

Legend:

  • 1st column: the hash value for tree_id (starting from 1)
  • 2nd column: the hash value for the node_id (starting from 1)
  • 3rd column: (1st ^ 2nd)

Using (h1 ^ (h2 << 1))

All have the same hash value equals to 0x1:

 3 1 1 | 13 9 1 | 23 11 1 | 33 19 1 | 43 21 1 |     | 3DB 1ED 1 |
 7 3 1 | 17 B 1 | 27 13 1 | 37 1B 1 | 47 23 1 | ... | 3DF 1EF 1 |
 B 5 1 | 1B D 1 | 2B 15 1 | 3B 1D 1 | 4B 25 1 |     | 3E3 1F1 1 |
 F 7 1 | 1F F 1 | 2F 17 1 | 3F 1F 1 | 4F 27 1 |     | 3E7 1F3 1 |

Legend:

  • 1st column: the hash value for tree_id (starting from 1)
  • 2nd column: the hash value for the node_id (starting from 1)
  • 3rd column: (1st ^ (2nd << 1))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of completeness, I've repeated the experiments for the variation you proposed h1 ^ (h2 << 1).

[-10000, 10000] Interval

ID's pair count: 400040001
distinct hash values produced: 57344
collision count: 400040001

there are 16384 buckets with 10001 elements
there are 16384 buckets with 10000 elements
there are 16384 buckets with 3617 elements
there are 2048 buckets with 1809 elements
there are 2048 buckets with 1808 elements
...
there are 512 buckets with 1569 elements
there are 415 buckets with 1536 elements
there are 33 buckets with 1537 elements
there are 32 buckets with 1553 elements
there are 32 buckets with 1552 elements
printed first and last 5 buckets with more than one element (still 2 remaining of 12)

[2^60, 2^60 + 5000] Interval

ID's pair count: 25010001
distinct hash values produced: 14336
collision count: 25010001

there are 4096 buckets with 2501 elements
there are 4096 buckets with 2500 elements
there are 4096 buckets with 905 elements
there are 512 buckets with 453 elements
there are 512 buckets with 452 elements
...
there are 128 buckets with 393 elements
there are 103 buckets with 384 elements
there are 9 buckets with 385 elements
there are 8 buckets with 389 elements
there are 8 buckets with 388 elements
printed first and last 5 buckets with more than one element (still 2 remaining of 12)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for the tests! This (hashing tuples without a lot of collisions) should be a solved problem; C++ might just make it hard to find. onnxruntime already depends on abseil in various places. We could consider using their Hash (which also works for tuples) here. The produced hashes seem nicely distributed: https://godbolt.org/z/qGsK6jj9a

Somewhat unrelated, but it might actually be worth generally using the abseil's hashmap rather than the unordered_map of the standard library. This won't save us from defining a hash function for TreeNodeElementId, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! That is true, I missed it.

I've changed the code accordingly. The results seem equivalent.

$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m18.937s
user	0m18.284s
sys	0m0.596s

(Time of implementation with unordered_map and absl::Hash)

I've also tested with absl::flat_hash_map but the results were slightly worse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 3 million nodes spread through 5 thousand trees. I used 2^60 just to show that the proposed hash function would behave well even for large ID's that would cause it to overflow.

The solution you proposed works well for us. I proposed Elegant pairing as it's a more generic approach capable of handling different ID allocation strategies since we have few guarantees regarding it. My understanding is we can count only on having each pair of ID's (tree and node) unique within int64_t x int64_t range. In other words, since it would reduce any possible impacts, I thought it would help approving this PR 🙂 .

But I agree assuming 32 bits from hashing perspective is reasonable since it seems converters start allocation of ID's from 1 and few models would have more than 4 billion nodes. I will proceed with the required changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, the std hash function does not (necessarily) have a uniform distribution. It seems to be a no-op in our case: https://godbolt.org/z/4675qPeYd

Either way, I was hoping there was a quick and simple way that Just Works. In its absence, I think either the current or elegant pairing is just fine :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems std::hash is optimized for unordered_map and std::hash(i) returns i. We can even remove it in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed accordingly 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unresolved for future readers

@lhrios lhrios requested review from cbourjau and xadupre July 17, 2023 16:33
Copy link
Contributor

@cbourjau cbourjau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me, but I am not officially part of this project. Also: Thanks a lot! I'm looking forward to the improvements from this PR!

@lhrios
Copy link
Contributor Author

lhrios commented Jul 21, 2023

Looks good to me, but I have not official part in this project. Also: Thanks a lot! I'm looking forward to the improvements from this PR!

Thanks for all the help and feedback throughout the process! Also looking forward to use a new version onnxruntime with this change. 🙂

One last point: it seems some of the checks have stuck and might need to be restarted.

@xadupre
Copy link
Member

xadupre commented Jul 21, 2023

/azp run Linux CPU CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@xadupre
Copy link
Member

xadupre commented Jul 21, 2023

/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, MacOS CI Pipeline, Linux QNN CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@henryhu666
Copy link

Will this be merged and released as part of the next ONNX release cycle in August? As @lhrios mentioned this would greatly reduce the time spent on loading a large tree model

@xadupre
Copy link
Member

xadupre commented Jul 25, 2023

Sorry for the delay, I need to manually run all the remaining CI tests but one was failing. I restarted it.

@xadupre
Copy link
Member

xadupre commented Jul 25, 2023

/azp run Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@xadupre
Copy link
Member

xadupre commented Jul 25, 2023

/azp run ONNX Runtime Web CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline

@xadupre
Copy link
Member

xadupre commented Jul 25, 2023

/azp run orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

@xadupre xadupre merged commit feeb0b5 into microsoft:main Jul 25, 2023
jchen351 pushed a commit that referenced this pull request Aug 12, 2023
### Description
This PR improves `TreeNodeElementId` hash function by employing [Elegant
Pairing function](http://szudzik.com/ElegantPairing.pdf). In few works,
Elegant Pairing function maps two non−negative integers to a
non−negative integer that is uniquely associated with that pair. This
drastically reduces the collision and therefore reduces the time
required to create a session in order to use a large tree ensemble
model.

### Motivation and Context
We use ONNX runtime to serve our models as part of Triton backend. We
noticed that it was taking around 2 minutes to load a model which is a
large tree ensemble model (around 5k trees with around 3 millions nodes
in total). After investigating the issue, it was clear that the
`TreeNodeElementId` hash function wasn't being able to map keys to
buckets of C++ `unordered_map` without a significant amount of
collisions (in same cases 700 items per bucket).

The following picture shows graphically the improvement obtained by the
proposed change. We used the `onnx_test_runner` command.

![flamegraph](https://github.com/microsoft/onnxruntime/assets/3594678/2588e87c-125b-4a4b-8f03-55e00ae25e08)

#### Before
```
$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m55.695s
user	0m52.919s
sys	0m0.760s
```

#### After
```
$> time ./onnx_test_runner -v ~/folder_with_model
result:
	Models: 1
	Total test cases: 0
		Succeeded: 0
		Not implemented: 0
		Failed: 0
	Stats by Operator type:
		Not implemented(0):
		Failed:
Failed Test Cases:

real	0m17.152s
user	0m14.318s
sys	0m0.619s
```
@sah-git
Copy link

sah-git commented May 7, 2024

Hello all,

I am investigating a serious performance issue in my application using onnxruntime 1.17.1. By debugging the code, and doing some investigation, I ended up into this project and its associated commit, feeb0b5. I've also written a comment there.

I did some further investigation and I think I found the problem and a possible solution. Let me share it here. I would like to know your opinion, and if possible to plan a fix in upcoming releases (if I'm not wrong).

My problem is the following. I'm loading a random forest regression model from disk. It is a small one, roughly 466 trees and 480 nodes/tree. It takes around 1 minute to load in my laptop. I thought that it was an excessively large load time for such a small model. So I debugged the library. I found that 90% of the time was spent at

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L310

which is just a call to std::unordered_map find method.

auto found = node_tree_ids_map.find(ind);

The node_tree_ids_map object uses TreeNodeElementId as unique key, combining the tree and the node Ids in a structure. Its definition is in

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L220

std::unordered_map<TreeNodeElementId, size_t, TreeNodeElementId::hash_fn> node_tree_ids_map;

Being an std::unordered_map object. This bad performance of the find method suggests that the hashing method employed, TreeNodeElementId::hash_fn, is not adequate for my map.

node_tree_ids_mapcan be defined by the following parameters:

  • T: number of trees
  • Nt: number of nodes per tree
  • N: total number of nodes
  • maxNid: Maximum Id of nodes (node ids might repeat across different trees)
  • B: Number of buckets in std::unordered_map
  • LF: load factor, N/B.

In my example I have

  • T=466
  • Nt=485 (roughly 485 per tree), numbering is reset to 0 for each tree, so node_ids are repeated.
  • N=218007
  • B=262144
  • LF=0.831630707

Theoretically, for efficient operations in std::unordered_map object, we would need

0.5 < LF < 1

and a hash function that provides a low number of collisions plus and range within the number of buckets.

maxHash < B

The implemented hash function TreeNodeElementId::hash_fn (version 1.17.1) is in

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h#L24

struct hash_fn {
    std::size_t operator()(const TreeNodeElementId& key) const {
        return static_cast<std::size_t>(static_cast<uint64_t>(key.tree_id) << 32 | static_cast<uint64_t>(key.node_id));
    }
};

where there is a bitwise shift of a hardcoded value of 32 bits, this implies a maximum hash number of

maxHash = T*2^32 + Nt = 2001454760421

Which exceeds my B by large. Internally, std::unordered_map will most likely perform a modulus operation on the hash, Hash % B. And therefore, we will be losing all the significant bits corresponding to the tree Ids, producing too many collisions.

So, if I'm not wrong, the problem with this implementation is that it uses a fixed number of 32 bits for the bitwise shift, which produces very large hashes. These might work for large maps. But it might be very inefficient for small ones.

One possible solution could be to define the hash structure with a constructor accepting the max node Id number. Then, the maximum size in bits for representing that number can be computed and used for the binary shift as follows.

 struct hash_fn {
    hash_fn() = default;
    hash_fn(size_t maxNodeIdIndex) {
      _shiftBits = static_cast<uint64_t>(std::ceil( log(maxNodeIdIndex)/log(2) ));
      _shiftBits = _shiftBits > 32 ? 32 : _shiftBits;
    }

    std::size_t operator()(const TreeNodeElementId& key) const {
      return static_cast<std::size_t>(static_cast<uint64_t>(key.tree_id) << _shiftBits | static_cast<uint64_t>(key.node_id));
    }

    uint64_t _shiftBits = 32;
};

This way, the binary shift could be adapted to the specific size of the map, and the hash function could perform well in all situations.

Later, we would need to replace in

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L220

the following code

size_t maxNodeIdIndex = static_cast<size_t>(*max_element(nodes_nodeids.begin(), nodes_nodeids.end()));
std::unordered_map<TreeNodeElementId, size_t, TreeNodeElementId::hash_fn> node_tree_ids_map(10, TreeNodeElementId::hash_fn(maxNodeIdIndex));

By doing these changes my example now loads in 6s instead of 60s. That's a speed up of 10x.

What do you think? Do you think that this fix (or similar) could be added to next onnxruntime version? I would really appreciate it. Thank you!

@xadupre
Copy link
Member

xadupre commented May 7, 2024

That's a good catch. The branch for onnxruntime 1.18 has already been created. So it would be for the next one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants