diff --git a/src/blob.c b/src/blob.c index 133d49423..de0141bf8 100644 --- a/src/blob.c +++ b/src/blob.c @@ -158,8 +158,58 @@ PyGetSetDef Blob_getseters[] = { {NULL} }; +static int +Blob_getbuffer(Blob *self, Py_buffer *view, int flags) +{ + return PyBuffer_FillInfo(view, (PyObject *) self, + (void *) git_blob_rawcontent(self->blob), + git_blob_rawsize(self->blob), 1, flags); +} + +#if PY_MAJOR_VERSION == 2 -PyDoc_STRVAR(Blob__doc__, "Blob objects."); +static Py_ssize_t +Blob_getreadbuffer(Blob *self, Py_ssize_t index, const void **ptr) +{ + if (index != 0) { + PyErr_SetString(PyExc_SystemError, + "accessing non-existent blob segment"); + return -1; + } + *ptr = (void *) git_blob_rawcontent(self->blob); + return git_blob_rawsize(self->blob); +} + +static Py_ssize_t +Blob_getsegcount(Blob *self, Py_ssize_t *lenp) +{ + if (lenp) + *lenp = git_blob_rawsize(self->blob); + + return 1; +} + +static PyBufferProcs Blob_as_buffer = { + (readbufferproc)Blob_getreadbuffer, + NULL, /* bf_getwritebuffer */ + (segcountproc)Blob_getsegcount, + NULL, /* charbufferproc */ + (getbufferproc)Blob_getbuffer, +}; + +#else + +static PyBufferProcs Blob_as_buffer = { + (getbufferproc)Blob_getbuffer, +}; + +#endif /* python 2 vs python 3 buffers */ + +PyDoc_STRVAR(Blob__doc__, "Blob object.\n" + "\n" + "Blobs implement the buffer interface, which means you can get access\n" + "to its data via `memoryview(blob)` without the need to create a copy." +); PyTypeObject BlobType = { PyVarObject_HEAD_INIT(NULL, 0) @@ -180,8 +230,14 @@ PyTypeObject BlobType = { 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ - 0, /* tp_as_buffer */ + &Blob_as_buffer, /* tp_as_buffer */ +#if PY_MAJOR_VERSION == 2 + Py_TPFLAGS_DEFAULT | /* tp_flags */ + Py_TPFLAGS_HAVE_GETCHARBUFFER | + Py_TPFLAGS_HAVE_NEWBUFFER, +#else Py_TPFLAGS_DEFAULT, /* tp_flags */ +#endif Blob__doc__, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ diff --git a/test/test_blob.py b/test/test_blob.py index da3db3e8f..9fd1d4f9e 100644 --- a/test/test_blob.py +++ b/test/test_blob.py @@ -74,6 +74,13 @@ def test_create_blob(self): self.assertEqual(BLOB_NEW_CONTENT, blob.data) self.assertEqual(len(BLOB_NEW_CONTENT), blob.size) self.assertEqual(BLOB_NEW_CONTENT, blob.read_raw()) + blob_buffer = memoryview(blob) + self.assertEqual(len(BLOB_NEW_CONTENT), len(blob_buffer)) + self.assertEqual(BLOB_NEW_CONTENT, blob_buffer) + def set_content(): + blob_buffer[:2] = b'hi' + + self.assertRaises(TypeError, set_content) def test_create_blob_fromworkdir(self):