Skip to content

Commit

Permalink
Enhance the CachingMemoryManager to possibly handle high memory press…
Browse files Browse the repository at this point in the history
…ure (#188)

Summary:
Enhance the CachingMemoryManager to handle high memory pressure by adding 2 options.
Note: the default behavior is unchanged.
Add a new API to MemoryManagerAdapter to setOption().
Side: remove some glog includes failing compilation if glog-dev not installed locally

**Original Issue**:
flashlight/flashlight#180
closes flashlight/flashlight#180

The CachingMemoryManager being greedy in memory. This PR adds 2 options to mitigate this.

### Test Plan (required)
Implemented a new unitest in order to show OOM with the default CachingMM and no OOM if the right option is set.

Pull Request resolved: flashlight/flashlight#188

Reviewed By: vineelpratap

Differential Revision: D24435873

Pulled By: jacobkahn

fbshipit-source-id: 99186439c1e306ad987079cc326ab68fe1f028fc
  • Loading branch information
WilliamTambellini authored and facebook-github-bot committed Oct 25, 2020
1 parent 12e4646 commit f701125
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 13 deletions.
6 changes: 3 additions & 3 deletions app/asr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ add_executable(Train ${CMAKE_CURRENT_LIST_DIR}/Train.cpp)
add_executable(Test ${CMAKE_CURRENT_LIST_DIR}/Test.cpp)
add_executable(Decoder ${CMAKE_CURRENT_LIST_DIR}/Decode.cpp)

target_link_libraries(Train flashlight-app-asr)
target_link_libraries(Test flashlight-app-asr)
target_link_libraries(Decoder flashlight-app-asr)
target_link_libraries(Train flashlight-app-asr ${CMAKE_DL_LIBS})
target_link_libraries(Test flashlight-app-asr ${CMAKE_DL_LIBS})
target_link_libraries(Decoder flashlight-app-asr ${CMAKE_DL_LIBS})

# --------------------------- Tests ---------------------------

Expand Down
10 changes: 6 additions & 4 deletions ext/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ cmake_minimum_required(VERSION 3.5.1)
set(DIR ${CMAKE_CURRENT_LIST_DIR})
set(LIBS flashlight)

build_test(
${DIR}/common/SequentialBuilderTest.cpp
${LIBS}
"ARCHDIR=\"${DIR}/common/\""
if(FL_BUILD_CONTRIB)
build_test(
${DIR}/common/SequentialBuilderTest.cpp
${LIBS}
"ARCHDIR=\"${DIR}/common/\""
)
endif()

add_library(test_module_plugin MODULE
${DIR}/common/test_module_plugin.cpp)
Expand Down
19 changes: 17 additions & 2 deletions flashlight/memory/managers/CachingMemoryManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

#include <algorithm>
#include <iostream>
#include <limits>
#include <mutex>
#include <numeric>
#include <stdexcept>
#include <string>
#include <vector>
Expand Down Expand Up @@ -87,6 +89,14 @@ CachingMemoryManager::CachingMemoryManager(

void CachingMemoryManager::initialize() {}

void CachingMemoryManager::setRecyclingSizeLimit(size_t limit) {
recyclingSizeLimit_ = limit;
}

void CachingMemoryManager::setSplitSizeLimit(size_t limit) {
splitSizeLimit_ = limit;
}

void CachingMemoryManager::shutdown() {
signalMemoryCleanup();
}
Expand Down Expand Up @@ -126,7 +136,9 @@ void* CachingMemoryManager::alloc(

CachingMemoryManager::Block* block = nullptr;
auto it = pool.lower_bound(&searchKey);
if (it != pool.end()) {
// Recycle blocks if any found, and if small alloc or the block size is not
// too large:
if (it != pool.end() && (isSmallAlloc || (*it)->size_ < recyclingSizeLimit_)) {
block = *it;
pool.erase(it);
memoryInfo.stats_.cachedBytes_ -= block->size_;
Expand All @@ -144,7 +156,10 @@ void* CachingMemoryManager::alloc(
// implementation simple.
CachingMemoryManager::Block* remaining = nullptr;
size_t diff = block->size_ - size;
if (diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) {
if ((diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) &&
(block->size_ < splitSizeLimit_) // possibly dont split large buffers to
// minimize risk of fragmentation
) {
remaining = block;
block = new Block(size, block->ptr_);
block->prev_ = remaining->prev_;
Expand Down
15 changes: 15 additions & 0 deletions flashlight/memory/managers/CachingMemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

#pragma once

#include <atomic>
#include <limits>
#include <memory>
#include <mutex>
#include <numeric>
#include <set>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -48,6 +51,9 @@ class CachingMemoryManager : public MemoryManagerAdapter {
bool jitTreeExceedsMemoryPressure(size_t bytes) override;
void addMemoryManagement(int device) override;
void removeMemoryManagement(int device) override;
// Set runtime options: RecyclingSizeLimit, SplitSizeLimit, ... Warning: not thread safe
void setRecyclingSizeLimit(size_t);
void setSplitSizeLimit(size_t);

// Block denotes a single allocated unit of memory.
struct Block {
Expand Down Expand Up @@ -133,6 +139,15 @@ class CachingMemoryManager : public MemoryManagerAdapter {

void tryMergeBlocks(Block* dst, Block* src, BlockSet& freeBlocks);
void freeBlock(Block* block);

private:
// Non-const runtime options in order to fine tune the behavior of this
// manager. Prevents to recycle some buffers, to be set by the user if
// desired:
size_t recyclingSizeLimit_{std::numeric_limits<size_t>::max()};
//size_t recyclingSizeLimit;
// Prevents to split big buffers, to be set by the user if desired:
size_t splitSizeLimit_{std::numeric_limits<size_t>::max()};
};

} // namespace fl
2 changes: 1 addition & 1 deletion flashlight/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.5.1)

set(DIR ${FLASHLIGHT_CORE_DIR}/test)
set(LIBS flashlight)
set(LIBS flashlight ${CMAKE_DL_LIBS})
build_test(${DIR}/autograd/AutogradTest.cpp ${LIBS} "")
build_test(${DIR}/common/DevicePtrTest.cpp ${LIBS} "")
build_test(${DIR}/common/HistogramTest.cpp ${LIBS} "")
Expand Down
55 changes: 52 additions & 3 deletions flashlight/test/memory/CachingMemoryManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
class CachingMemoryManagerTest : public ::testing::Test {
protected:
virtual void SetUp() override {
auto deviceInterface_ =
std::make_shared<fl::MemoryManagerDeviceInterface>();
auto adapter_ = std::make_shared<fl::CachingMemoryManager>(
deviceInterface_ = std::make_shared<fl::MemoryManagerDeviceInterface>();
adapter_ = std::make_shared<fl::CachingMemoryManager>(
af::getDeviceCount(), deviceInterface_);
installer_ = fl::cpp::make_unique<fl::MemoryManagerInstaller>(adapter_);
installer_->setAsMemoryManager();
Expand All @@ -32,6 +31,8 @@ class CachingMemoryManagerTest : public ::testing::Test {
af_unset_memory_manager();
}

std::shared_ptr<fl::MemoryManagerDeviceInterface> deviceInterface_;
std::shared_ptr<fl::CachingMemoryManager> adapter_;
std::unique_ptr<fl::MemoryManagerInstaller> installer_;
};

Expand Down Expand Up @@ -137,6 +138,54 @@ TEST_F(CachingMemoryManagerTest, OOM) {
}
}

void testFragmentation(
std::shared_ptr<fl::MemoryManagerDeviceInterface> deviceInterface_,
std::shared_ptr<fl::CachingMemoryManager> adapter_,
bool expectOOM) {
af::Backend b = af::getActiveBackend();

if (b != AF_BACKEND_CUDA) {
GTEST_SKIP()
<< "CachingMemoryManager fragmentation tests require CUDA backend";
}

const auto mms = deviceInterface_->getMaxMemorySize(0);
const auto maxNumf32 = mms / sizeof(float); // AF f32 is supposed to be 32b
ASSERT_NE(mms, 0);
{
af::array a1(.5f * maxNumf32);
adapter_->printInfo("After creating a1:", 0);
} // The a1 buffer will not be freed here, just registered to the cache
adapter_->printInfo("After releasing a1:", 0);

af::array a2(.1f * maxNumf32);
adapter_->printInfo("After creating a2:", 0);

af::array a3;
try {
a3 = af::array(.5f * maxNumf32);
} catch (af::exception& ex) {
if (expectOOM) {
ASSERT_EQ(ex.err(), AF_ERR_NO_MEM);
} else {
EXPECT_TRUE(false)
<< "CachingMemoryManagerTest fragmentaiton not supposed to throw: "
<< ex.what();
}
}
}

TEST_F(CachingMemoryManagerTest, Fragmentation) {
testFragmentation(deviceInterface_, adapter_, true); // should OOM
}

TEST_F(CachingMemoryManagerTest, RecLimit) {
constexpr static size_t ONE_GB = 1 << 30;
// Fine set the manager in order not to recycle big tensors:
adapter_->setRecyclingSizeLimit(2 * ONE_GB);
testFragmentation(deviceInterface_, adapter_, false); // should not OOM
}

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down

0 comments on commit f701125

Please sign in to comment.