Skip to content

Commit

Permalink
PYTHON-1332 - Send lsid with all commands
Browse files Browse the repository at this point in the history
  • Loading branch information
ajdavis committed Sep 29, 2017
1 parent 9051b65 commit c1ec855
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 159 deletions.
4 changes: 2 additions & 2 deletions gridfs/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def add_option(self, *args, **kwargs):
def remove_option(self, *args, **kwargs):
raise NotImplementedError("Method does not exist for GridOutCursor")

def _clone_base(self):
def _clone_base(self, session):
"""Creates an empty GridOutCursor for information to be copied into.
"""
return GridOutCursor(self.__root_collection)
return GridOutCursor(self.__root_collection, session=session)
42 changes: 22 additions & 20 deletions pymongo/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,26 +309,28 @@ def execute_command(self, sock_info, generator, write_concern, session):
db_name = self.collection.database.name
listeners = self.collection.database.client._event_listeners

for run in generator:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', self.ordered)])
if write_concern.document:
cmd['writeConcern'] = write_concern.document
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
cmd['bypassDocumentValidation'] = True
if session:
cmd['lsid'] = session.session_id

bwc = _BulkWriteContext(db_name, cmd, sock_info, op_id, listeners)
results = _do_batched_write_command(
self.namespace, run.op_type, cmd,
run.ops, True, self.collection.codec_options, bwc)

_merge_command(run, full_result, results)
# We're supposed to continue if errors are
# at the write concern level (e.g. wtimeout)
if self.ordered and full_result['writeErrors']:
break
with self.collection.database.client._tmp_session(session) as s:
for run in generator:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', self.ordered)])
if write_concern.document:
cmd['writeConcern'] = write_concern.document
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
cmd['bypassDocumentValidation'] = True
if s:
cmd['lsid'] = s.session_id
bwc = _BulkWriteContext(db_name, cmd, sock_info, op_id,
listeners)

results = _do_batched_write_command(
self.namespace, run.op_type, cmd,
run.ops, True, self.collection.codec_options, bwc)

_merge_command(run, full_result, results)
# We're supposed to continue if errors are
# at the write concern level (e.g. wtimeout)
if self.ordered and full_result['writeErrors']:
break

if full_result["writeErrors"] or full_result["writeConcernErrors"]:
if full_result['writeErrors']:
Expand Down
5 changes: 4 additions & 1 deletion pymongo/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ def end_session(self):
:class:`~pymongo.collection.Collection`, or
:class:`~pymongo.cursor.Cursor` after the session has ended.
"""
self._end_session(True)

def _end_session(self, lock):
if self._server_session is not None:
self.client._return_server_session(self._server_session)
self.client._return_server_session(self._server_session, lock)
self._server_session = None

def __enter__(self):
Expand Down
178 changes: 97 additions & 81 deletions pymongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,20 @@ def _command(self, sock_info, command, slave_ok=False,
(result document, address of server the command was run on)
"""
return sock_info.command(
self.__database.name,
command,
slave_ok,
read_preference or self.read_preference,
codec_options or self.codec_options,
check,
allowable_errors,
read_concern=read_concern,
write_concern=write_concern,
parse_write_concern_error=parse_write_concern_error,
collation=collation,
session=session)
with self.__database.client._tmp_session(session) as s:
return sock_info.command(
self.__database.name,
command,
slave_ok,
read_preference or self.read_preference,
codec_options or self.codec_options,
check,
allowable_errors,
read_concern=read_concern,
write_concern=write_concern,
parse_write_concern_error=parse_write_concern_error,
collation=collation,
session=s)

