Skip to content

Commit

Permalink
Enhance the CachingMemoryManager to possibly handle high memory pressure
Browse files Browse the repository at this point in the history
Enhance the CachingMemoryManager to possibly handle high memory
pressure by adding 2 options.
Note: the default behavior is unchanged.
  • Loading branch information
WilliamTambellini committed Oct 25, 2020
1 parent 8508fb9 commit ed35c16
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 12 deletions.
6 changes: 3 additions & 3 deletions app/asr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ add_executable(Train ${CMAKE_CURRENT_LIST_DIR}/Train.cpp)
add_executable(Test ${CMAKE_CURRENT_LIST_DIR}/Test.cpp)
add_executable(Decoder ${CMAKE_CURRENT_LIST_DIR}/Decode.cpp)

target_link_libraries(Train flashlight-app-asr)
target_link_libraries(Test flashlight-app-asr)
target_link_libraries(Decoder flashlight-app-asr)
target_link_libraries(Train flashlight-app-asr ${CMAKE_DL_LIBS})
target_link_libraries(Test flashlight-app-asr ${CMAKE_DL_LIBS})
target_link_libraries(Decoder flashlight-app-asr ${CMAKE_DL_LIBS})

# --------------------------- Tests ---------------------------

Expand Down
10 changes: 6 additions & 4 deletions ext/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ cmake_minimum_required(VERSION 3.5.1)
set(DIR ${CMAKE_CURRENT_LIST_DIR})
set(LIBS flashlight)

