Skip to content

Commit

Permalink
Use intrusive reference counting
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed May 20, 2024
1 parent b543d50 commit b659cad
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 45 deletions.
17 changes: 14 additions & 3 deletions src/mempool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
#include <iostream>
#include "bitlog.hpp"

#ifndef PYGPU_PYCUDA
#include <nanobind/intrusive/ref.h>
#include <nanobind/intrusive/counter.h>


namespace nb = nanobind;
#endif


namespace PYGPU_PACKAGE
{
Expand All @@ -53,7 +61,7 @@ namespace PYGPU_PACKAGE
#ifdef PYGPU_PYCUDA
#define PYGPU_SHARED_PTR boost::shared_ptr
#else
#define PYGPU_SHARED_PTR std::shared_ptr
#define PYGPU_SHARED_PTR nb::ref
#endif

template <class T>
Expand Down Expand Up @@ -89,6 +97,9 @@ namespace PYGPU_PACKAGE

template<class Allocator>
class memory_pool : mp_noncopyable
#ifndef PYGPU_PYCUDA
, public nb::intrusive_base
#endif
{
public:
typedef typename Allocator::pointer_type pointer_type;
Expand All @@ -102,7 +113,7 @@ namespace PYGPU_PACKAGE
container_t m_container;
typedef typename container_t::value_type bin_pair_t;

std::shared_ptr<Allocator> m_allocator;
PYGPU_SHARED_PTR<Allocator> m_allocator;

// A held block is one that's been released by the application, but that
// we are keeping around to dish out again.
Expand All @@ -125,7 +136,7 @@ namespace PYGPU_PACKAGE
unsigned m_leading_bits_in_bin_id;

public:
memory_pool(std::shared_ptr<Allocator> alloc, unsigned leading_bits_in_bin_id=4)
memory_pool(PYGPU_SHARED_PTR<Allocator> alloc, unsigned leading_bits_in_bin_id=4)
: m_allocator(alloc),
m_held_blocks(0), m_active_blocks(0),
m_managed_bytes(0), m_active_bytes(0),
Expand Down
11 changes: 11 additions & 0 deletions src/wrap_cl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define PY_ARRAY_UNIQUE_SYMBOL pyopencl_ARRAY_API

#include "wrap_cl.hpp"
#include <nanobind/intrusive/counter.inl>



Expand All @@ -49,6 +50,16 @@ static bool import_numpy_helper()

NB_MODULE(_cl, m)
{
py::intrusive_init(
[](PyObject *o) noexcept {
py::gil_scoped_acquire guard;
Py_INCREF(o);
},
[](PyObject *o) noexcept {
py::gil_scoped_acquire guard;
Py_DECREF(o);
});

if (!import_numpy_helper())
throw py::python_error();

Expand Down
23 changes: 11 additions & 12 deletions src/wrap_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ namespace pyopencl

// {{{ context

class context : public noncopyable
class context : public noncopyable, public py::intrusive_base
{
private:
cl_context m_context;
Expand Down Expand Up @@ -1415,7 +1415,7 @@ namespace pyopencl

// {{{ command_queue

class command_queue
class command_queue: public py::intrusive_base
{
private:
cl_command_queue m_queue;
Expand Down Expand Up @@ -1625,13 +1625,12 @@ namespace pyopencl
}
}

std::unique_ptr<context> get_context() const
py::ref<context> get_context() const
{
cl_context param_value;
PYOPENCL_CALL_GUARDED(clGetCommandQueueInfo,
(data(), CL_QUEUE_CONTEXT, sizeof(param_value), &param_value, 0));
return std::unique_ptr<context>(
new context(param_value, /*retain*/ true));
return py::ref<context>(new context(param_value, /*retain*/ true));
}

#if PYOPENCL_CL_VERSION < 0x1010
Expand Down Expand Up @@ -3437,12 +3436,12 @@ namespace pyopencl
{
private:
bool m_valid;
std::shared_ptr<command_queue> m_queue;
py::ref<command_queue> m_queue;
memory_object m_mem;
void *m_ptr;

public:
memory_map(std::shared_ptr<command_queue> cq, memory_object const &mem, void *ptr)
memory_map(py::ref<command_queue> cq, memory_object const &mem, void *ptr)
: m_valid(true), m_queue(cq), m_mem(mem), m_ptr(ptr)
{
}
Expand Down Expand Up @@ -3479,7 +3478,7 @@ namespace pyopencl
#ifndef PYPY_VERSION
inline
py::object enqueue_map_buffer(
std::shared_ptr<command_queue> cq,
py::ref<command_queue> cq,
memory_object_holder &buf,
cl_map_flags flags,
size_t offset,
Expand Down Expand Up @@ -3563,7 +3562,7 @@ namespace pyopencl
#ifndef PYPY_VERSION
inline
py::object enqueue_map_image(
std::shared_ptr<command_queue> cq,
py::ref<command_queue> cq,
memory_object_holder &img,
cl_map_flags flags,
py::object py_origin,
Expand Down Expand Up @@ -3697,15 +3696,15 @@ namespace pyopencl
class svm_allocation : public svm_pointer
{
private:
std::shared_ptr<context> m_context;
py::ref<context> m_context;
void *m_allocation;
size_t m_size;
command_queue_ref m_queue;
// FIXME Should maybe also allow keeping a list of events so that we can
// wait for users to finish in the case of out-of-order queues.

public:
svm_allocation(std::shared_ptr<context> const &ctx, size_t size, cl_uint alignment,
svm_allocation(py::ref<context> const &ctx, size_t size, cl_uint alignment,
cl_svm_mem_flags flags, const command_queue *queue = nullptr)
: m_context(ctx), m_size(size)
{
Expand Down Expand Up @@ -3738,7 +3737,7 @@ namespace pyopencl
}
}

svm_allocation(std::shared_ptr<context> const &ctx, void *allocation, size_t size,
svm_allocation(py::ref<context> const &ctx, void *allocation, size_t size,
const cl_command_queue queue)
: m_context(ctx), m_allocation(allocation), m_size(size)
{
Expand Down
14 changes: 12 additions & 2 deletions src/wrap_cl_part_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ void pyopencl_expose_part_1(py::module_ &m)

{
typedef context cls;
py::class_<cls>(m, "Context", py::dynamic_attr(), py::is_weak_referenceable())
py::class_<cls>(
m, "Context",
py::dynamic_attr(),
py::is_weak_referenceable(),
py::intrusive_ptr<cls>(
[](cls *o, PyObject *po) noexcept { o->set_self_py(po); })
)
.def(
"__init__",
[](cls *self, py::object py_devices, py::object py_properties,
Expand Down Expand Up @@ -112,7 +118,11 @@ void pyopencl_expose_part_1(py::module_ &m)
// {{{ command queue
{
typedef command_queue cls;
py::class_<cls>(m, "CommandQueue", py::dynamic_attr())
py::class_<cls>(
m, "CommandQueue",
py::dynamic_attr(),
py::intrusive_ptr<cls>(
[](cls *o, PyObject *po) noexcept { o->set_self_py(po); }) )
.def(
py::init<const context &, const device *, py::object>(),
py::arg("context"),
Expand Down
2 changes: 1 addition & 1 deletion src/wrap_cl_part_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ void pyopencl_expose_part_2(py::module_ &m)
{
typedef svm_allocation cls;
py::class_<cls, svm_pointer>(m, "SVMAllocation", py::dynamic_attr())
.def(py::init<std::shared_ptr<context>, size_t, cl_uint, cl_svm_mem_flags, const command_queue *>(),
.def(py::init<py::ref<context>, size_t, cl_uint, cl_svm_mem_flags, const command_queue *>(),
py::arg("context"),
py::arg("size"),
py::arg("alignment"),
Expand Down
3 changes: 2 additions & 1 deletion src/wrap_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/intrusive/counter.h>
#include <nanobind/intrusive/ref.h>
#include <nanobind/ndarray.h>


Expand Down
Loading

0 comments on commit b659cad

Please sign in to comment.