Skip to content

Commit

Permalink
Ignore custom data codec for internal introspection (MagicStack#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Sep 25, 2020
1 parent 68b40cb commit e064f59
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 32 deletions.
45 changes: 33 additions & 12 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,16 @@ async def _get_statement(
*,
named: bool=False,
use_cache: bool=True,
ignore_custom_codec=False,
record_class=None
):
if record_class is None:
record_class = self._protocol.get_record_class()

if use_cache:
statement = self._stmt_cache.get((query, record_class))
statement = self._stmt_cache.get(
(query, record_class, ignore_custom_codec)
)
if statement is not None:
return statement

Expand All @@ -371,6 +374,7 @@ async def _get_statement(
query,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
need_reprepare = False
types_with_missing_codecs = statement._init_types()
Expand Down Expand Up @@ -415,7 +419,8 @@ async def _get_statement(
)

if use_cache:
self._stmt_cache.put((query, record_class), statement)
self._stmt_cache.put(
(query, record_class, ignore_custom_codec), statement)

# If we've just created a new statement object, check if there
# are any statements for GC.
Expand All @@ -426,7 +431,12 @@ async def _get_statement(

async def _introspect_types(self, typeoids, timeout):
return await self.__execute(
self._intro_query, (list(typeoids),), 0, timeout)
self._intro_query,
(list(typeoids),),
0,
timeout,
ignore_custom_codec=True,
)

async def _introspect_type(self, typename, schema):
if (
Expand All @@ -439,20 +449,22 @@ async def _introspect_type(self, typename, schema):
[typeoid],
limit=0,
timeout=None,
ignore_custom_codec=True,
)
if rows:
typeinfo = rows[0]
else:
typeinfo = None
else:
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
rows = await self._execute(
introspection.TYPE_BY_NAME,
[typename, schema],
limit=1,
timeout=None,
ignore_custom_codec=True,
)

if not typeinfo:
if not rows:
raise ValueError(
'unknown type: {}.{}'.format(schema, typename))

return typeinfo
return rows[0]

def cursor(
self,
Expand Down Expand Up @@ -1325,7 +1337,9 @@ def _mark_stmts_as_closed(self):
def _maybe_gc_stmt(self, stmt):
if (
stmt.refs == 0
and not self._stmt_cache.has((stmt.query, stmt.record_class))
and not self._stmt_cache.has(
(stmt.query, stmt.record_class, stmt.ignore_custom_codec)
)
):
# If low-level `stmt` isn't referenced from any high-level
# `PreparedStatement` object and is not in the `_stmt_cache`:
Expand Down Expand Up @@ -1589,6 +1603,7 @@ async def _execute(
timeout,
*,
return_status=False,
ignore_custom_codec=False,
record_class=None
):
with self._stmt_exclusive_section:
Expand All @@ -1599,6 +1614,7 @@ async def _execute(
timeout,
return_status=return_status,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
return result

Expand All @@ -1610,6 +1626,7 @@ async def __execute(
timeout,
*,
return_status=False,
ignore_custom_codec=False,
record_class=None
):
executor = lambda stmt, timeout: self._protocol.bind_execute(
Expand All @@ -1620,6 +1637,7 @@ async def __execute(
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)

async def _executemany(self, query, args, timeout):
Expand All @@ -1637,20 +1655,23 @@ async def _do_execute(
timeout,
retry=True,
*,
ignore_custom_codec=False,
record_class=None
):
if timeout is None:
stmt = await self._get_statement(
query,
None,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
else:
before = time.monotonic()
stmt = await self._get_statement(
query,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
after = time.monotonic()
timeout -= after - before
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,6 @@ cdef class DataCodecConfig:
dict _derived_type_codecs
dict _custom_type_codecs

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=*)
cdef inline Codec get_any_local_codec(self, uint32_t oid)
22 changes: 12 additions & 10 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -692,18 +692,20 @@ cdef class DataCodecConfig:

return codec

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=False):
cdef Codec codec

codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec
if not ignore_custom_codec:
codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec

codec = get_core_codec(oid, format)
if codec is not None:
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cdef class PreparedStatementState:
readonly bint closed
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec


list row_desc
Expand Down
10 changes: 7 additions & 3 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ cdef class PreparedStatementState:
str name,
str query,
BaseProtocol protocol,
type record_class
type record_class,
bint ignore_custom_codec
):
self.name = name
self.query = query
Expand All @@ -28,6 +29,7 @@ cdef class PreparedStatementState:
self.closed = False
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec

def _get_parameters(self):
cdef Codec codec
Expand Down Expand Up @@ -205,7 +207,8 @@ cdef class PreparedStatementState:
cols_mapping[col_name] = i
cols_names.append(col_name)
oid = row[3]
codec = self.settings.get_data_codec(oid)
codec = self.settings.get_data_codec(
oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for OID {}'.format(oid))
Expand All @@ -230,7 +233,8 @@ cdef class PreparedStatementState:

for i from 0 <= i < self.args_num:
p_oid = self.parameters_desc[i]
codec = self.settings.get_data_codec(p_oid)
codec = self.settings.get_data_codec(
p_oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_encoder():
raise exceptions.InternalClientError(
'no encoder for OID {}'.format(p_oid))
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ cdef class BaseProtocol(CoreProtocol):
async def prepare(self, stmt_name, query, timeout,
*,
PreparedStatementState state=None,
ignore_custom_codec=False,
record_class):
if self.cancel_waiter is not None:
await self.cancel_waiter
Expand All @@ -161,7 +162,7 @@ cdef class BaseProtocol(CoreProtocol):
self.last_query = query
if state is None:
state = PreparedStatementState(
stmt_name, query, self, record_class)
stmt_name, query, self, record_class, ignore_custom_codec)
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/settings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext):
cpdef inline set_builtin_type_codec(
self, typeoid, typename, typeschema, typekind, alias_to, format)
cpdef inline Codec get_data_codec(
self, uint32_t oid, ServerDataFormat format=*)
self, uint32_t oid, ServerDataFormat format=*,
bint ignore_custom_codec=*)
12 changes: 8 additions & 4 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext):
typekind, alias_to, _format)

cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY):
ServerDataFormat format=PG_FORMAT_ANY,
bint ignore_custom_codec=False):
if format == PG_FORMAT_ANY:
codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY)
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT)
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
return self._data_codecs.get_codec(oid, format)
return self._data_codecs.get_codec(
oid, format, ignore_custom_codec)

def __getattr__(self, name):
if not name.startswith('_'):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def tearDownClass(cls):

super().tearDownClass()

def setUp(self):
super().setUp()
self.loop.run_until_complete(self._add_custom_codec(self.con))

async def _add_custom_codec(self, conn):
# mess up with the codec - builtin introspection shouldn't be affected
await conn.set_type_codec(
"oid",
schema="pg_catalog",
encoder=lambda value: None,
decoder=lambda value: None,
format="text",
)

@tb.with_connection_options(database='asyncpg_intro_test')
async def test_introspection_on_large_db(self):
await self.con.execute(
Expand Down Expand Up @@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self):
# query would cause introspection to retry.
slow_intro_conn = await self.connect(
connection_class=SlowIntrospectionConnection)
await self._add_custom_codec(slow_intro_conn)
try:
await self.con.execute('''
CREATE DOMAIN intro_1_t AS int;
Expand Down

0 comments on commit e064f59

Please sign in to comment.