Skip to content

Commit

Permalink
Add support for Python 3.12's type aliases
Browse files Browse the repository at this point in the history
This adds support for the new syntactic type aliases added in Python
3.12. A few examples:

```
type NullableStr = str | None

type Pair[T] = tuple[T, T]

type NullableStrPair = Pair[NullableStr]
```

msgspec now supports these type aliases, *except* in cases where the
type alias is recursive. For example, the following type isn't
supported:

```
type Link[T] = tuple[T, Link[T] | None]
```

The internal datastructure we use to store type information was not
designed to handle recursive types like these; supporting them will
require a larger refactor.
  • Loading branch information
jcrist committed Dec 8, 2023
1 parent ba316c4 commit f8d2c1a
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 30 deletions.
32 changes: 32 additions & 0 deletions docs/source/supported-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Most combinations of the following types are supported (with a few restrictions)
- `typing.Literal`
- `typing.NewType`
- `typing.Final`
- `typing.TypeAliasType`
- `typing.TypeAlias`
- `typing.NamedTuple` / `collections.namedtuple`
- `typing.TypedDict`
- `typing.Generic`
Expand Down Expand Up @@ -1170,6 +1172,36 @@ support here is purely to aid static analysis tools like mypy_ or pyright_.
File "<stdin>", line 1, in <module>
msgspec.ValidationError: Expected `int`, got `str`
Type Aliases
------------

For complex types, sometimes it can be nice to write the type once so you can
reuse it later.

.. code-block:: python
Point = tuple[float, float]
Here ``Point`` is a "type alias" for ``tuple[float, float]`` - ``msgspec``
will substitute in ``tuple[float, float]`` whenever the ``Point`` type
is used in an annotation.

``msgspec`` supports the following equivalent forms:

.. code-block:: python
# Using variable assignment
Point = tuple[float, float]
# Using variable assignment, annotated as a `TypeAlias`
Point: TypeAlias = tuple[float, float]
# Using Python 3.12's new `type` statement. This only works on Python 3.12+
type Point = tuple[float, float]
To learn more about Type Aliases, see Python's `Type Alias docs here
<https://docs.python.org/3/library/typing.html#type-aliases>`__.

Generic Types
-------------

Expand Down
107 changes: 83 additions & 24 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
#include "ryu.h"
#include "atof.h"

/* Python version checks */
#define PY39_PLUS (PY_VERSION_HEX >= 0x03090000)
#define PY310_PLUS (PY_VERSION_HEX >= 0x030a0000)
#define PY311_PLUS (PY_VERSION_HEX >= 0x030b0000)
#define PY312_PLUS (PY_VERSION_HEX >= 0x030c0000)

/* Hint to the compiler not to store `x` in a register since it is likely to
* change. Results in much higher performance on GCC, with smaller benefits on
* clang */
Expand All @@ -36,18 +42,18 @@ ms_popcount(uint64_t i) { \
}
#endif

#if PY_VERSION_HEX < 0x03090000
#define CALL_ONE_ARG(f, a) PyObject_CallFunctionObjArgs(f, a, NULL)
#define CALL_NO_ARGS(f) PyObject_CallFunctionObjArgs(f, NULL)
#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodObjArgs(o, n, a, NULL)
#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodObjArgs(o, n, NULL)
#define SET_SIZE(obj, size) (((PyVarObject *)obj)->ob_size = size)
#else
#if PY39_PLUS
#define CALL_ONE_ARG(f, a) PyObject_CallOneArg(f, a)
#define CALL_NO_ARGS(f) PyObject_CallNoArgs(f)
#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodOneArg(o, n, a)
#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodNoArgs(o, n)
#define SET_SIZE(obj, size) Py_SET_SIZE(obj, size)
#else
#define CALL_ONE_ARG(f, a) PyObject_CallFunctionObjArgs(f, a, NULL)
#define CALL_NO_ARGS(f) PyObject_CallFunctionObjArgs(f, NULL)
#define CALL_METHOD_ONE_ARG(o, n, a) PyObject_CallMethodObjArgs(o, n, a, NULL)
#define CALL_METHOD_NO_ARGS(o, n) PyObject_CallMethodObjArgs(o, n, NULL)
#define SET_SIZE(obj, size) (((PyVarObject *)obj)->ob_size = size)
#endif

