diff --git a/doc/quickstart.rst b/doc/quickstart.rst index 06466a72..46e9d11b 100644 --- a/doc/quickstart.rst +++ b/doc/quickstart.rst @@ -1,15 +1,24 @@ Quick Start =========== +Key Points +---------- + +- Runs on Python version 2.5, 2.6, 2.7, 3.2, 3.3 and 3.4 +- Runs on CPython, Jython and PyPy +- Although it's possible for threads to share cursors and connections, for + performance reasons it's best to use one thread per connection. +- Internally, all queries use prepared statements. pg8000 remembers that a + prepared statement has been created, and uses it on subsequent queries. + Installation ------------ -pg8000 is available for Python 2.5, 2.6, 2.7, 3.2, 3.3 and 3.4 (and has been -tested on CPython, Jython and PyPy). To install pg8000 using `pip `_ type: ``pip install pg8000`` + Interactive Example ------------------- @@ -74,21 +83,3 @@ turned on by using the autocommit property of the connection. >>> conn.autocommit = False >>> cursor.close() >>> conn.close() - -Try the use_cache feature: - -.. code-block:: python - - >>> conn = pg8000.connect( - ... user="postgres", password="C.P.Snow", use_cache=True) - >>> cur = conn.cursor() - >>> cur.execute("select cast(%s as varchar) as f1", ('Troon',)) - >>> res = cur.fetchall() - -Now subsequent queries with the same parameter types and SQL will use the -cached prepared statement. - -.. code-block:: python - - >>> cur.execute("select cast(%s as varchar) as f1", ('Trunho',)) - >>> res = cur.fetchall() diff --git a/pg8000/__init__.py b/pg8000/__init__.py index 4b046676..36962fcb 100644 --- a/pg8000/__init__.py +++ b/pg8000/__init__.py @@ -142,12 +142,10 @@ def __neq__(self, other): # @return An instance of {@link #ConnectionWrapper ConnectionWrapper}. def connect( user, host='localhost', unix_sock=None, port=5432, database=None, - password=None, socket_timeout=60, ssl=False, use_cache=False, - **kwargs): + password=None, socket_timeout=60, ssl=False, **kwargs): return pg8000.core.Connection( - user, host, unix_sock, port, database, password, socket_timeout, ssl, - use_cache) + user, host, unix_sock, port, database, password, socket_timeout, ssl) ## # The DBAPI level supported. Currently 2.0. This property is part of the diff --git a/pg8000/core.py b/pg8000/core.py index a2d653f6..fce699b8 100644 --- a/pg8000/core.py +++ b/pg8000/core.py @@ -467,16 +467,12 @@ def int_in(data, offset, length): class Cursor(Iterator): def __init__(self, connection): self._c = connection - self._stmt = None self.arraysize = 1 + self.ps = None self._row_count = -1 - - def require_stmt(func): - def retval(self, *args, **kwargs): - if self._stmt is None: - raise ProgrammingError("attempting to use unexecuted cursor") - return func(self, *args, **kwargs) - return retval + self._cached_rows = deque() + self.portal_name = None + self.portal_suspended = False ## # This read-only attribute returns a reference to the connection object on @@ -513,11 +509,10 @@ def rowcount(self): # Stability: Part of the DBAPI 2.0 specification. description = property(lambda self: self._getDescription()) - @require_open_cursor def _getDescription(self): - if self._stmt is None: + if self.ps is None: return None - row_desc = self._stmt.get_row_description() + row_desc = self.ps['row_desc'] if len(row_desc) == 0: return None columns = [] @@ -532,24 +527,22 @@ def _getDescription(self): #

# Stability: Part of the DBAPI 2.0 specification. def execute(self, operation, args=None, stream=None): - if args is None: - args = tuple() - - self._row_count = -1 - if not self._c.use_cache: - self._c.statement_cache.clear() - try: - self._c.begin() + self._c._lock.acquire() + self.stream = stream + + if not self._c.in_transaction and not self._c.autocommit: + self._c.execute(self, "begin transaction", None) + self._c.execute(self, operation, args) except AttributeError: if self._c is None: raise InterfaceError("Cursor closed") + elif self._c._sock is None: + raise InterfaceError("Connection closed") else: raise exc_info()[1] - - self._stmt = self._get_ps(operation, args) - self._stmt.execute(args, stream=stream) - self._row_count = self._stmt.row_count + finally: + self._c._lock.release() ## # Prepare a database operation and then execute it against all parameter @@ -557,28 +550,9 @@ def execute(self, operation, args=None, stream=None): #

