Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ jobs:
working-directory: unytdtype
run: |
pytest -vvv --color=yes
- name: Install quaddtype
working-directory: quaddtype
run: |
python -m build --no-isolation --wheel -Cbuilddir=build
find ./dist/*.whl | xargs python -m pip install
- name: Run quaddtype tests
working-directory: quaddtype
run: |
pytest -vvv --color=yes
4 changes: 2 additions & 2 deletions quaddtype/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [
"meson-python",
"patchelf",
"wheel",
"numpy @ file:///home/pdmurray/Desktop/numpy-1.25.0.dev0+167.g4ca8204c5-cp311-cp311d-linux_x86_64.whl"
"numpy"
]
build-backend = "mesonpy"

Expand All @@ -16,7 +16,7 @@ readme = 'README.md'
author = "Peyton Murray"
requires-python = ">=3.9.0"
dependencies = [
"numpy @ file:///home/pdmurray/Desktop/numpy-1.25.0.dev0+167.g4ca8204c5-cp311-cp311d-linux_x86_64.whl"
"numpy"
]

[project.optional-dependencies]
Expand Down
78 changes: 12 additions & 66 deletions quaddtype/quaddtype/src/casts.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
npy_intp *view_offset)
{
Py_INCREF(given_descrs[0]);
loop_descrs[0] = given_descrs[0];

if (given_descrs[1] == NULL) {
Py_INCREF(given_descrs[0]);
loop_descrs[1] = given_descrs[0];
}
else {
Py_INCREF(given_descrs[1]);
loop_descrs[1] = given_descrs[1];
}

return NPY_SAME_KIND_CASTING;
}

Expand Down Expand Up @@ -140,69 +152,3 @@ PyArrayMethod_Spec QuadToQuadCastSpec = {
.dtypes = QuadToQuadDtypes,
.slots = QuadToQuadSlots,
};

// Quad to Float128
static NPY_CASTING
quad_to_float128_resolve_descriptors(PyObject *NPY_UNUSED(self),
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
npy_intp *view_offset)
{
return NPY_SAME_KIND_CASTING;
}

static int
quad_to_float128_contiguous(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
{
return 0;
}

static int
quad_to_float128_strided(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
{
return 0;
}

static int
quad_to_float128_unaligned(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
{
return 0;
}

static int
quad_to_float128_get_loop(PyArrayMethod_Context *context, int aligned,
int NPY_UNUSED(move_references), const npy_intp *strides,
PyArrayMethod_StridedLoop **out_loop, NpyAuxData **out_transferdata,
NPY_ARRAYMETHOD_FLAGS *flags)
{
int contig = (strides[0] == sizeof(__float128) && strides[1] == sizeof(__float128));

if (aligned && contig)
*out_loop = (PyArrayMethod_StridedLoop *)&quad_to_float128_contiguous;
else if (aligned)
*out_loop = (PyArrayMethod_StridedLoop *)&quad_to_float128_strided;
else
*out_loop = (PyArrayMethod_StridedLoop *)&quad_to_float128_unaligned;

*flags = 0;
return 0;
}

static PyArray_DTypeMeta *QuadToFloat128Dtypes[2] = {NULL, NULL};
static PyType_Slot QuadToFloat128Slots[] = {
{NPY_METH_resolve_descriptors, &quad_to_float128_resolve_descriptors},
{_NPY_METH_get_loop, &quad_to_float128_get_loop},
{0, NULL}};

PyArrayMethod_Spec QuadToFloat128CastSpec = {
.name = "cast_QuadDType_to_Float128",
.nin = 1,
.nout = 1,
.flags = NPY_METH_SUPPORTS_UNALIGNED,
.casting = NPY_SAFE_CASTING,
.dtypes = QuadToFloat128Dtypes,
.slots = QuadToFloat128Slots,
};
13 changes: 10 additions & 3 deletions quaddtype/quaddtype/src/dtype.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "dtype.h"
#include "abstract.h"
#include "casts.h"

PyTypeObject *QuadScalar_Type = NULL;
Expand Down Expand Up @@ -37,9 +38,13 @@ quad_getitem(QuadDTypeObject *descr, char *dataptr)
return NULL;
}

Py_DECREF(val_obj); // Why decrement this pointer? Shouldn't this be
// Py_INCREF?
return val_obj;
// Need to create a new QuadScalar instance here and return that...
PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)QuadScalar_Type, val_obj, NULL);
if (res == NULL) {
return NULL;
}
Py_DECREF(val_obj);
return res;
}

// For two instances of the same dtype, both have the same precision. Return
Expand Down Expand Up @@ -74,6 +79,8 @@ common_dtype(PyArray_DTypeMeta *self, PyArray_DTypeMeta *other)
return (PyArray_DTypeMeta *)Py_NotImplemented;
}

// Expected to have this, and that it does an incref; see NEP42
// Without this you'll get weird memory corruption bugs in the casting code
static QuadDTypeObject *
quaddtype_ensure_canonical(QuadDTypeObject *self)
{
Expand Down
8 changes: 5 additions & 3 deletions quaddtype/quaddtype/src/umath.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ quad_multiply_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
Py_INCREF(given_descrs[1]);
loop_descrs[1] = given_descrs[1];

Py_INCREF(given_descrs[1]);
loop_descrs[1] = given_descrs[1];
return NPY_SAFE_CASTING;
// For now, we assume that the input dtypes are all simply quads. We therefore
// can just reuse the input dtype for the output dtype, given by loop_descrs[2]
Py_INCREF(given_descrs[0]);
loop_descrs[2] = given_descrs[0];
return NPY_NO_CASTING;
}

// Function that adds our multiply loop to NumPy's multiply ufunc.
Expand Down
3 changes: 3 additions & 0 deletions quaddtype/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

os.environ["NUMPY_EXPERIMENTAL_DTYPE_API"] = "1"
29 changes: 24 additions & 5 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
import numpy as np

from quaddtype import QuadDType, QuadScalar

def test_instantiate():
"""Test that importing the dtype works."""
import quaddtype # noqa: F401

def test_dtype_creation():
assert str(QuadDType()) == "This is a quad (128-bit float) dtype."

if __name__ == '__main__':
test_instantiate()

def test_scalar_creation():
assert str(QuadScalar(3.1)) == "3.1"


def test_create_with_explicit_dtype():
assert repr(
np.array([3.0, 3.1, 3.2], dtype=QuadDType())
) == "array([3.0, 3.1, 3.2], dtype=This is a quad (128-bit float) dtype.)"


def test_multiply():
x = np.array([3, 8.0], dtype=QuadDType())
assert str(x * x) == '[9.0 64.0]'


def test_bytes():
"""Check that each quad is 16 bytes."""
x = np.array([3, 8.0, 1.4], dtype=QuadDType())
assert len(x.tobytes()) == x.size * 16