Skip to content

Commit

Permalink
Untangle custom codec confusion (MagicStack#662)
Browse files Browse the repository at this point in the history
Asyncpg currently erroneously prefers binary I/O for underlying type of
arrays effectively ignoring a possible custom text codec that might have
been configured on a type.

Fix this by removing the explicit preference for binary I/O, so that the
codec selection preference is now in the following order:

- custom binary codec
- custom text codec
- builtin binary codec
- builtin text codec

Fixes: MagicStack#590
Reported-by: @neumond
  • Loading branch information
elprans committed Dec 2, 2020
1 parent 7252dbe commit 50f65fb
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 105 deletions.
9 changes: 9 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,15 @@ async def set_type_codec(self, typename, *,
.. versionchanged:: 0.13.0
The ``binary`` keyword argument was removed in favor of
``format``.
.. note::
It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
whenever possible and if the underlying type supports it. Asyncpg
currently does not support text I/O for composite and range types,
and some other functionality, such as
:meth:`Connection.copy_to_table`, does not support types with text
codecs.
"""
self._check_open()
typeinfo = await self._introspect_type(typename, schema)
Expand Down
31 changes: 10 additions & 21 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,9 @@
ELSE NULL
END) AS basetype,
t.typreceive::oid != 0 AND t.typsend::oid != 0
AS has_bin_io,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
range_t.rngsubtype AS range_subtype,
(CASE WHEN t.typtype = 'r' THEN
(SELECT
range_elem_t.typreceive::oid != 0 AND
range_elem_t.typsend::oid != 0
FROM
pg_catalog.pg_type AS range_elem_t
WHERE
range_elem_t.oid = range_t.rngsubtype)
ELSE
elem_t.typreceive::oid != 0 AND
elem_t.typsend::oid != 0
END) AS elem_has_bin_io,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
Expand Down Expand Up @@ -98,12 +84,12 @@

INTRO_LOOKUP_TYPES = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim,
range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth)
oid, ns, name, kind, basetype, elemtype, elemdelim,
range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
Expand All @@ -113,8 +99,8 @@
UNION ALL
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
Expand All @@ -126,7 +112,10 @@
)
SELECT DISTINCT
*
*,
basetype::regtype::text AS basetype_name,
elemtype::regtype::text AS elemtype_name,
range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,5 @@ cdef class DataCodecConfig:

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=*)
cdef inline Codec get_any_local_codec(self, uint32_t oid)
cdef inline Codec get_custom_codec(self, uint32_t oid,
ServerDataFormat format)
137 changes: 64 additions & 73 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,7 @@ cdef class DataCodecConfig:
for ti in types:
oid = ti['oid']

if not ti['has_bin_io']:
format = PG_FORMAT_TEXT
else:
format = PG_FORMAT_BINARY

has_text_elements = False

if self.get_codec(oid, format) is not None:
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
continue

name = ti['name']
Expand All @@ -468,92 +461,79 @@ cdef class DataCodecConfig:
name = name[1:]
name = '{}[]'.format(name)

if ti['elem_has_bin_io']:
elem_format = PG_FORMAT_BINARY
else:
elem_format = PG_FORMAT_TEXT

elem_codec = self.get_codec(array_element_oid, elem_format)
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
array_element_oid, name, schema)
array_element_oid, ti['elemtype_name'], schema)

elem_delim = <Py_UCS4>ti['elemdelim'][0]

self._derived_type_codecs[oid, elem_format] = \
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_array_codec(
oid, name, schema, elem_codec, elem_delim)

elif ti['kind'] == b'c':
# Composite type

if not comp_type_attrs:
raise exceptions.InternalClientError(
'type record missing field types for '
'composite {}'.format(oid))

# Composite type
f'type record missing field types for composite {oid}')

comp_elem_codecs = []
has_text_elements = False

for typoid in comp_type_attrs:
elem_codec = self.get_codec(typoid, PG_FORMAT_BINARY)
if elem_codec is None:
elem_codec = self.get_codec(typoid, PG_FORMAT_TEXT)
has_text_elements = True
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
if elem_codec is None:
raise exceptions.InternalClientError(
'no codec for composite attribute type {}'.format(
typoid))
f'no codec for composite attribute type {typoid}')
if elem_codec.format is PG_FORMAT_TEXT:
has_text_elements = True
comp_elem_codecs.append(elem_codec)

element_names = collections.OrderedDict()
for i, attrname in enumerate(ti['attrnames']):
element_names[attrname] = i

# If at least one element is text-encoded, we must
# encode the whole composite as text.
if has_text_elements:
format = PG_FORMAT_TEXT
elem_format = PG_FORMAT_TEXT
else:
elem_format = PG_FORMAT_BINARY

self._derived_type_codecs[oid, format] = \
self._derived_type_codecs[oid, elem_format] = \
Codec.new_composite_codec(
oid, name, schema, format, comp_elem_codecs,
oid, name, schema, elem_format, comp_elem_codecs,
comp_type_attrs, element_names)

elif ti['kind'] == b'd':
# Domain type

if not base_type:
raise exceptions.InternalClientError(
'type record missing base type for domain {}'.format(
oid))
f'type record missing base type for domain {oid}')

elem_codec = self.get_codec(base_type, format)
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
if elem_codec is None:
format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
base_type, name, schema)
base_type, ti['basetype_name'], schema)

self._derived_type_codecs[oid, format] = elem_codec
self._derived_type_codecs[oid, elem_codec.format] = elem_codec

elif ti['kind'] == b'r':
# Range type

if not range_subtype_oid:
raise exceptions.InternalClientError(
'type record missing base type for range {}'.format(
oid))
f'type record missing base type for range {oid}')

if ti['elem_has_bin_io']:
elem_format = PG_FORMAT_BINARY
else:
elem_format = PG_FORMAT_TEXT

elem_codec = self.get_codec(range_subtype_oid, elem_format)
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
range_subtype_oid, name, schema)
range_subtype_oid, ti['range_subtype_name'], schema)

self._derived_type_codecs[oid, elem_format] = \
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)

elif ti['kind'] == b'e':
Expand Down Expand Up @@ -665,10 +645,6 @@ cdef class DataCodecConfig:
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
cdef Codec codec

codec = self.get_codec(oid, PG_FORMAT_TEXT)
if codec is not None:
return codec

if oid <= MAXBUILTINOID:
# This is a BKI type, for which asyncpg has no
# defined codec. This should only happen for newly
Expand All @@ -695,34 +671,49 @@ cdef class DataCodecConfig:
bint ignore_custom_codec=False):
cdef Codec codec

if not ignore_custom_codec:
codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec

codec = get_core_codec(oid, format)
if codec is not None:
if format == PG_FORMAT_ANY:
codec = self.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
try:
return self._derived_type_codecs[oid, format]
except KeyError:
return None
if not ignore_custom_codec:
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec

codec = get_core_codec(oid, format)
if codec is not None:
return codec
else:
try:
return self._derived_type_codecs[oid, format]
except KeyError:
return None

cdef inline Codec get_any_local_codec(self, uint32_t oid):
cdef inline Codec get_custom_codec(
self,
uint32_t oid,
ServerDataFormat format
):
cdef Codec codec

codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY))
if codec is None:
return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT))
if format == PG_FORMAT_ANY:
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
if codec is None:
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
else:
return codec
codec = self._custom_type_codecs.get((oid, format))

return codec


cdef inline Codec get_core_codec(
Expand Down
11 changes: 1 addition & 10 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,7 @@ cdef class ConnectionSettings(pgproto.CodecContext):
cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY,
bint ignore_custom_codec=False):
if format == PG_FORMAT_ANY:
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
return self._data_codecs.get_codec(
oid, format, ignore_custom_codec)
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)

def __getattr__(self, name):
if not name.startswith('_'):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,34 @@ async def test_custom_codec_on_enum(self):
finally:
await self.con.execute('DROP TYPE custom_codec_t')

async def test_custom_codec_on_enum_array(self):
"""Test encoding/decoding using a custom codec on an enum array.
Bug: https://github.com/MagicStack/asyncpg/issues/590
"""
await self.con.execute('''
CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz')
''')

try:
await self.con.set_type_codec(
'custom_codec_t',
encoder=lambda v: str(v).lstrip('enum :'),
decoder=lambda v: 'enum: ' + str(v))

v = await self.con.fetchval(
"SELECT ARRAY['foo', 'bar']::custom_codec_t[]")
self.assertEqual(v, ['enum: foo', 'enum: bar'])

v = await self.con.fetchval(
'SELECT ARRAY[$1]::custom_codec_t[]', 'foo')
self.assertEqual(v, ['enum: foo'])

v = await self.con.fetchval("SELECT 'foo'::custom_codec_t")
self.assertEqual(v, 'enum: foo')
finally:
await self.con.execute('DROP TYPE custom_codec_t')

async def test_custom_codec_override_binary(self):
"""Test overriding core codecs."""
import json
Expand Down Expand Up @@ -1374,6 +1402,14 @@ def _decoder(value):
res = await conn.fetchval('SELECT $1::json', data)
self.assertEqual(data, res)

res = await conn.fetchval('SELECT $1::json[]', [data])
self.assertEqual([data], res)

await conn.execute('CREATE DOMAIN my_json AS json')

res = await conn.fetchval('SELECT $1::my_json', data)
self.assertEqual(data, res)

def _encoder(value):
return value

Expand All @@ -1389,6 +1425,7 @@ def _decoder(value):
res = await conn.fetchval('SELECT $1::uuid', data)
self.assertEqual(res, data)
finally:
await conn.execute('DROP DOMAIN IF EXISTS my_json')
await conn.close()

async def test_custom_codec_override_tuple(self):
Expand Down

0 comments on commit 50f65fb

Please sign in to comment.