Skip to content

Commit

Permalink
Fix set_type_codec() to accept standard SQL type names (MagicStack#619)
Browse files Browse the repository at this point in the history
Currently, `Connection.set_type_codec()` only accepts type names as they
appear in `pg_catalog.pg_type` and would refuse to handle a standard SQL
spelling of a type like `character varying`.  This is an oversight, as
the internal type names aren't really supposed to be treated as public
Postgres API.  Additionally, for historical reasons, Postgres has a
single-byte `"char"` type, which is distinct from both `varchar` and
SQL `char`, which may lead to massive confusion if a user sets up a
custom codec on it expecting to handle the `char(n)` type instead.

Issue: MagicStack#617
  • Loading branch information
elprans committed Sep 22, 2020
1 parent 4a627d5 commit 68b40cb
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 22 deletions.
51 changes: 30 additions & 21 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,32 @@ async def _introspect_types(self, typeoids, timeout):
return await self.__execute(
self._intro_query, (list(typeoids),), 0, timeout)

async def _introspect_type(self, typename, schema):
if (
schema == 'pg_catalog'
and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP
):
typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()]
rows = await self._execute(
introspection.TYPE_BY_OID,
[typeoid],
limit=0,
timeout=None,
)
if rows:
typeinfo = rows[0]
else:
typeinfo = None
else:
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)

if not typeinfo:
raise ValueError(
'unknown type: {}.{}'.format(schema, typename))

return typeinfo

def cursor(
self,
query,
Expand Down Expand Up @@ -1110,12 +1136,7 @@ async def set_type_codec(self, typename, *,
``format``.
"""
self._check_open()

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise ValueError(
'cannot use custom codec on non-scalar type {}.{}'.format(
Expand All @@ -1142,15 +1163,9 @@ async def reset_type_codec(self, typename, *, schema='public'):
.. versionadded:: 0.12.0
"""

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

oid = typeinfo['oid']

typeinfo = await self._introspect_type(typename, schema)
self._protocol.get_settings().remove_python_codec(
oid, typename, schema)
typeinfo['oid'], typename, schema)

# Statement cache is no longer valid due to codec changes.
self._drop_local_statement_cache()
Expand Down Expand Up @@ -1191,13 +1206,7 @@ async def set_builtin_type_codec(self, typename, *,
core data type. Added the *format* keyword argument.
"""
self._check_open()

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise exceptions.InterfaceError(
'unknown type: {}.{}'.format(schema, typename))

typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise exceptions.InterfaceError(
'cannot alias non-scalar type {}.{}'.format(
Expand Down
12 changes: 12 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@
'''


TYPE_BY_OID = '''\
SELECT
t.oid,
t.typelem AS elemtype,
t.typtype AS kind
FROM
pg_catalog.pg_type AS t
WHERE
t.oid = $1
'''


# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')

Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

# flake8: NOQA

from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
18 changes: 18 additions & 0 deletions asyncpg/protocol/pgtypes.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,23 @@ BUILTIN_TYPE_NAME_MAP['double precision'] = \
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamptz']

BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamp']

BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timetz']

BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
BUILTIN_TYPE_NAME_MAP['time']

BUILTIN_TYPE_NAME_MAP['char'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']

BUILTIN_TYPE_NAME_MAP['character'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']

BUILTIN_TYPE_NAME_MAP['character varying'] = \
BUILTIN_TYPE_NAME_MAP['varchar']

BUILTIN_TYPE_NAME_MAP['bit varying'] = \
BUILTIN_TYPE_NAME_MAP['varbit']
33 changes: 33 additions & 0 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,39 @@ async def test_custom_codec_on_domain(self):
finally:
await self.con.execute('DROP DOMAIN custom_codec_t')

async def test_custom_codec_on_stdsql_types(self):
types = [
'smallint',
'int',
'integer',
'bigint',
'decimal',
'real',
'double precision',
'timestamp with timezone',
'time with timezone',
'timestamp without timezone',
'time without timezone',
'char',
'character',
'character varying',
'bit varying',
'CHARACTER VARYING'
]

for t in types:
with self.subTest(type=t):
try:
await self.con.set_type_codec(
t,
schema='pg_catalog',
encoder=str,
decoder=str,
format='text'
)
finally:
await self.con.reset_type_codec(t, schema='pg_catalog')

async def test_custom_codec_on_enum(self):
"""Test encoding/decoding using a custom codec on an enum."""
await self.con.execute('''
Expand Down

0 comments on commit 68b40cb

Please sign in to comment.