# Stability: Part of the DBAPI 2.0 specification. def executemany(self, operation, param_sets): - self._row_count = -1 - - try: - self._c.begin() - except AttributeError: - if self._c is None: - raise InterfaceError("Cursor closed") - else: - raise exc_info()[1] - - if not self._c.use_cache: - self._c.statement_cache.clear() - for parameters in param_sets: - self._stmt = self._get_ps(operation, parameters) - self._stmt.execute(parameters) - if self._stmt.row_count == -1: - self._row_count = -1 - elif self._row_count == -1: - self._row_count = self._stmt.row_count - else: - self._row_count += self._stmt.row_count + self.execute(operation, parameters) + self._row_count = -1 def copy_from(self, fileobj, table=None, sep='\t', null=None, query=None): if query is None: @@ -598,22 +572,6 @@ def copy_to(self, fileobj, table=None, sep='\t', null=None, query=None): query += " NULL '%s'" % (null,) self.copy_execute(fileobj, query) - def _get_ps(self, operation, vals): - if pg8000.paramstyle in ('numeric', 'qmark', 'format'): - args = vals - else: - args = tuple(vals[k] for k in sorted(vals.keys())) - - key = tuple(oid for oid, x, y in self._c.make_params(args)), operation - - try: - return self._c.statement_cache[key] - except KeyError: - ps = PreparedStatement(self._c, operation, vals) - self._c.statement_cache[key] = ps - return ps - - @require_open_cursor def copy_execute(self, fileobj, query): self.execute(query, stream=fileobj) @@ -624,7 +582,7 @@ def copy_execute(self, fileobj, query): # Stability: Part of the DBAPI 2.0 specification. def fetchone(self): try: - return next(self._stmt) + return next(self) except StopIteration: return None except TypeError: @@ -681,15 +639,11 @@ def fetchall(self): # Close the cursor. #