build_test(
${DIR}/common/SequentialBuilderTest.cpp
${LIBS}
"ARCHDIR=\"${DIR}/common/\""
if(FL_BUILD_CONTRIB)
build_test(
${DIR}/common/SequentialBuilderTest.cpp
${LIBS}
"ARCHDIR=\"${DIR}/common/\""
)
endif()

add_library(test_module_plugin MODULE
${DIR}/common/test_module_plugin.cpp)
Expand Down
19 changes: 17 additions & 2 deletions flashlight/memory/managers/CachingMemoryManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

#include <algorithm>
#include <iostream>
#include <limits>
#include <mutex>
#include <numeric>
#include <stdexcept>
#include <string>
#include <vector>
Expand Down Expand Up @@ -87,6 +89,14 @@ CachingMemoryManager::CachingMemoryManager(

void CachingMemoryManager::initialize() {}

void CachingMemoryManager::setRecyclingSizeLimit(size_t limit) {
recyclingSizeLimit_ = limit;
}

void CachingMemoryManager::setSplitSizeLimit(size_t limit) {
splitSizeLimit_ = limit;
}

void CachingMemoryManager::shutdown() {
signalMemoryCleanup();
}
Expand Down Expand Up @@ -126,7 +136,9 @@ void* CachingMemoryManager::alloc(

CachingMemoryManager::Block* block = nullptr;
auto it = pool.lower_bound(&searchKey);
if (it != pool.end()) {
// Recycle blocks if any found, and if small alloc or the block size is not
// too large:
if (it != pool.end() && (isSmallAlloc || (*it)->size_ < recyclingSizeLimit_)) {
block = *it;
pool.erase(it);
memoryInfo.stats_.cachedBytes_ -= block->size_;
Expand All @@ -144,7 +156,10 @@ void* CachingMemoryManager::alloc(
// implementation simple.
CachingMemoryManager::Block* remaining = nullptr;
size_t diff = block->size_ - size;
if (diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) {
if ((diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) &&
(block->size_ < splitSizeLimit_) // possibly dont split large buffers to
// minimize risk of fragmentation
) {
remaining = block;
block = new Block(size, block->ptr_);
block->prev_ = remaining->prev_;
Expand Down
15 changes: 15 additions & 0 deletions flashlight/memory/managers/CachingMemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

#pragma once

#include <atomic>
#include <limits>
#include <memory>
#include <mutex>
#include <numeric>
#include <set>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -48,6 +51,9 @@ class CachingMemoryManager : public MemoryManagerAdapter {
bool jitTreeExceedsMemoryPressure(size_t bytes) override;
void addMemoryManagement(int device) override;
void removeMemoryManagement(int device) override;
// Set runtime options: RecyclingSizeLimit, SplitSizeLimit, ... Warning: not thread safe
void setRecyclingSizeLimit(size_t);
void setSplitSizeLimit(size_t);

// Block denotes a single allocated unit of memory.
struct Block {
Expand Down Expand Up @@ -133,6 +139,15 @@ class CachingMemoryManager : public MemoryManagerAdapter {

void tryMergeBlocks(Block* dst, Block* src, BlockSet& freeBlocks);
void freeBlock(Block* block);

private:
// Non-const runtime options in order to fine tune the behavior of this
// manager. Prevents to recycle some buffers, to be set by the user if
// desired:
size_t recyclingSizeLimit_{std::numeric_limits<size_t>::max()};
//size_t recyclingSizeLimit;
// Prevents to split big buffers, to be set by the user if desired:
size_t splitSizeLimit_{std::numeric_limits<size_t>::max()};
};

} // namespace fl
55 changes: 52 additions & 3 deletions flashlight/test/memory/CachingMemoryManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
class CachingMemoryManagerTest : public ::testing::Test {
protected:
virtual void SetUp() override {
auto deviceInterface_ =
std::make_shared<fl::MemoryManagerDeviceInterface>();
auto adapter_ = std::make_shared<fl::CachingMemoryManager>(
deviceInterface_ = std::make_shared<fl::MemoryManagerDeviceInterface>();
adapter_ = std::make_shared<fl::CachingMemoryManager>(
af::getDeviceCount(), deviceInterface_);
installer_ = fl::cpp::make_unique<fl::MemoryManagerInstaller>(adapter_);
installer_->setAsMemoryManager();
Expand All @@ -32,6 +31,8 @@ class CachingMemoryManagerTest : public ::testing::Test {
af_unset_memory_manager();
}

std::shared_ptr<fl::MemoryManagerDeviceInterface> deviceInterface_;
std::shared_ptr<fl::CachingMemoryManager> adapter_;
std::unique_ptr<fl::MemoryManagerInstaller> installer_;
};

Expand Down Expand Up @@ -137,6 +138,54 @@ TEST_F(CachingMemoryManagerTest, OOM) {
}
}

void testFragmentation(
std::shared_ptr<fl::MemoryManagerDeviceInterface> deviceInterface_,
std::shared_ptr<fl::CachingMemoryManager> adapter_,
bool expectOOM) {
af::Backend b = af::getActiveBackend();

if (b != AF_BACKEND_CUDA) {
GTEST_SKIP()
<< "CachingMemoryManager fragmentation tests require CUDA backend";
}

const auto mms = deviceInterface_->getMaxMemorySize(0);
const auto maxNumf32 = mms / sizeof(float); // AF f32 is supposed to be 32b
ASSERT_NE(mms, 0);
{
af::array a1(.5f * maxNumf32);
adapter_->printInfo("After creating a1:", 0);
} // The a1 buffer will not be freed here, just registered to the cache
adapter_->printInfo("After releasing a1:", 0);

af::array a2(.1f * maxNumf32);
adapter_->printInfo("After creating a2:", 0);

af::array a3;
try {
a3 = af::array(.5f * maxNumf32);
} catch (af::exception& ex) {
if (expectOOM) {
ASSERT_EQ(ex.err(), AF_ERR_NO_MEM);
} else {
EXPECT_TRUE(false)
<< "CachingMemoryManagerTest fragmentaiton not supposed to throw: "
<< ex.what();
}
}
}

TEST_F(CachingMemoryManagerTest, Fragmentation) {
testFragmentation(deviceInterface_, adapter_, true); // should OOM
}

TEST_F(CachingMemoryManagerTest, RecLimit) {
constexpr static size_t ONE_GB = 1 << 30;
// Fine set the manager in order not to recycle big tensors:
adapter_->setRecyclingSizeLimit(2 * ONE_GB);
testFragmentation(deviceInterface_, adapter_, false); // should not OOM
}

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down

0 comments on commit ed35c16

Please sign in to comment.