def __create(self, options, collation, session):
"""Sends a create command with the given options.
Expand Down Expand Up @@ -562,19 +563,20 @@ def _insert_one(
('documents', [doc])])
if concern:
command['writeConcern'] = concern
if session:
command['lsid'] = session.session_id

if sock_info.max_wire_version > 1 and acknowledged:
if bypass_doc_val and sock_info.max_wire_version >= 4:
command['bypassDocumentValidation'] = True

# Insert command.
result = sock_info.command(
self.__database.name,
command,
codec_options=self.__write_response_codec_options,
check_keys=check_keys)
_check_write_command_response([(0, result)])
with self.__database.client._tmp_session(session) as s:
result = sock_info.command(
self.__database.name,
command,
codec_options=self.__write_response_codec_options,
check_keys=check_keys,
session=s)
_check_write_command_response([(0, result)])
else:
# Legacy OP_INSERT.
self._legacy_write(
Expand Down Expand Up @@ -630,8 +632,6 @@ def gen():
('ordered', ordered)])
if concern:
command['writeConcern'] = concern
if session:
command['lsid'] = session.session_id
if op_id is None:
op_id = message._randint()
if bypass_doc_val and sock_info.max_wire_version >= 4:
Expand All @@ -641,10 +641,13 @@ def gen():
self.database.client._event_listeners)
if sock_info.max_wire_version > 1 and acknowledged:
# Batched insert command.
results = message._do_batched_write_command(
self.database.name + ".$cmd", message._INSERT, command,
gen(), check_keys, self.__write_response_codec_options, bwc)
_check_write_command_response(results)
with self.__database.client._tmp_session(session) as s:
if s:
command['lsid'] = s.session_id
results = message._do_batched_write_command(
self.database.name + ".$cmd", message._INSERT, command,
gen(), check_keys, self.__write_response_codec_options, bwc)
_check_write_command_response(results)
else:
# Legacy batched OP_INSERT.
message._do_batched_insert(self.__full_name, gen(), check_keys,
Expand Down Expand Up @@ -756,7 +759,8 @@ def gen():

blk = _Bulk(self, ordered, bypass_document_validation)
blk.ops = [doc for doc in gen()]
blk.execute(self.write_concern.document, session=session)
with self.__database.client._tmp_session(session) as s:
blk.execute(self.write_concern.document, session=s)
return InsertManyResult(inserted_ids, self.write_concern.acknowledged)

def _update(self, sock_info, criteria, document, upsert=False,
Expand Down Expand Up @@ -803,13 +807,14 @@ def _update(self, sock_info, criteria, document, upsert=False,
if bypass_doc_val and sock_info.max_wire_version >= 4:
command['bypassDocumentValidation'] = True

# The command result has to be published for APM unmodified
# so we make a shallow copy here before adding updatedExisting.
result = sock_info.command(
self.__database.name,
command,
codec_options=self.__write_response_codec_options,
session=session).copy()
with self.__database.client._tmp_session(session) as s:
# The command result has to be published for APM unmodified
# so we make a shallow copy here before adding updatedExisting.
result = sock_info.command(
self.__database.name,
command,
codec_options=self.__write_response_codec_options,
session=s).copy()
_check_write_command_response([(0, result)])
# Add the updatedExisting field for compatibility.
if result.get('n') and 'upserted' not in result:
Expand Down Expand Up @@ -1081,12 +1086,13 @@ def _delete(
command['writeConcern'] = concern

if sock_info.max_wire_version > 1 and acknowledged:
# Delete command.
result = sock_info.command(
self.__database.name,
command,
codec_options=self.__write_response_codec_options,
session=session)
with self.__database.client._tmp_session(session) as s:
# Delete command.
result = sock_info.command(
self.__database.name,
command,
codec_options=self.__write_response_codec_options,
session=s)
_check_write_command_response([(0, result)])
return result
else:
Expand Down Expand Up @@ -1455,13 +1461,14 @@ def parallel_scan(self, num_cursors, session=None, **kwargs):
('numCursors', num_cursors)])
cmd.update(kwargs)

s = self.__database.client._ensure_session(session)
with self._socket_for_reads() as (sock_info, slave_ok):
result = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
session=session)
session=s)

return [CommandCursor(self, cursor['cursor'], sock_info.address,
session=session)
session=s, session_owned=session is None)
for cursor in result['cursors']]

def _count(self, cmd, collation=None, session=None):
Expand Down Expand Up @@ -1883,18 +1890,20 @@ def list_indexes(self, session=None):
with self._socket_for_primary_reads() as (sock_info, slave_ok):
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
if sock_info.max_wire_version > 2:
s = self.__database.client._ensure_session(session)
try:
cursor = self._command(sock_info, cmd, slave_ok,
ReadPreference.PRIMARY,
codec_options,
session=session)["cursor"]
session=s)["cursor"]
except OperationFailure as exc:
# Ignore NamespaceNotFound errors to match the behavior
# of reading from *.system.indexes.
if exc.code != 26:
raise
cursor = {'id': 0, 'firstBatch': []}
return CommandCursor(coll, cursor, sock_info.address)
return CommandCursor(coll, cursor, sock_info.address,
session=s, session_owned=session is None)
else:
namespace = _UJOIN % (self.__database.name, "system.indexes")
res = helpers._first_batch(
Expand Down Expand Up @@ -2019,41 +2028,47 @@ def _aggregate(self, pipeline, cursor_class, first_batch_size, session,
cmd['writeConcern'] = self.write_concern.document

cmd.update(kwargs)

# Apply this Collection's read concern if $out is not in the
# pipeline.
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
if dollar_out:
result = self._command(sock_info, cmd, slave_ok,
parse_write_concern_error=True,
collation=collation,
session=session)
session_owned = session is None
s = self.__database.client._ensure_session(session)
try:
# Apply this Collection's read concern if $out is not in the
# pipeline.
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
if dollar_out:
result = self._command(sock_info, cmd, slave_ok,
parse_write_concern_error=True,
collation=collation,
session=s)
else:
result = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation,
session=s)
else:
result = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
parse_write_concern_error=dollar_out,
collation=collation,
session=session)
else:
result = self._command(sock_info, cmd, slave_ok,
parse_write_concern_error=dollar_out,
collation=collation,
session=session)
session=s)

if "cursor" in result:
cursor = result["cursor"]
else:
# Pre-MongoDB 2.6. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result["result"],
"ns": self.full_name,
}

return cursor_class(
self, cursor, sock_info.address,
batch_size=batch_size or 0,
max_await_time_ms=max_await_time_ms,
session=session)
if "cursor" in result:
cursor = result["cursor"]
else:
# Pre-MongoDB 2.6. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result["result"],
"ns": self.full_name,
}

return cursor_class(
self, cursor, sock_info.address,
batch_size=batch_size or 0,
max_await_time_ms=max_await_time_ms,
session=s, session_owned=session_owned)
except Exception:
if session_owned:
s.end_session()
raise

def aggregate(self, pipeline, session=None, **kwargs):
"""Perform an aggregation using the aggregation framework on this
Expand Down Expand Up @@ -2334,11 +2349,12 @@ def rename(self, new_name, session=None, **kwargs):
new_name = "%s.%s" % (self.__database.name, new_name)
cmd = SON([("renameCollection", self.__full_name), ("to", new_name)])
with self._socket_for_writes() as sock_info:
if sock_info.max_wire_version >= 5 and self.write_concern:
cmd['writeConcern'] = self.write_concern.document
cmd.update(kwargs)
sock_info.command('admin', cmd, parse_write_concern_error=True,
session=session)
with self.__database.client._tmp_session(session) as s:
if sock_info.max_wire_version >= 5 and self.write_concern:
cmd['writeConcern'] = self.write_concern.document
cmd.update(kwargs)
sock_info.command('admin', cmd, parse_write_concern_error=True,
session=s)

def distinct(self, key, filter=None, session=None, **kwargs):
"""Get a list of distinct values for `key` among all documents
Expand Down

0 comments on commit c1ec855

Please sign in to comment.