# Stability: Part of the DBAPI 2.0 specification. - @require_open_cursor def close(self): - if self._stmt is not None and not self._c.use_cache: - self._stmt.close() - self._stmt = None self._c = None def __iter__(self): - return self._stmt + return self def setinputsizes(self, sizes): pass @@ -697,6 +651,29 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column=None): pass + def __next__(self): + try: + self._c._lock.acquire() + return self._cached_rows.popleft() + except IndexError: + if self.portal_suspended: + self._c.send_EXECUTE(self) + self._c._write(SYNC_MSG) + self._c._flush() + self._c.handle_messages(self) + if not self.portal_suspended: + self._c.close_portal(self) + try: + return self._cached_rows.popleft() + except IndexError: + if self.ps is None: + raise ProgrammingError("A query hasn't been issued.") + elif len(self.ps['row_desc']) == 0: + raise ProgrammingError("no result set") + else: + raise StopIteration() + finally: + self._c._lock.release() # Message codes NOTICE_RESPONSE = b("N") @@ -827,6 +804,14 @@ class Connection(object): NotSupportedError = property( lambda self: self._getError(NotSupportedError)) + # Determines the number of rows to read from the database server at once. + # Reading more rows increases performance at the cost of memory. The + # default value is 100 rows. The effect of this parameter is transparent. + # That is, the library reads more rows when the cache is empty + # automatically. + _row_cache_size = 100 + _row_cache_size_bin = i_pack(_row_cache_size) + def _getError(self, error): warn( "DB-API extension connection.%s used" % @@ -835,24 +820,19 @@ def _getError(self, error): def __init__( self, user, host, unix_sock, port, database, password, - socket_timeout, ssl, use_cache): + socket_timeout, ssl): self._client_encoding = "ascii" self._commands_with_count = ( b("INSERT"), b("DELETE"), b("UPDATE"), b("MOVE"), b("FETCH"), b("COPY"), b("SELECT")) - self._sock_lock = threading.Lock() + self._lock = threading.Lock() self.user = user self.password = password self.autocommit = False - self.statement_cache = {} - - self.statement_number_lock = threading.Lock() + self._caches = defaultdict(lambda: defaultdict(dict)) self.statement_number = 0 - - self.portal_number_lock = threading.Lock() self.portal_number = 0 - self.use_cache = use_cache try: if unix_sock is None and host is not None: @@ -873,7 +853,7 @@ def __init__( if ssl: try: - self._sock_lock.acquire() + self._lock.acquire() import ssl as sslmodule # Int32(8) - Message length, including self. # Int32(80877103) - The SSL request code. @@ -888,7 +868,7 @@ def __init__( "SSL required but ssl module not available in " "this python installation") finally: - self._sock_lock.release() + self._lock.release() # settimeout causes ssl failure, on windows. Python bug 1462352. self._usock.settimeout(socket_timeout) @@ -1189,19 +1169,18 @@ def inet_in(data, offset, length): self._write(val) self._flush() + self._cursor = self.cursor() try: try: - self._sock_lock.acquire() - self.handle_messages() + self._lock.acquire() + self.handle_messages(None) finally: - self._sock_lock.release() + self._lock.release() except: self.close() raise exc_info()[1] - self._begin = PreparedStatement(self, "BEGIN TRANSACTION", ()) - self._commit = PreparedStatement(self, "COMMIT TRANSACTION", ()) - self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION", ()) + self._cursor = self.cursor() self.in_transaction = False self.notifies = [] self.notifies_lock = threading.Lock() @@ -1226,8 +1205,8 @@ def handle_PARSE_COMPLETE(self, data, ps): def handle_BIND_COMPLETE(self, data, ps): pass - def handle_PORTAL_SUSPENDED(self, data, ps): - ps.portal_suspended = True + def handle_PORTAL_SUSPENDED(self, data, cursor): + cursor.portal_suspended = True def handle_PARAMETER_DESCRIPTION(self, data, ps): # Well, we don't really care -- we're going to send whatever we @@ -1258,7 +1237,7 @@ def handle_COPY_IN_RESPONSE(self, data, ps): # Int16(N) - Format codes for each column (0 text, 1 binary) is_binary, num_cols = bh_unpack(data) # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) - assert self._sock_lock.locked() + assert self._lock.locked() if ps.stream is None: raise CopyQueryWithoutStreamError() @@ -1324,23 +1303,22 @@ def cursor(self): #

# Stability: Part of the DBAPI 2.0 specification. def commit(self): - # There's a threading bug here. If a query is sent after the - # commit, but before the begin, it will be executed immediately - # without a surrounding transaction. Like all threading bugs -- it - # sounds unlikely, until it happens every time in one - # application... however, to fix this, we need to lock the - # database connection entirely, so that no cursors can execute - # statements on other threads. Support for that type of lock will - # be done later. - self._commit.execute(()) + try: + self._lock.acquire() + self.execute(self._cursor, "commit", None) + finally: + self._lock.release() ## # Rolls back the current database transaction. #

# Stability: Part of the DBAPI 2.0 specification. def rollback(self): - # see bug description in commit. - self._rollback.execute(()) + try: + self._lock.acquire() + self.execute(self._cursor, "rollback", None) + finally: + self._lock.release() ## # Closes the database connection. @@ -1348,7 +1326,7 @@ def rollback(self): # Stability: Part of the DBAPI 2.0 specification. def close(self): try: - self._sock_lock.acquire() + self._lock.acquire() # Byte1('X') - Identifies the message as a terminate message. # Int32(4) - Message length, including self. self._write(TERMINATE_MSG) @@ -1361,18 +1339,10 @@ def close(self): except ValueError: raise pg8000.InterfaceError("Connection is closed.") finally: - self._sock_lock.release() - - ## - # Begins a new transaction. - #

