Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Offloader support readv/writev #9

Merged
merged 4 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion colo_nvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import Tensor
from typing import Callable, Optional
from typing import Callable, Optional, List


class Offloader:
Expand All @@ -11,3 +11,7 @@ class Offloader:
def sync_write_events(self) -> None: ...
def sync_read_events(self) -> None: ...
def synchronize(self) -> None: ...
def async_writev(self, tensors: List[Tensor], key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
def async_readv(self, tensors: List[Tensor], key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
def sync_writev(self, tensors: List[Tensor], key: str) -> None: ...
def sync_readv(self, tensors: List[Tensor], key: str) -> None: ...
50 changes: 41 additions & 9 deletions colo_nvme/offload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import torch
import uuid
from typing import Callable, Optional
from functools import partial
from typing import Callable, Optional, List
from colo_nvme._C import Offloader


Expand All @@ -19,7 +18,12 @@ def __init__(self, dir_name: str, n_entries: int = 128, backend: str = 'uring')

def async_write(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None:
assert tensor.storage().size() > 0
super().async_write(tensor, str(id(tensor)), partial(DiskOffloader._write_callback, tensor, callback))

def callback_fn():
tensor.storage().resize_(0)
if callback is not None:
callback()
super().async_write(tensor, str(id(tensor)), callback_fn)

def async_read(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None:
if tensor.storage().size() == 0:
Expand All @@ -29,15 +33,43 @@ def async_read(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]
def sync_write(self, tensor: torch.Tensor) -> None:
assert tensor.storage().size() > 0
super().sync_write(tensor, str(id(tensor)))
self._write_callback(tensor)
tensor.storage().resize_(0)

def sync_read(self, tensor: torch.Tensor) -> None:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.numel())
super().sync_read(tensor, str(id(tensor)))

@staticmethod
def _write_callback(tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None:
tensor.storage().resize_(0)
if callback is not None:
callback()
def async_writev(self, tensors: List[torch.Tensor], callback: Optional[Callable[[], None]] = None) -> None:
for tensor in tensors:
assert tensor.storage().size() > 0
key = str(hash(tuple(tensors)))

def callback_fn():
for tensor in tensors:
tensor.storage().resize_(0)
if callback is not None:
callback()
super().async_writev(tensors, key, callback_fn)

def async_readv(self, tensors: List[torch.Tensor], callback: Optional[Callable[[], None]] = None) -> None:
for tensor in tensors:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.numel())
key = str(hash(tuple(tensors)))
super().async_readv(tensors, key, callback)

def sync_writev(self, tensors: List[torch.Tensor]) -> None:
for tensor in tensors:
assert tensor.storage().size() > 0
key = str(hash(tuple(tensors)))
super().sync_writev(tensors, key)
for tensor in tensors:
tensor.storage().resize_(0)

def sync_readv(self, tensors: List[torch.Tensor]) -> None:
for tensor in tensors:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.numel())
key = str(hash(tuple(tensors)))
super().sync_readv(tensors, key)
90 changes: 88 additions & 2 deletions csrc/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
#include "aio.h"
#include "space_mgr.h"

iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)
{
iovec *iovs = static_cast<iovec *>(calloc(tensors.size(), sizeof(iovec)));
for (size_t i = 0; i < tensors.size(); i++)
{
iovs[i].iov_base = tensors[i].data_ptr();
iovs[i].iov_len = tensors[i].storage().nbytes();
}
return iovs;
}

class Offloader
{
public:
Expand All @@ -34,7 +45,7 @@ class Offloader
throw std::runtime_error("Tensor must be contiguous and on cpu");
ull bytes = tensor.storage().nbytes();
ull offset = this->space_mgr.alloc(bytes);
SpaceInfo space_info = SpaceInfo(offset, bytes);
SpaceInfo space_info(offset, bytes);
this->tensors_info[key] = space_info;
return space_info;
}
Expand Down Expand Up @@ -109,6 +120,77 @@ class Offloader
printf("Remove \"%s\" error(%d): %s\n", this->filename.c_str(), errno, strerror(errno));
}

SpaceInfo prepare_writev(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull total_bytes = 0;
for (const at::Tensor &tensor : tensors)
{
if (!tensor.is_contiguous() || !tensor.is_cpu())
throw std::runtime_error("Tensor must be contiguous and on cpu");
total_bytes += tensor.storage().nbytes();
}
ull offset = this->space_mgr.alloc(total_bytes);
SpaceInfo space_info(offset, total_bytes);
this->tensors_info[key] = space_info;
return space_info;
}

