Skip to content

Commit

Permalink
Prohibit custom codecs on domains
Browse files Browse the repository at this point in the history
Postgres always includes the base type OID in the RowDescription message
even if the query is technically returning domain values.  This makes
custom codecs on domains ineffective, and so prohibit them to avoid
confusion and bug reports.

See postgres/postgres@d9b679c and
https://postgr.es/m/27307.1047485980%40sss.pgh.pa.us for context.

Fixes: MagicStack#457.
  • Loading branch information
elprans committed Dec 2, 2020
1 parent b53f038 commit 50f964f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 29 deletions.
11 changes: 10 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,9 +1160,18 @@ async def set_type_codec(self, typename, *,
self._check_open()
typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise ValueError(
raise exceptions.InterfaceError(
'cannot use custom codec on non-scalar type {}.{}'.format(
schema, typename))
if introspection.is_domain_type(typeinfo):
raise exceptions.UnsupportedClientFeatureError(
'custom codecs on domain types are not supported',
hint='Set the codec on the base type.',
detail=(
'PostgreSQL does not distinguish domains from '
'their base types in query results at the protocol level.'
)
)

oid = typeinfo['oid']
self._protocol.get_settings().add_python_codec(
Expand Down
7 changes: 6 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError')
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -214,6 +215,10 @@ class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""


class UnsupportedClientFeatureError(InterfaceError):
"""Requested feature is unsupported by asyncpg."""


class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""

Expand Down
4 changes: 4 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ def is_scalar_type(typeinfo) -> bool:
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)


def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
9 changes: 4 additions & 5 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ cdef class Codec:
self.decoder = <codec_decode_func>&self.decode_array_text
elif type == CODEC_RANGE:
if format != PG_FORMAT_BINARY:
raise NotImplementedError(
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_range
self.decoder = <codec_decode_func>&self.decode_range
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
raise NotImplementedError(
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'composite types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_composite
Expand Down Expand Up @@ -675,9 +675,8 @@ cdef class DataCodecConfig:
# added builtin types, for which this version of
# asyncpg is lacking support.
#
raise NotImplementedError(
'unhandled standard data type {!r} (OID {})'.format(
name, oid))
raise exceptions.UnsupportedClientFeatureError(
f'unhandled standard data type {name!r} (OID {oid})')
else:
# This is a non-BKI type, and as such, has no
# stable OID, so no possibility of a builtin codec.
Expand Down
36 changes: 14 additions & 22 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ async def test_extra_codec_alias(self):
# This should fail, as there is no binary codec for
# my_dec_t and text decoding of composites is not
# implemented.
with self.assertRaises(NotImplementedError):
with self.assertRaises(asyncpg.UnsupportedClientFeatureError):
res = await self.con.fetchval('''
SELECT ($1::my_dec_t, 'a=>1'::hstore)::rec_t AS result
''', 44)
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def hstore_encoder(obj):
self.assertEqual(at[0].type, pt[0])

err = 'cannot use custom codec on non-scalar type public._hstore'
with self.assertRaisesRegex(ValueError, err):
with self.assertRaisesRegex(asyncpg.InterfaceError, err):
await self.con.set_type_codec('_hstore',
encoder=hstore_encoder,
decoder=hstore_decoder)
Expand All @@ -1160,7 +1160,7 @@ def hstore_encoder(obj):
try:
err = 'cannot use custom codec on non-scalar type ' + \
'public.mytype'
with self.assertRaisesRegex(ValueError, err):
with self.assertRaisesRegex(asyncpg.InterfaceError, err):
await self.con.set_type_codec(
'mytype', encoder=hstore_encoder,
decoder=hstore_decoder)
Expand Down Expand Up @@ -1261,13 +1261,14 @@ async def test_custom_codec_on_domain(self):
''')

try:
await self.con.set_type_codec(
'custom_codec_t',
encoder=lambda v: str(v),
decoder=lambda v: int(v))

v = await self.con.fetchval('SELECT $1::custom_codec_t', 10)
self.assertEqual(v, 10)
with self.assertRaisesRegex(
asyncpg.UnsupportedClientFeatureError,
'custom codecs on domain types are not supported'
):
await self.con.set_type_codec(
'custom_codec_t',
encoder=lambda v: str(v),
decoder=lambda v: int(v))
finally:
await self.con.execute('DROP DOMAIN custom_codec_t')

Expand Down Expand Up @@ -1666,7 +1667,7 @@ async def test_unknown_type_text_fallback(self):
# Text encoding of ranges and composite types
# is not supported yet.
with self.assertRaisesRegex(
RuntimeError,
asyncpg.UnsupportedClientFeatureError,
'text encoding of range types is not supported'):

await self.con.fetchval('''
Expand All @@ -1675,7 +1676,7 @@ async def test_unknown_type_text_fallback(self):
''', ['a', 'z'])

with self.assertRaisesRegex(
RuntimeError,
asyncpg.UnsupportedClientFeatureError,
'text encoding of composite types is not supported'):

await self.con.fetchval('''
Expand Down Expand Up @@ -1847,7 +1848,7 @@ async def test_custom_codec_large_oid(self):

expected_oid = self.LARGE_OID
if self.server_version >= (11, 0):
# PostgreSQL 11 automatically create a domain array type
# PostgreSQL 11 automatically creates a domain array type
# _before_ the domain type, so the expected OID is
# off by one.
expected_oid += 1
Expand All @@ -1858,14 +1859,5 @@ async def test_custom_codec_large_oid(self):
v = await self.con.fetchval('SELECT $1::test_domain_t', 10)
self.assertEqual(v, 10)

# Test that custom codec logic handles large OIDs
await self.con.set_type_codec(
'test_domain_t',
encoder=lambda v: str(v),
decoder=lambda v: int(v))

v = await self.con.fetchval('SELECT $1::test_domain_t', 10)
self.assertEqual(v, 10)

finally:
await self.con.execute('DROP DOMAIN test_domain_t')

0 comments on commit 50f964f

Please sign in to comment.