- # Stability: Added in v1.00, stability guaranteed for v1.xx. - def begin(self): - if not self.in_transaction and not self.autocommit: - self._begin.execute(()) + self._lock.release() - def handle_AUTHENTICATION_REQUEST(self, data, ps): - assert self._sock_lock.locked() + def handle_AUTHENTICATION_REQUEST(self, data, cursor): + assert self._lock.locked() # Int32 - An authentication code that represents different # authentication messages: # 0 = AuthenticationOk @@ -1449,7 +1419,7 @@ def make_params(self, values): "not mapped to pg type") return params - def handle_ROW_DESCRIPTION(self, data, ps): + def handle_ROW_DESCRIPTION(self, data, cursor): count = h_unpack(data)[0] idx = 2 for i in range(count): @@ -1461,7 +1431,7 @@ def handle_ROW_DESCRIPTION(self, data, ps): "type_size", "type_modifier", "format"), ihihih_unpack(data, idx)))) idx += 18 - ps.row_desc.append(field) + cursor.ps['row_desc'].append(field) try: field['pg8000_fc'], field['func'] = self.pg_types[ field['type_oid']] @@ -1469,9 +1439,38 @@ def handle_ROW_DESCRIPTION(self, data, ps): raise NotSupportedError( "type oid " + exc_info()[1] + " not supported") - def parse(self, ps, statement): + def execute(self, cursor, operation, vals): + if vals is None: + vals = () + paramstyle = pg8000.paramstyle + cache = self._caches[paramstyle] + try: - self._sock_lock.acquire() + statement, make_args = cache['statement'][operation] + except KeyError: + statement, make_args = convert_paramstyle(paramstyle, operation) + cache['statement'][operation] = statement, make_args + + args = make_args(vals) + params = self.make_params(args) + + key = tuple(oid for oid, x, y in params), operation + + try: + ps = cache['ps'][key] + cursor.ps = ps + except KeyError: + statement_name = "pg8000_statement_" + str(self.statement_number) + self.statement_number += 1 + statement_name_bin = statement_name.encode('ascii') + NULL_BYTE + ps = { + 'row_desc': [], + 'param_funcs': tuple(x[2] for x in params), + } + cursor.ps = ps + + param_fcs = tuple(x[1] for x in params) + # Byte1('P') - Identifies the message as a Parse command. # Int32 - Message length, including self. # String - Prepared statement name. An empty string selects the @@ -1480,10 +1479,10 @@ def parse(self, ps, statement): # Int16 - Number of parameter data types specified (can be zero). # For each parameter: # Int32 - The OID of the parameter data type. - val = bytearray(ps.statement_name_bin) + val = bytearray(statement_name_bin) val.extend(statement.encode(self._client_encoding) + NULL_BYTE) - val.extend(h_pack(len(ps.params))) - for oid, fc, send_func in ps.params: + val.extend(h_pack(len(params))) + for oid, fc, send_func in params: # Parse message doesn't seem to handle the -1 type_oid for NULL # values that other messages handle. So we'll provide type_oid # 705, the PG "unknown" type. @@ -1494,17 +1493,25 @@ def parse(self, ps, statement): # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to describe. self._send_message(PARSE, val) - self._send_message(DESCRIBE, STATEMENT + ps.statement_name_bin) + self._send_message(DESCRIBE, STATEMENT + statement_name_bin) self._write(SYNC_MSG) - self._flush() - self.handle_messages(ps) - finally: - self._sock_lock.release() - def bind(self, ps, values): - try: - self._sock_lock.acquire() + try: + self._flush() + except AttributeError: + if self._sock is None: + raise InterfaceError("Connection closed") + else: + raise exc_info()[1] + + self.handle_messages(cursor) + # We've got row_desc that allows us to identify what we're + # going to get back from this statement. + output_fc = tuple( + self.pg_types[f['type_oid']][0] for f in ps['row_desc']) + + ps['input_funcs'] = tuple(f['func'] for f in ps['row_desc']) # Byte1('B') - Identifies the Bind command. # Int32 - Message length, including self. # String - Name of the destination portal. @@ -1521,25 +1528,56 @@ def bind(self, ps, values): # Int16 - The number of result-column format codes. # For each result-column format code: # Int16 - The format code. - retval = bytearray(ps.portal_name_bin + ps.bind_1) - for value, send_func in zip(values, ps.param_funcs): - if value is None: - val = NULL - else: - val = send_func(value) - retval.extend(i_pack(len(val))) - retval.extend(val) - retval.extend(ps.bind_2) + ps['bind_1'] = statement_name_bin + h_pack(len(params)) + \ + pack("!" + "h" * len(param_fcs), *param_fcs) + \ + h_pack(len(params)) - self._send_message(BIND, retval) - self.send_EXECUTE(ps) - self._write(SYNC_MSG) - self._flush() - self.handle_messages(ps) - except AttributeError: - raise pg8000.InterfaceError("Connection is closed.") - finally: - self._sock_lock.release() + ps['bind_2'] = h_pack(len(output_fc)) + \ + pack("!" + "h" * len(output_fc), *output_fc) + + cache['ps'][key] = ps + + cursor._cached_rows.clear() + cursor._row_count = -1 + cursor.portal_name = "pg8000_portal_" + str(self.portal_number) + self.portal_number += 1 + cursor.portal_name_bin = cursor.portal_name.encode('ascii') + NULL_BYTE + cursor.execute_msg = cursor.portal_name_bin + \ + Connection._row_cache_size_bin + + # Byte1('B') - Identifies the Bind command. + # Int32 - Message length, including self. + # String - Name of the destination portal. + # String - Name of the source prepared statement. + # Int16 - Number of parameter format codes. + # For each parameter format code: + # Int16 - The parameter format code. + # Int16 - Number of parameter values. + # For each parameter value: + # Int32 - The length of the parameter value, in bytes, not + # including this length. -1 indicates a NULL parameter + # value, in which no value bytes follow. + # Byte[n] - Value of the parameter. + # Int16 - The number of result-column format codes. + # For each result-column format code: + # Int16 - The format code. + retval = bytearray(cursor.portal_name_bin + ps['bind_1']) + for value, send_func in zip(args, ps['param_funcs']): + if value is None: + val = NULL + else: + val = send_func(value) + retval.extend(i_pack(len(val))) + retval.extend(val) + retval.extend(ps['bind_2']) + + self._send_message(BIND, retval) + self.send_EXECUTE(cursor) + self._write(SYNC_MSG) + self._flush() + self.handle_messages(cursor) + if not cursor.portal_suspended: + self.close_portal(cursor) def _send_message(self, code, data): try: @@ -1555,41 +1593,32 @@ def _send_message(self, code, data): except AttributeError: raise pg8000.InterfaceError("Connection is closed.") - def send_EXECUTE(self, ps): + def send_EXECUTE(self, cursor): # Byte1('E') - Identifies the message as an execute message. # Int32 - Message length, including self. # String - The name of the portal to execute. # Int32 - Maximum number of rows to return, if portal # contains a query # that returns rows. # 0 = no limit. - ps.cmd = None - ps.portal_suspended = False - self._send_message( - EXECUTE, ps.portal_name_bin + ps.row_cache_size_bin) + cursor.portal_suspended = False + self._send_message(EXECUTE, cursor.execute_msg) def handle_NO_DATA(self, msg, ps): pass - def handle_COMMAND_COMPLETE(self, data, ps): - ps.cmd = {} - data = data[:-1] - values = data.split(b(" ")) + def handle_COMMAND_COMPLETE(self, data, cursor): + values = data[:-1].split(b(" ")) if values[0] in self._commands_with_count: - ps.cmd['command'] = values[0] row_count = int(values[-1]) - if ps.row_count == -1: - ps.row_count = row_count + if cursor._row_count == -1: + cursor._row_count = row_count else: - ps.row_count += row_count - if values[0] == b("INSERT"): - ps.cmd['oid'] = int(values[1]) - else: - ps.cmd['command'] = data + cursor._row_count += row_count - def handle_DATA_ROW(self, data, ps): + def handle_DATA_ROW(self, data, cursor): data_idx = 2 row = [] - for func in ps.input_funcs: + for func in cursor.ps['input_funcs']: vlen = i_unpack(data, data_idx)[0] data_idx += 4 if vlen == -1: @@ -1597,25 +1626,27 @@ def handle_DATA_ROW(self, data, ps): else: row.append(func(data, data_idx, vlen)) data_idx += vlen - ps._cached_rows.append(row) + cursor._cached_rows.append(row) - def handle_messages(self, ps=None): + def handle_messages(self, cursor): message_code = None error = None while message_code != READY_FOR_QUERY: message_code, data_len = ci_unpack(self._read(5)) try: - self.message_types[message_code](self._read(data_len - 4), ps) + self.message_types[message_code]( + self._read(data_len - 4), cursor) except KeyError: raise InternalError( "Unrecognised message code " + message_code) except pg8000.errors.Error: e = exc_info()[1] - if ps is None: + if cursor is None: raise e else: error = e + if error is not None: raise error @@ -1623,29 +1654,11 @@ def handle_messages(self, ps=None): # Int32 - Message length, including self. # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to close. - def close_statement(self, ps): - try: - self._sock_lock.acquire() - self._send_message(CLOSE, STATEMENT + ps.statement_name_bin) - self._write(SYNC_MSG) - self._flush() - self.handle_messages(ps) - finally: - self._sock_lock.release() - - # Byte1('C') - Identifies the message as a close command. - # Int32 - Message length, including self. - # Byte1 - 'S' for prepared statement, 'P' for portal. - # String - The name of the item to close. - def close_portal(self, ps): - try: - self._sock_lock.acquire() - self._send_message(CLOSE, PORTAL + ps.portal_name_bin) - self._write(SYNC_MSG) - self._flush() - self.handle_messages(ps) - finally: - self._sock_lock.release() + def close_portal(self, cursor): + self._send_message(CLOSE, PORTAL + cursor.portal_name_bin) + self._write(SYNC_MSG) + self._flush() + self.handle_messages(cursor) def handle_NOTICE_RESPONSE(self, data, ps): resp = data_into_dict(data) @@ -1914,161 +1927,3 @@ def array_dim_lengths(arr): else: return [len(arr)] return retval - - -## -# This class represents a prepared statement. A prepared statement is -# pre-parsed on the server, which reduces the need to parse the query every -# time it is run. The statement can have parameters in the form of $1, $2, $3, -# etc. When parameters are used, the types of the parameters need to be -# specified when creating the prepared statement. -#

