Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable page-locked tensors without CUDA #2775

Merged
merged 13 commits into from
Feb 7, 2023
Merged
5 changes: 3 additions & 2 deletions csrc/aio/common/deepspeed_aio_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/

#include <cmath>
#include <iostream>

#include "deepspeed_aio_utils.h"

Expand Down Expand Up @@ -113,8 +114,8 @@ void* ds_page_aligned_alloc(const size_t size, const bool lock)
auto mlock_ret = mlock(ptr, size);
if (mlock_ret != 0) {
auto mlock_error = errno;
printf("mlock failed with %d %s\n", mlock_error, strerror(mlock_error));

std::cerr << "mlock failed to allocate " << size << " bytes with error no " << mlock_error
<< " msg " << strerror(mlock_error) << std::endl;
free(ptr);
return nullptr;
}
Expand Down
43 changes: 43 additions & 0 deletions csrc/aio/py_lib/deepspeed_pin_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
Copyright 2023 The Microsoft DeepSpeed Team
Licensed under the MIT license.

Functionality for managing CPU tensors occupying page-locked memory.
*/

#include "deepspeed_pin_tensor.h"

using namespace std;

deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t()
{
for (auto iter = _locked_tensors.begin(); iter != _locked_tensors.end(); ++iter) {
munlock(iter->first, iter->second);
}
_locked_tensors.clear();
}

torch::Tensor deepspeed_pin_tensor_t::alloc(const size_t num_elem, const at::ScalarType& elem_type)
{
const auto num_bytes = num_elem * elementSize(elem_type);
auto pinned_buffer = ds_page_aligned_alloc(num_bytes, true);
assert(nullptr != pinned_buffer);

_locked_tensors[pinned_buffer] = num_bytes;

auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU);

return at::from_blob(pinned_buffer, static_cast<long int>(num_bytes), options);
}

bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor)
{
auto addr = locked_tensor.data_ptr();
if (_locked_tensors.find(addr) != _locked_tensors.end()) {
munlock(addr, _locked_tensors[addr]);
_locked_tensors.erase(addr);
return true;
}

return false;
}
24 changes: 24 additions & 0 deletions csrc/aio/py_lib/deepspeed_pin_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
Copyright 2023 The Microsoft DeepSpeed Team
Licensed under the MIT license.

Functionality for managing CPU tensors occupying page-locked memory.
TODO: Implement a full-featured manager that
1. Avoid page-locked memory leaks
2. Minimize page-locked memory usage by reducing internal fragmentation
*/

#include <map>
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
#include "deepspeed_py_aio.h"

struct deepspeed_pin_tensor_t {
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
std::map<void*, size_t> _locked_tensors;

deepspeed_pin_tensor_t() = default;

~deepspeed_pin_tensor_t();

torch::Tensor alloc(const size_t num_elem, const at::ScalarType& elem_type);

bool free(torch::Tensor& locked_tensor);
};
14 changes: 13 additions & 1 deletion csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size,
_overlap_events(overlap_events),
_num_threads(num_threads),
_aio_config(block_size, queue_depth, single_submit, overlap_events, false),
_num_pending_ops(0)
_num_pending_ops(0),
_pinned_tensor_mgr(new deepspeed_pin_tensor_t())
{
for (auto i = 0; i < num_threads; ++i) {
_thread_contexts.push_back(std::make_shared<deepspeed_aio_thread_t>(i, _aio_config));
Expand Down Expand Up @@ -280,3 +281,14 @@ int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char
{
return pwrite(buffer, filename, false, true);
}

tjruwase marked this conversation as resolved.
Show resolved Hide resolved
at::Tensor deepspeed_aio_handle_t::new_cpu_locked_tensor(const size_t num_elem,
const torch::Tensor& example_tensor)
{
return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type());
}

bool deepspeed_aio_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor)
{
return _pinned_tensor_mgr->free(locked_tensor);
}
7 changes: 7 additions & 0 deletions csrc/aio/py_lib/deepspeed_py_aio_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include <condition_variable>
#include <memory>
#include "deepspeed_aio_thread.h"
#include "deepspeed_pin_tensor.h"

struct deepspeed_aio_handle_t {
std::unique_ptr<struct aio_context> _aio_ctxt;
Expand All @@ -19,6 +20,7 @@ struct deepspeed_aio_handle_t {
std::vector<std::shared_ptr<struct deepspeed_aio_thread_t>> _thread_contexts;
std::vector<std::thread> _threads;
int _num_pending_ops;
std::unique_ptr<struct deepspeed_pin_tensor_t> _pinned_tensor_mgr;

deepspeed_aio_handle_t(const int block_size,
const int queue_depth,
Expand Down Expand Up @@ -56,6 +58,11 @@ struct deepspeed_aio_handle_t {

int async_pwrite(const torch::Tensor& buffer, const char* filename);

// TODO: Make API's args to be shape and dtype.
torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor);
ShadenSmith marked this conversation as resolved.
Show resolved Hide resolved

bool free_cpu_locked_tensor(torch::Tensor&);

int wait();

void _stop_threads();
Expand Down
3 changes: 3 additions & 0 deletions csrc/aio/py_lib/py_ds_aio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("async_pread", &deepspeed_aio_handle_t::async_pread)
.def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite)

.def("new_cpu_locked_tensor", &deepspeed_aio_handle_t::new_cpu_locked_tensor)
.def("free_cpu_locked_tensor", &deepspeed_aio_handle_t::free_cpu_locked_tensor)

.def("wait", &deepspeed_aio_handle_t::wait);
}
3 changes: 2 additions & 1 deletion op_builder/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def sources(self):
'csrc/aio/py_lib/deepspeed_aio_thread.cpp',
'csrc/aio/common/deepspeed_aio_utils.cpp',
'csrc/aio/common/deepspeed_aio_common.cpp',
'csrc/aio/common/deepspeed_aio_types.cpp'
'csrc/aio/common/deepspeed_aio_types.cpp',
'csrc/aio/py_lib/deepspeed_pin_tensor.cpp'
]

def include_paths(self):
Expand Down
Loading