Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flags to set caching memory manager split+recycling size limits #420

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions flashlight/app/asr/Train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ DEFINE_int64(
} // namespace

int main(int argc, char** argv) {
fl::init();
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
std::string exec(argv[0]);
Expand Down Expand Up @@ -172,6 +171,8 @@ int main(int argc, char** argv) {
LOG(FATAL) << "'runpath' specified by --rundir, --runname cannot be empty";
}

fl::init(FLAGS_fl_mem_recycling_size, FLAGS_fl_mem_split_size);

af::setSeed(FLAGS_seed);
fl::DynamicBenchmark::setBenchmarkMode(FLAGS_fl_benchmark_mode);

Expand Down Expand Up @@ -608,7 +609,7 @@ int main(int argc, char** argv) {
};

std::ofstream memLog;
if (FLAGS_fl_log_mem_ops_interval > 0 && isMaster) {
if (FLAGS_fl_mem_log_ops_interval > 0 && isMaster) {
auto* curMemMgr =
fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
if (curMemMgr) {
Expand All @@ -619,7 +620,7 @@ int main(int argc, char** argv) {
}
curMemMgr->setLogStream(&memLog);
curMemMgr->setLoggingEnabled(true);
curMemMgr->setLogFlushInterval(FLAGS_fl_log_mem_ops_interval);
curMemMgr->setLogFlushInterval(FLAGS_fl_mem_log_ops_interval);
}
}

Expand Down
22 changes: 21 additions & 1 deletion flashlight/app/asr/common/Flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,32 @@ DEFINE_string(
DEFINE_int64(fl_vlog_level, 0, "Sets the verbose logging level");

DEFINE_int64(
fl_log_mem_ops_interval,
fl_mem_log_ops_interval,
0,
"Flushes memory manager logs after a specified "
"number of log entries. 1000000 is a reasonable "
"value which will reduce overhead.");

DEFINE_int64(
fl_mem_recycling_size,
(1L << 28), /* 256MB */
"prevents the caching memory manager from recycling buffers larger "
"than this value. Recycled buffers can be split by the caching "
"manager so it helps reduce fragmentation of buffers over this value."
"Default value of 256MB works well for typical workload where "
"number of allocation is exponentially decreasing with allocation "
"size and largest allocations are ~500MB.");

DEFINE_int64(
fl_mem_split_size,
(1L << 29), /* 512MB */
"prevents the caching memory manager from splitting buffers larger "
"than this value. Helps reduce external fragmentation by allowing "
"higher internal fragmentation."
"Default value of 512MB works well for typical workload where "
"number of allocation is exponentially decreasing with allocation "
"size and largest allocations are ~500MB.");

// MIXED PRECISION OPTIONS
DEFINE_bool(
fl_amp_use_mixed_precision,
Expand Down
4 changes: 3 additions & 1 deletion flashlight/app/asr/common/Flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ DECLARE_bool(fl_benchmark_mode);
DECLARE_string(fl_optim_mode);
DECLARE_string(fl_log_level);
DECLARE_int64(fl_vlog_level);
DECLARE_int64(fl_log_mem_ops_interval);
DECLARE_int64(fl_mem_log_ops_interval);
DECLARE_int64(fl_mem_recycling_size);
DECLARE_int64(fl_mem_split_size);

/* ========== MIXED PRECISION OPTIONS ========== */

Expand Down
18 changes: 16 additions & 2 deletions flashlight/fl/common/Init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <af/device.h>

#include "flashlight/fl/memory/MemoryManagerInstaller.h"
#include "flashlight/fl/memory/managers/CachingMemoryManager.h"

namespace fl {
namespace {
Expand All @@ -23,14 +24,27 @@ std::once_flag flInitFlag;
*
* Can only be called once per process. Subsequent calls will be noops.
*/
void init() {
std::call_once(flInitFlag, []() {
void init(int memRecyclingSize /*=-1*/, int memSplitSize /*=-1*/) {
std::call_once(flInitFlag, [memRecyclingSize, memSplitSize]() {
af_init();
// TODO: remove this temporary workaround for TextDatasetTest crash on CPU
// backend when tearing down the test environment. This is possibly due to
// AF race conditions when tearing down our custom memory manager.
if (!FL_BACKEND_CPU) {
MemoryManagerInstaller::installDefaultMemoryManager();
auto* curMemMgr =
fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
if (curMemMgr) {
auto cachMemMgr = dynamic_cast<fl::CachingMemoryManager*>(curMemMgr);
if (cachMemMgr) {
if (memRecyclingSize > -1) {
cachMemMgr->setRecyclingSizeLimit(memRecyclingSize);
}
if (memSplitSize > -1) {
cachMemMgr->setSplitSizeLimit(memSplitSize);
}
}
}
}
});
}
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/common/Init.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ namespace fl {
/**
* Initialize Flashlight.
*/
void init();
void init(int memRecyclingSize=-1, int memSplitSize=-1);

} // namespace fl