-# As of v1.01, instances of this class are thread-safe. This means that a -# single PreparedStatement can be accessed by multiple threads without the -# internal consistency of the statement being altered. However, the -# responsibility is on the client application to ensure that one thread reading -# from a statement isn't affected by another thread starting a new query with -# the same statement. -#

-# Stability: Added in v1.00, stability guaranteed for v1.xx. -# -# @param connection An instance of {@link Connection Connection}. -# -# @param statement The SQL statement to be represented, often containing -# parameters in the form of $1, $2, $3, etc. -# -# @param types Python type objects for each parameter in the SQL -# statement. For example, int, float, str. -class PreparedStatement(Iterator): - - ## - # Determines the number of rows to read from the database server at once. - # Reading more rows increases performance at the cost of memory. The - # default value is 100 rows. The affect of this parameter is transparent. - # That is, the library reads more rows when the cache is empty - # automatically. - #

- # Stability: Added in v1.00, stability guaranteed for v1.xx. It is - # possible that implementation changes in the future could cause this - # parameter to be ignored. - row_cache_size = 100 - - def __init__(self, connection, query, values): - - # Stability: Added in v1.03, stability guaranteed for v1.xx. - self.row_count = -1 - - self.c = connection - self.portal_name = None - self.row_cache_size_bin = i_pack(PreparedStatement.row_cache_size) - - try: - self.c.statement_number_lock.acquire() - self.statement_name = "pg8000_statement_" + \ - str(self.c.statement_number) - self.c.statement_number += 1 - finally: - self.c.statement_number_lock.release() - - self.statement_name_bin = self.statement_name.encode('ascii') + \ - NULL_BYTE - self._cached_rows = deque() - self.statement, self.make_args = convert_paramstyle( - pg8000.paramstyle, query) - self.params = self.c.make_params(self.make_args(values)) - param_fcs = tuple(x[1] for x in self.params) - self.param_funcs = tuple(x[2] for x in self.params) - self.row_desc = [] - self.c.parse(self, self.statement) - self._lock = threading.RLock() - self.cmd = None - - # We've got row_desc that allows us to identify what we're - # going to get back from this statement. - output_fc = tuple( - self.c.pg_types[f['type_oid']][0] for f in self.row_desc) - - self.input_funcs = tuple(f['func'] for f in self.row_desc) - # Byte1('B') - Identifies the Bind command. - # Int32 - Message length, including self. - # String - Name of the destination portal. - # String - Name of the source prepared statement. - # Int16 - Number of parameter format codes. - # For each parameter format code: - # Int16 - The parameter format code. - # Int16 - Number of parameter values. - # For each parameter value: - # Int32 - The length of the parameter value, in bytes, not - # including this length. -1 indicates a NULL parameter - # value, in which no value bytes follow. - # Byte[n] - Value of the parameter. - # Int16 - The number of result-column format codes. - # For each result-column format code: - # Int16 - The format code. - self.bind_1 = self.statement_name_bin + h_pack(len(self.params)) + \ - pack("!" + "h" * len(param_fcs), *param_fcs) + \ - h_pack(len(self.params)) - - self.bind_2 = h_pack(len(output_fc)) + \ - pack("!" + "h" * len(output_fc), *output_fc) - - def close(self): - if self.statement_name != "": # don't close unnamed statement - self.c.close_statement(self) - if self.portal_name is not None: - self.c.close_portal(self) - self.portal_name = None - - def get_row_description(self): - return self.row_desc - - ## - # Run the SQL prepared statement with the given parameters. - #

