Skip to content

Commit

Permalink
[feature] Offloader support readv/writev (#9)
Browse files Browse the repository at this point in the history
* offloader supports readv/writev

* update pyi

* update DiskOffloader

* add test dist offload
  • Loading branch information
ver217 committed Jul 7, 2022
1 parent 1ddde2e commit 9550cbf
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 12 deletions.
6 changes: 5 additions & 1 deletion colo_nvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import Tensor
from typing import Callable, Optional
from typing import Callable, Optional, List


class Offloader:
Expand All @@ -11,3 +11,7 @@ class Offloader:
def sync_write_events(self) -> None: ...
def sync_read_events(self) -> None: ...
def synchronize(self) -> None: ...
def async_writev(self, tensors: List[Tensor], key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
def async_readv(self, tensors: List[Tensor], key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
def sync_writev(self, tensors: List[Tensor], key: str) -> None: ...
def sync_readv(self, tensors: List[Tensor], key: str) -> None: ...
50 changes: 41 additions & 9 deletions colo_nvme/offload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import torch
import uuid
from typing import Callable, Optional
from functools import partial
from typing import Callable, Optional, List
from colo_nvme._C import Offloader


Expand All @@ -19,7 +18,12 @@ def __init__(self, dir_name: str, n_entries: int = 128, backend: str = 'uring')

def async_write(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None:
assert tensor.storage().size() > 0
super().async_write(tensor, str(id(tensor)), partial(DiskOffloader._write_callback, tensor, callback))

def callback_fn():
tensor.storage().resize_(0)
if callback is not None:
callback()
super().async_write(tensor, str(id(tensor)), callback_fn)

def async_read(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None:
if tensor.storage().size() == 0:
Expand All @@ -29,15 +33,43 @@ def async_read(self, tensor: torch.Tensor, callback: Optional[Callable[[], None]
def sync_write(self, tensor: torch.Tensor) -> None:
assert tensor.storage().size() > 0
super().sync_write(tensor, str(id(tensor)))
self._write_callback(tensor)
tensor.storage().resize_(0)

def sync_read(self, tensor: torch.Tensor) -> None:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.numel())
super().sync_read(tensor, str(id(tensor)))

@staticmethod
def _write_callback(tensor: torch.Tensor, callback: Optional[Callable[[], None]] = None) -> None:
tensor.storage().resize_(0)
if callback is not None:
callback()
def async_writev(self, tensors: List[torch.Tensor], callback: Optional[Callable[[], None]] = None) -> None:
for tensor in tensors:
assert tensor.storage().size() > 0
key = str(hash(tuple(tensors)))

def callback_fn():
for tensor in tensors:
tensor.storage().resize_(0)
if callback is not None:
callback()
super().async_writev(tensors, key, callback_fn)

def async_readv(self, tensors: List[torch.Tensor], callback: Optional[Callable[[], None]] = None) -> None:
for tensor in tensors:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.numel())
key = str(hash(tuple(tensors)))
super().async_readv(tensors, key, callback)

def sync_writev(self, tensors: List[torch.Tensor]) -> None:
for tensor in tensors:
assert tensor.storage().size() > 0
key = str(hash(tuple(tensors)))
super().sync_writev(tensors, key)
for tensor in tensors:
tensor.storage().resize_(0)

def sync_readv(self, tensors: List[torch.Tensor]) -> None:
for tensor in tensors:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.numel())
key = str(hash(tuple(tensors)))
super().sync_readv(tensors, key)
90 changes: 88 additions & 2 deletions csrc/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
#include "aio.h"
#include "space_mgr.h"

iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)
{
iovec *iovs = static_cast<iovec *>(calloc(tensors.size(), sizeof(iovec)));
for (size_t i = 0; i < tensors.size(); i++)
{
iovs[i].iov_base = tensors[i].data_ptr();
iovs[i].iov_len = tensors[i].storage().nbytes();
}
return iovs;
}

class Offloader
{
public:
Expand All @@ -34,7 +45,7 @@ class Offloader
throw std::runtime_error("Tensor must be contiguous and on cpu");
ull bytes = tensor.storage().nbytes();
ull offset = this->space_mgr.alloc(bytes);
SpaceInfo space_info = SpaceInfo(offset, bytes);
SpaceInfo space_info(offset, bytes);
this->tensors_info[key] = space_info;
return space_info;
}
Expand Down Expand Up @@ -109,6 +120,77 @@ class Offloader
printf("Remove \"%s\" error(%d): %s\n", this->filename.c_str(), errno, strerror(errno));
}

SpaceInfo prepare_writev(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull total_bytes = 0;
for (const at::Tensor &tensor : tensors)
{
if (!tensor.is_contiguous() || !tensor.is_cpu())
throw std::runtime_error("Tensor must be contiguous and on cpu");
total_bytes += tensor.storage().nbytes();
}
ull offset = this->space_mgr.alloc(total_bytes);
SpaceInfo space_info(offset, total_bytes);
this->tensors_info[key] = space_info;
return space_info;
}

SpaceInfo prepare_readv(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull total_bytes = 0;
for (const at::Tensor &tensor : tensors)
{
if (!tensor.is_contiguous() || !tensor.is_cpu())
throw std::runtime_error("Tensor must be contiguous and on cpu");
total_bytes += tensor.storage().nbytes();
}
if (this->tensors_info.find(key) == this->tensors_info.end())
throw std::runtime_error("Read error, tensor not found");
SpaceInfo space_info = this->tensors_info[key];
if (total_bytes != space_info.second)
throw std::runtime_error("Read error, tensor shape mismatch");
this->tensors_info.erase(key);
return space_info;
}

void async_writev(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr)
{
ull offset, bytes;
std::tie(offset, bytes) = prepare_writev(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
this->aio->writev(this->fd, iov, tensors.size(), offset, callback);
}

void async_readv(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr)
{

ull offset, bytes;
std::tie(offset, bytes) = prepare_readv(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
auto fn = std::bind(&Offloader::release, this, offset, bytes, callback);
this->aio->readv(this->fd, iov, tensors.size(), offset, fn);
}

void sync_writev(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull offset, bytes;
std::tie(offset, bytes) = prepare_writev(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
lseek(this->fd, offset, SEEK_SET);
writev(this->fd, iov, tensors.size());
delete iov;
}

void sync_readv(const std::vector<at::Tensor> &tensors, const std::string &key)
{
ull offset, bytes;
std::tie(offset, bytes) = prepare_readv(tensors, key);
iovec *iov = tensors_to_iovec(tensors);
lseek(this->fd, offset, SEEK_SET);
readv(this->fd, iov, tensors.size());
delete iov;
}

private:
const std::string filename;
int fd;
Expand All @@ -134,5 +216,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("sync_read", &Offloader::sync_read, py::arg("tensor"), py::arg("key"))
.def("sync_write_events", &Offloader::sync_write_events)
.def("sync_read_events", &Offloader::sync_write_events)
.def("synchronize", &Offloader::synchronize);
.def("synchronize", &Offloader::synchronize)
.def("async_writev", &Offloader::async_writev, py::arg("tensors"), py::arg("key"), py::arg("callback") = py::none())
.def("async_readv", &Offloader::async_readv, py::arg("tensors"), py::arg("key"), py::arg("callback") = py::none())
.def("sync_writev", &Offloader::sync_writev, py::arg("tensors"), py::arg("key"))
.def("sync_readv", &Offloader::sync_readv, py::arg("tensors"), py::arg("key"));
}
109 changes: 109 additions & 0 deletions tests/test_disk_offloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import pytest
from colo_nvme import DiskOffloader


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_sync_io(backend):
x = torch.rand(2, 2)
x_copy = x.clone()
of = DiskOffloader('.', backend=backend)
try:
of.sync_read(x)
assert False
except RuntimeError:
pass
of.sync_write(x)
assert x.storage().size() == 0
of.sync_read(x)
assert torch.equal(x, x_copy)


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_async_io(backend):
x = torch.rand(2, 2)
x_copy = x.clone()
of = DiskOffloader('.', backend=backend)
try:
of.async_read(x)
assert False
except RuntimeError:
pass
of.async_write(x)
assert x.storage().size() > 0
of.sync_write_events()
assert x.storage().size() == 0
of.sync_read(x)
of.sync_read_events()
assert torch.equal(x, x_copy)


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_sync_vec_io(backend):
x = torch.rand(2, 2)
y = torch.rand(2, 2, 2)
x_copy = x.clone()
y_copy = y.clone()
of = DiskOffloader('.', backend=backend)
try:
of.sync_readv([x, y])
assert False
except RuntimeError:
pass
of.sync_writev([x, y])
assert x.storage().size() == 0
assert y.storage().size() == 0
try:
of.sync_readv(x)
assert False
except RuntimeError:
pass
try:
of.sync_readv([y, x])
assert False
except RuntimeError:
pass
of.sync_readv([x, y])
assert torch.equal(x, x_copy)
assert torch.equal(y, y_copy)


@pytest.mark.parametrize('backend', ['uring', 'aio'])
def test_async_vec_io(backend):
x = torch.rand(2, 2)
y = torch.rand(2, 2, 2)
x_copy = x.clone()
y_copy = y.clone()
of = DiskOffloader('.', backend=backend)
try:
of.async_readv([x, y])
assert False
except RuntimeError:
pass
of.async_writev([x, y])
assert x.storage().size() > 0
assert y.storage().size() > 0
of.sync_write_events()
assert x.storage().size() == 0
assert y.storage().size() == 0
try:
of.async_readv(x)
assert False
except RuntimeError:
pass
try:
of.async_readv([y, x])
assert False
except RuntimeError:
pass
of.async_readv([x, y])
of.sync_read_events()
assert torch.equal(x, x_copy)
assert torch.equal(y, y_copy)


if __name__ == '__main__':
test_sync_io('uring')
test_async_io('uring')
test_sync_vec_io('uring')
test_async_vec_io('uring')

0 comments on commit 9550cbf

Please sign in to comment.