diff --git a/src/memory_pool.c b/src/memory_pool.c index ef2c0fa66b..8e9e510fdd 100644 --- a/src/memory_pool.c +++ b/src/memory_pool.c @@ -208,3 +208,10 @@ umf_result_t umfPoolGetTag(umf_memory_pool_handle_t hPool, void **tag) { utils_mutex_unlock(&hPool->lock); return UMF_RESULT_SUCCESS; } + +void *umfPoolGetPoolPriv(umf_memory_pool_handle_t hPool) { + if (hPool == NULL) { + return NULL; + } + return hPool->pool_priv; +} diff --git a/src/memory_pool_internal.h b/src/memory_pool_internal.h index ab3378163d..be68e01c07 100644 --- a/src/memory_pool_internal.h +++ b/src/memory_pool_internal.h @@ -36,6 +36,8 @@ typedef struct umf_memory_pool_t { void *tag; } umf_memory_pool_t; +void *umfPoolGetPoolPriv(umf_memory_pool_handle_t hPool); + #ifdef __cplusplus } #endif diff --git a/src/pool/pool_disjoint.c b/src/pool/pool_disjoint.c index 0bd88bd246..088e5a8b79 100644 --- a/src/pool/pool_disjoint.c +++ b/src/pool/pool_disjoint.c @@ -17,6 +17,7 @@ #include #include "base_alloc_global.h" +#include "memory_pool_internal.h" #include "pool_disjoint_internal.h" #include "provider/provider_tracking.h" #include "uthash/utlist.h" @@ -819,6 +820,11 @@ umf_result_t disjoint_pool_free(void *pool, void *ptr) { return ret; } + if (pool != umfPoolGetPoolPriv(allocInfo.pool)) { + LOG_ERR("pool mismatch"); + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + size_t size = allocInfo.baseSize; umf_memory_provider_handle_t provider = disjoint_pool->provider; ret = umfMemoryProviderFree(provider, ptr, size); diff --git a/src/pool/pool_proxy.c b/src/pool/pool_proxy.c index eedddb0acb..81dced62e3 100644 --- a/src/pool/pool_proxy.c +++ b/src/pool/pool_proxy.c @@ -13,8 +13,10 @@ #include #include "base_alloc_global.h" +#include "memory_pool_internal.h" #include "provider/provider_tracking.h" #include "utils_common.h" +#include "utils_log.h" static __TLS umf_result_t TLS_last_allocation_error; @@ -100,6 +102,11 @@ static umf_result_t proxy_free(void *pool, void *ptr) { umf_alloc_info_t allocInfo = {NULL, 0, NULL}; umf_result_t umf_result = umfMemoryTrackerGetAllocInfo(ptr, &allocInfo); if (umf_result == UMF_RESULT_SUCCESS) { + if (pool != umfPoolGetPoolPriv(allocInfo.pool)) { + LOG_ERR("pool mismatch"); + return UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + size = allocInfo.baseSize; } }