- # Stability: Added in v1.00, stability guaranteed for v1.xx. - def execute(self, values, stream=None): - try: - self._lock.acquire() - # cleanup last execute - self._cached_rows.clear() - self.row_count = -1 - self.portal_suspended = False - try: - self.c.portal_number_lock.acquire() - self.portal_name = "pg8000_portal_" + str(self.c.portal_number) - self.c.portal_number += 1 - finally: - self.c.portal_number_lock.release() - self.portal_name_bin = self.portal_name.encode('ascii') + NULL_BYTE - self.cmd = None - self.stream = stream - self.c.bind(self, self.make_args(values)) - if len(self.row_desc) == 0: - self.c.close_portal(self) - finally: - self._lock.release() - - def __next__(self): - try: - self._lock.acquire() - return self._cached_rows.popleft() - except IndexError: - if self.portal_suspended: - try: - self.c._sock_lock.acquire() - self.c.send_EXECUTE(self) - self.c._write(SYNC_MSG) - self.c._flush() - self.c.handle_messages(self) - finally: - self.c._sock_lock.release() - - try: - return self._cached_rows.popleft() - except IndexError: - if len(self.row_desc) == 0: - raise ProgrammingError("no result set") - self.c.close_portal(self) - raise StopIteration() - finally: - self._lock.release() diff --git a/pg8000/tests/test_query.py b/pg8000/tests/test_query.py index 774b65a3..992450e8 100644 --- a/pg8000/tests/test_query.py +++ b/pg8000/tests/test_query.py @@ -84,6 +84,40 @@ def testParallelQueries(self): cursor.close() self.db.rollback() + def testParallelOpenPortals(self): + try: + c1, c2 = self.db.cursor(), self.db.cursor() + c1count, c2count = 0, 0 + q = "select * from generate_series(1, %s)" + params = (self.db._row_cache_size + 1,) + c1.execute(q, params) + c2.execute(q, params) + for c2row in c2: + c2count += 1 + for c1row in c1: + c1count += 1 + finally: + c1.close() + c2.close() + self.db.rollback() + + self.assertEqual(c1count, c2count) + + # Test query works if the number of rows returned is exactly the same as + # the size of the row cache + + def testQuerySizeCache(self): + try: + cursor = self.db.cursor() + cursor.execute( + "select * from generate_series(1, %s)", + (self.db._row_cache_size,)) + for row in cursor: + pass + finally: + cursor.close() + self.db.rollback() + def testInsertReturning(self): try: cursor = self.db.cursor() diff --git a/tox.ini b/tox.ini index d90e603c..9f67da15 100644 --- a/tox.ini +++ b/tox.ini @@ -9,7 +9,6 @@ envlist = py26, py27, py32, py33, py34, pypy [testenv] commands = nosetests - nosetests --tc=use_cache:true deps = nose pytz @@ -18,13 +17,11 @@ deps = [testenv:py34] commands = nosetests - nosetests --tc=use_cache:true python -m doctest README.creole python -m doctest doc/quickstart.rst flake8 pg8000 python setup.py check deps = nose - nose-testconfig flake8 pytz