Skip to content

Commit f1b79c2

Browse files
committed
Fixed repr for quad dtype; added tests
1 parent 914afe8 commit f1b79c2

File tree

7 files changed

+65
-79
lines changed

7 files changed

+65
-79
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,12 @@ jobs:
5555
working-directory: unytdtype
5656
run: |
5757
pytest -vvv --color=yes
58+
- name: Install quaddtype
59+
working-directory: quaddtype
60+
run: |
61+
python -m build --no-isolation --wheel -Cbuilddir=build
62+
find ./dist/*.whl | xargs python -m pip install
63+
- name: Run quaddtype tests
64+
working-directory: quaddtype
65+
run: |
66+
pytest -vvv --color=yes

quaddtype/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ requires = [
44
"meson-python",
55
"patchelf",
66
"wheel",
7-
"numpy @ file:///home/pdmurray/Desktop/numpy-1.25.0.dev0+167.g4ca8204c5-cp311-cp311d-linux_x86_64.whl"
7+
"numpy"
88
]
99
build-backend = "mesonpy"
1010

@@ -16,7 +16,7 @@ readme = 'README.md'
1616
author = "Peyton Murray"
1717
requires-python = ">=3.9.0"
1818
dependencies = [
19-
"numpy @ file:///home/pdmurray/Desktop/numpy-1.25.0.dev0+167.g4ca8204c5-cp311-cp311d-linux_x86_64.whl"
19+
"numpy"
2020
]
2121

2222
[project.optional-dependencies]

quaddtype/quaddtype/src/casts.c

Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
2626
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
2727
npy_intp *view_offset)
2828
{
29+
Py_INCREF(given_descrs[0]);
30+
loop_descrs[0] = given_descrs[0];
31+
32+
if (given_descrs[1] == NULL) {
33+
Py_INCREF(given_descrs[0]);
34+
loop_descrs[1] = given_descrs[0];
35+
}
36+
else {
37+
Py_INCREF(given_descrs[1]);
38+
loop_descrs[1] = given_descrs[1];
39+
}
40+
2941
return NPY_SAME_KIND_CASTING;
3042
}
3143

@@ -140,69 +152,3 @@ PyArrayMethod_Spec QuadToQuadCastSpec = {
140152
.dtypes = QuadToQuadDtypes,
141153
.slots = QuadToQuadSlots,
142154
};
143-
144-
// Quad to Float128
145-
static NPY_CASTING
146-
quad_to_float128_resolve_descriptors(PyObject *NPY_UNUSED(self),
147-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
148-
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
149-
npy_intp *view_offset)
150-
{
151-
return NPY_SAME_KIND_CASTING;
152-
}
153-
154-
static int
155-
quad_to_float128_contiguous(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
156-
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
157-
{
158-
return 0;
159-
}
160-
161-
static int
162-
quad_to_float128_strided(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
163-
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
164-
{
165-
return 0;
166-
}
167-
168-
static int
169-
quad_to_float128_unaligned(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
170-
npy_intp const dimensions[], npy_intp const strides[], void *auxdata)
171-
{
172-
return 0;
173-
}
174-
175-
static int
176-
quad_to_float128_get_loop(PyArrayMethod_Context *context, int aligned,
177-
int NPY_UNUSED(move_references), const npy_intp *strides,
178-
PyArrayMethod_StridedLoop **out_loop, NpyAuxData **out_transferdata,
179-
NPY_ARRAYMETHOD_FLAGS *flags)
180-
{
181-
int contig = (strides[0] == sizeof(__float128) && strides[1] == sizeof(__float128));
182-
183-
if (aligned && contig)
184-
*out_loop = (PyArrayMethod_StridedLoop *)&quad_to_float128_contiguous;
185-
else if (aligned)
186-
*out_loop = (PyArrayMethod_StridedLoop *)&quad_to_float128_strided;
187-
else
188-
*out_loop = (PyArrayMethod_StridedLoop *)&quad_to_float128_unaligned;
189-
190-
*flags = 0;
191-
return 0;
192-
}
193-
194-
static PyArray_DTypeMeta *QuadToFloat128Dtypes[2] = {NULL, NULL};
195-
static PyType_Slot QuadToFloat128Slots[] = {
196-
{NPY_METH_resolve_descriptors, &quad_to_float128_resolve_descriptors},
197-
{_NPY_METH_get_loop, &quad_to_float128_get_loop},
198-
{0, NULL}};
199-
200-
PyArrayMethod_Spec QuadToFloat128CastSpec = {
201-
.name = "cast_QuadDType_to_Float128",
202-
.nin = 1,
203-
.nout = 1,
204-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
205-
.casting = NPY_SAFE_CASTING,
206-
.dtypes = QuadToFloat128Dtypes,
207-
.slots = QuadToFloat128Slots,
208-
};

quaddtype/quaddtype/src/dtype.c

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "dtype.h"
2+
#include "abstract.h"
23
#include "casts.h"
34

45
PyTypeObject *QuadScalar_Type = NULL;
@@ -37,9 +38,13 @@ quad_getitem(QuadDTypeObject *descr, char *dataptr)
3738
return NULL;
3839
}
3940

40-
Py_DECREF(val_obj); // Why decrement this pointer? Shouldn't this be
41-
// Py_INCREF?
42-
return val_obj;
41+
// Need to create a new QuadScalar instance here and return that...
42+
PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)QuadScalar_Type, val_obj, NULL);
43+
if (res == NULL) {
44+
return NULL;
45+
}
46+
Py_DECREF(val_obj);
47+
return res;
4348
}
4449

4550
// For two instances of the same dtype, both have the same precision. Return
@@ -74,6 +79,8 @@ common_dtype(PyArray_DTypeMeta *self, PyArray_DTypeMeta *other)
7479
return (PyArray_DTypeMeta *)Py_NotImplemented;
7580
}
7681

82+
// Expected to have this, and that it does an incref; see NEP42
83+
// Without this you'll get weird memory corruption bugs in the casting code
7784
static QuadDTypeObject *
7885
quaddtype_ensure_canonical(QuadDTypeObject *self)
7986
{

quaddtype/quaddtype/src/umath.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ quad_multiply_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
6565
Py_INCREF(given_descrs[1]);
6666
loop_descrs[1] = given_descrs[1];
6767

68-
Py_INCREF(given_descrs[1]);
69-
loop_descrs[1] = given_descrs[1];
70-
return NPY_SAFE_CASTING;
68+
// For now, we assume that the input dtypes are all simply quads. We therefore
69+
// can just reuse the input dtype for the output dtype, given by loop_descrs[2]
70+
Py_INCREF(given_descrs[0]);
71+
loop_descrs[2] = given_descrs[0];
72+
return NPY_NO_CASTING;
7173
}
7274

7375
// Function that adds our multiply loop to NumPy's multiply ufunc.

quaddtype/tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
os.environ["NUMPY_EXPERIMENTAL_DTYPE_API"] = "1"

quaddtype/tests/test_quaddtype.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
1+
import numpy as np
12

3+
from quaddtype import QuadDType, QuadScalar
24

3-
def test_instantiate():
4-
"""Test that importing the dtype works."""
5-
import quaddtype # noqa: F401
65

6+
def test_dtype_creation():
7+
assert str(QuadDType()) == "This is a quad (128-bit float) dtype."
78

8-
if __name__ == '__main__':
9-
test_instantiate()
9+
10+
def test_scalar_creation():
11+
assert str(QuadScalar(3.1)) == "3.1"
12+
13+
14+
def test_create_with_explicit_dtype():
15+
assert repr(
16+
np.array([3.0, 3.1, 3.2], dtype=QuadDType())
17+
) == "array([3.0, 3.1, 3.2], dtype=This is a quad (128-bit float) dtype.)"
18+
19+
20+
def test_multiply():
21+
x = np.array([3, 8.0], dtype=QuadDType())
22+
assert str(x * x) == '[9.0 64.0]'
23+
24+
25+
def test_bytes():
26+
"""Check that each quad is 16 bytes."""
27+
x = np.array([3, 8.0, 1.4], dtype=QuadDType())
28+
assert len(x.tobytes()) == x.size * 16

0 commit comments

Comments
 (0)