#define DIV_ROUND_CLOSEST(n, d) ((((n) < 0) == ((d) < 0)) ? (((n) + (d)/2)/(d)) : (((n) - (d)/2)/(d)))
Expand Down Expand Up @@ -157,7 +163,7 @@ fast_long_extract_parts(PyObject *vv, bool *neg, uint64_t *scale) {
uint64_t prev, x = 0;
bool negative;

#if PY_VERSION_HEX >= 0x030c0000
#if PY312_PLUS
/* CPython 3.12 changed int internal representation */
int sign = 1 - (v->long_value.lv_tag & _PyLong_SIGN_MASK);
negative = sign == -1;
Expand Down Expand Up @@ -405,6 +411,9 @@ typedef struct {
PyObject *str___dataclass_fields__;
PyObject *str___attrs_attrs__;
PyObject *str___supertype__;
#if PY312_PLUS
PyObject *str___value__;
#endif
PyObject *str___bound__;
PyObject *str___constraints__;
PyObject *str_int;
Expand All @@ -427,8 +436,11 @@ typedef struct {
PyObject *get_typeddict_info;
PyObject *get_dataclass_info;
PyObject *rebuild;
#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
PyObject *types_uniontype;
#endif
#if PY312_PLUS
PyObject *typing_typealiastype;
#endif
PyObject *astimezone;
PyObject *re_compile;
Expand Down Expand Up @@ -2122,7 +2134,7 @@ PyTypeObject NoDefault_Type = {
.tp_basicsize = 0
};

#if PY_VERSION_HEX >= 0x030c0000
#if PY312_PLUS
PyObject _NoDefault_Object = {
_PyObject_EXTRA_INIT
{ _Py_IMMORTAL_REFCNT },
Expand Down Expand Up @@ -2226,7 +2238,7 @@ PyTypeObject Unset_Type = {
.tp_basicsize = 0
};

#if PY_VERSION_HEX >= 0x030c0000
#if PY312_PLUS
PyObject _Unset_Object = {
_PyObject_EXTRA_INIT
{ _Py_IMMORTAL_REFCNT },
Expand Down Expand Up @@ -4459,6 +4471,21 @@ typenode_origin_args_metadata(
t = temp;
continue;
}
/* Check for parametrized TypeAliasType if Python 3.12+ */
#if PY312_PLUS
if (Py_TYPE(origin) == (PyTypeObject *)(state->mod->typing_typealiastype)) {
PyObject *value = PyObject_GetAttr(origin, state->mod->str___value__);
if (value == NULL) goto error;
PyObject *temp = PyObject_GetItem(value, args);
Py_DECREF(value);
if (temp == NULL) goto error;
Py_CLEAR(args);
Py_CLEAR(origin);
Py_DECREF(t);
t = temp;
continue;
}
#endif
}
else {
/* Custom non-parametrized generics won't have __args__
Expand Down Expand Up @@ -4487,14 +4514,23 @@ typenode_origin_args_metadata(
t = supertype;
continue;
}
else {
PyErr_Clear();
break;
PyErr_Clear();

/* Check for TypeAliasType if Python 3.12+ */
#if PY312_PLUS
if (Py_TYPE(t) == (PyTypeObject *)(state->mod->typing_typealiastype)) {
PyObject *value = PyObject_GetAttr(t, state->mod->str___value__);
if (value == NULL) goto error;
Py_DECREF(t);
t = value;
continue;
}
#endif
break;
}
}

#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
if (Py_TYPE(t) == (PyTypeObject *)(state->mod->types_uniontype)) {
/* Handle types.UnionType unions (`int | float | ...`) */
args = PyObject_GetAttr(t, state->mod->str___args__);
Expand Down Expand Up @@ -4692,13 +4728,18 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
}
}
else if (origin == state->mod->typing_union) {
if (Py_EnterRecursiveCall(" while analyzing a type")) {
out = -1;
goto done;
}
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(args); i++) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
/* Ignore UnsetType in unions */
if (arg == (PyObject *)(&Unset_Type)) continue;
out = typenode_collect_type(state, arg);
if (out < 0) break;
}
Py_LeaveRecursiveCall();
}
else if (origin == state->mod->typing_literal) {
if (state->literals == NULL) {
Expand Down Expand Up @@ -4761,6 +4802,8 @@ TypeNode_Convert(PyObject *obj) {
state.mod = msgspec_get_global_state();
state.context = obj;

if (Py_EnterRecursiveCall(" while analyzing a type")) return NULL;

/* Traverse `obj` to collect all type annotations at this level */
if (typenode_collect_type(&state, obj) < 0) goto done;
/* Handle structs in a second pass */
Expand All @@ -4773,6 +4816,7 @@ TypeNode_Convert(PyObject *obj) {
out = typenode_from_collect_state(&state);
done:
typenode_collect_clear_state(&state);
Py_LeaveRecursiveCall();
return out;
}

Expand Down Expand Up @@ -9717,14 +9761,14 @@ ms_encode_err_type_unsupported(PyTypeObject *type) {
*************************************************************************/

#define MS_HAS_TZINFO(o) (((_PyDateTime_BaseTZInfo *)(o))->hastzinfo)
#if PY_VERSION_HEX < 0x030a00f0
#if PY310_PLUS
#define MS_DATE_GET_TZINFO(o) PyDateTime_DATE_GET_TZINFO(o)
#define MS_TIME_GET_TZINFO(o) PyDateTime_TIME_GET_TZINFO(o)
#else
#define MS_DATE_GET_TZINFO(o) (MS_HAS_TZINFO(o) ? \
((PyDateTime_DateTime *)(o))->tzinfo : Py_None)
#define MS_TIME_GET_TZINFO(o) (MS_HAS_TZINFO(o) ? \
((PyDateTime_Time *)(o))->tzinfo : Py_None)
#else
#define MS_DATE_GET_TZINFO(o) PyDateTime_DATE_GET_TZINFO(o)
#define MS_TIME_GET_TZINFO(o) PyDateTime_TIME_GET_TZINFO(o)
#endif

#ifndef TIMEZONE_CACHE_SIZE
Expand Down Expand Up @@ -15472,7 +15516,7 @@ static struct PyMethodDef Decoder_methods[] = {
"decode", (PyCFunction) Decoder_decode, METH_FASTCALL,
Decoder_decode__doc__,
},
#if PY_VERSION_HEX >= 0x03090000
#if PY39_PLUS
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS},
#endif
{NULL, NULL} /* sentinel */
Expand Down Expand Up @@ -18512,7 +18556,7 @@ static struct PyMethodDef JSONDecoder_methods[] = {
"decode_lines", (PyCFunction) JSONDecoder_decode_lines, METH_FASTCALL,
JSONDecoder_decode_lines__doc__,
},
#if PY_VERSION_HEX >= 0x03090000
#if PY39_PLUS
{"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS},
#endif
{NULL, NULL} /* sentinel */
Expand Down Expand Up @@ -21029,6 +21073,9 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->str___dataclass_fields__);
Py_CLEAR(st->str___attrs_attrs__);
Py_CLEAR(st->str___supertype__);
#if PY312_PLUS
Py_CLEAR(st->str___value__);
#endif
Py_CLEAR(st->str___bound__);
Py_CLEAR(st->str___constraints__);
Py_CLEAR(st->str_int);
Expand All @@ -21051,8 +21098,11 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->get_typeddict_info);
Py_CLEAR(st->get_dataclass_info);
Py_CLEAR(st->rebuild);
#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
Py_CLEAR(st->types_uniontype);
#endif
#if PY312_PLUS
Py_CLEAR(st->typing_typealiastype);
#endif
Py_CLEAR(st->astimezone);
Py_CLEAR(st->re_compile);
Expand Down Expand Up @@ -21118,8 +21168,11 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
Py_VISIT(st->get_typeddict_info);
Py_VISIT(st->get_dataclass_info);
Py_VISIT(st->rebuild);
#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
Py_VISIT(st->types_uniontype);
#endif
#if PY312_PLUS
Py_VISIT(st->typing_typealiastype);
#endif
Py_VISIT(st->astimezone);
Py_VISIT(st->re_compile);
Expand Down Expand Up @@ -21315,6 +21368,9 @@ PyInit__core(void)
SET_REF(typing_final, "Final");
SET_REF(typing_generic, "Generic");
SET_REF(typing_generic_alias, "_GenericAlias");
#if PY312_PLUS
SET_REF(typing_typealiastype, "TypeAliasType");
#endif
Py_DECREF(temp_module);