SpaceInfo prepare_readv(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull total_bytes = 0;
for (const at::Tensor &tensor : tensors)
{
if (!tensor.is_contiguous() || !tensor.is_cpu())
throw std::runtime_error("Tensor must be contiguous and on cpu");
total_bytes += tensor.storage().nbytes();
}
if (this->tensors_info.find(key) == this->tensors_info.end())
throw std::runtime_error("Read error, tensor not found");
SpaceInfo space_info = this->tensors_info[key];
if (total_bytes != space_info.second)
throw std::runtime_error("Read error, tensor shape mismatch");
this->tensors_info.erase(key);
return space_info;
}

void async_writev(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr)
{
ull offset, bytes;
std::tie(offset, bytes) = prepare_writev(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
this->aio->writev(this->fd, iov, tensors.size(), offset, callback);
}

void async_readv(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr)
{

ull offset, bytes;
std::tie(offset, bytes) = prepare_readv(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
auto fn = std::bind(&Offloader::release, this, offset, bytes, callback);
this->aio->readv(this->fd, iov, tensors.size(), offset, fn);
}

void sync_writev(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull offset, bytes;
std::tie(offset, bytes) = prepare_writev(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
lseek(this->fd, offset, SEEK_SET);
writev(this->fd, iov, tensors.size());
delete iov;
}

void sync_readv(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull offset, bytes;
std::tie(offset, bytes) = prepare_readv(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
lseek(this->fd, offset, SEEK_SET);
readv(this->fd, iov, tensors.size());
delete iov;
}

private:
const std::string filename;
int fd;
Expand All @@ -134,5 +216,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("sync_read", &Offloader::sync_read, py::arg("tensor"), py::arg("key"))
.def("sync_write_events", &Offloader::sync_write_events)
.def("sync_read_events", &Offloader::sync_write_events)
.def("synchronize", &Offloader::synchronize);
.def("synchronize", &Offloader::synchronize)
.def("async_writev", &Offloader::async_writev, py::arg("tensors"), py::arg("key"), py::arg("callback") = py::none())
.def("async_readv", &Offloader::async_readv, py::arg("tensors"), py::arg("key"), py::arg("callback") = py::none())
.def("sync_writev", &Offloader::sync_writev, py::arg("tensors"), py::arg("key"))
.def("sync_readv", &Offloader::sync_readv, py::arg("tensors"), py::arg("key"));
}
109 changes: 109 additions & 0 deletions tests/test_disk_offloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import pytest
from colo_nvme import DiskOffloader


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_sync_io(backend):
x = torch.rand(2, 2)
x_copy = x.clone()
of = DiskOffloader('.', backend=backend)
try:
of.sync_read(x)
assert False
except RuntimeError:
pass
of.sync_write(x)
assert x.storage().size() == 0
of.sync_read(x)
assert torch.equal(x, x_copy)


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_async_io(backend):
x = torch.rand(2, 2)
x_copy = x.clone()
of = DiskOffloader('.', backend=backend)
try:
of.async_read(x)
assert False
except RuntimeError:
pass
of.async_write(x)
assert x.storage().size() > 0
of.sync_write_events()
assert x.storage().size() == 0
of.sync_read(x)
of.sync_read_events()
assert torch.equal(x, x_copy)


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_sync_vec_io(backend):
x = torch.rand(2, 2)
y = torch.rand(2, 2, 2)
x_copy = x.clone()
y_copy = y.clone()
of = DiskOffloader('.', backend=backend)
try:
of.sync_readv([x, y])
assert False
except RuntimeError:
pass
of.sync_writev([x, y])
assert x.storage().size() == 0
assert y.storage().size() == 0
try:
of.sync_readv(x)
assert False
except RuntimeError:
pass
try:
of.sync_readv([y, x])
assert False
except RuntimeError:
pass
of.sync_readv([x, y])
assert torch.equal(x, x_copy)
assert torch.equal(y, y_copy)


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_async_vec_io(backend):
x = torch.rand(2, 2)
y = torch.rand(2, 2, 2)
x_copy = x.clone()
y_copy = y.clone()
of = DiskOffloader('.', backend=backend)
try:
of.async_readv([x, y])
assert False
except RuntimeError:
pass
of.async_writev([x, y])
assert x.storage().size() > 0
assert y.storage().size() > 0
of.sync_write_events()
assert x.storage().size() == 0
assert y.storage().size() == 0
try:
of.async_readv(x)
assert False
except RuntimeError:
pass
try:
of.async_readv([y, x])
assert False
except RuntimeError:
pass
of.async_readv([x, y])
of.sync_read_events()
assert torch.equal(x, x_copy)
assert torch.equal(y, y_copy)


if __name__ == '__main__':
test_sync_io('uring')
test_async_io('uring')
test_sync_vec_io('uring')
test_async_vec_io('uring')