temp_module = PyImport_ImportModule("msgspec._utils");
Expand All @@ -21328,7 +21384,7 @@ PyInit__core(void)
SET_REF(rebuild, "rebuild");
Py_DECREF(temp_module);

#if PY_VERSION_HEX >= 0x030a00f0
#if PY310_PLUS
temp_module = PyImport_ImportModule("types");
if (temp_module == NULL) return NULL;
SET_REF(types_uniontype, "UnionType");
Expand Down Expand Up @@ -21411,6 +21467,9 @@ PyInit__core(void)
CACHED_STRING(str___dataclass_fields__, "__dataclass_fields__");
CACHED_STRING(str___attrs_attrs__, "__attrs_attrs__");
CACHED_STRING(str___supertype__, "__supertype__");
#if PY312_PLUS
CACHED_STRING(str___value__, "__value__");
#endif
CACHED_STRING(str___bound__, "__bound__");
CACHED_STRING(str___constraints__, "__constraints__");
CACHED_STRING(str_int, "int");
Expand Down
9 changes: 9 additions & 0 deletions msgspec/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
except Exception:
_types_UnionType = type("UnionType", (), {}) # type: ignore

try:
from typing import TypeAliasType as _TypeAliasType # type: ignore
except Exception:
_TypeAliasType = type("TypeAliasType", (), {}) # type: ignore

import msgspec
from msgspec import NODEFAULT, UNSET, UnsetType as _UnsetType

Expand Down Expand Up @@ -628,6 +633,8 @@ def _origin_args_metadata(t):
t = origin
elif origin == Final:
t = t.__args__[0]
elif type(origin) is _TypeAliasType:
t = origin.__value__[t.__args__]
else:
args = getattr(t, "__args__", None)
origin = _CONCRETE_TYPES.get(origin, origin)
Expand All @@ -636,6 +643,8 @@ def _origin_args_metadata(t):
supertype = getattr(t, "__supertype__", None)
if supertype is not None:
t = supertype
elif type(t) is _TypeAliasType:
t = t.__value__
else:
origin = t
args = None
Expand Down
Loading

0 comments on commit f8d2c1a

Please sign in to comment.