Skip to content

Commit

Permalink
Clarify client protocol version handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
madscientist committed May 27, 2018
1 parent f4c883c commit 5e76dc5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Expand Up @@ -2,4 +2,4 @@
include=pynuodb/*

[report]
omit=tests/*
omit=tests/*
25 changes: 13 additions & 12 deletions pynuodb/connection.py
Expand Up @@ -42,20 +42,20 @@ def connect(database, host, user, password, options=None):
return Connection(database, host, user, password, options)

class Connection(object):

"""Class for establishing a connection with host.
Public Functions:
testConnection -- Tests to ensure the connection was properly established.
close -- Closes the connection with the host.
commit -- Sends a message to the host to commit transaction.
rollback -- Sends a message to the host to rollback uncommitted changes.
cursor -- Return a new Cursor object using the connection.
Private Functions:
__init__ -- Constructor for the Connection class.
_check_closed -- Checks if the connection to the host is closed.
Special Function:
auto_commit (getter) -- Gets the value of auto_commit from the database.
auto_commit (setter) -- Sets the value of auto_commit on the database.
Expand All @@ -64,10 +64,10 @@ class Connection(object):
from .exception import Warning, Error, InterfaceError, DatabaseError, \
OperationalError, IntegrityError, InternalError, \
ProgrammingError, NotSupportedError

def __init__(self, dbName, broker, username, password, options):
"""Constructor for the Connection class.
Arguments:
dbName -- Name of database you are accessing.
broker -- Address of the broker you are connecting too.
Expand All @@ -76,7 +76,7 @@ def __init__(self, dbName, broker, username, password, options):
options -- A dictionary of NuoDB connection options
Some common options include:
"schema"
Returns:
a Connection instance
Expand All @@ -91,7 +91,7 @@ def __init__(self, dbName, broker, username, password, options):
self._trans_id = None

cp = ClientPassword()

parameters = {'user' : username, 'timezone' : time.strftime('%Z')}
if options:
parameters.update(options)
Expand All @@ -101,7 +101,8 @@ def __init__(self, dbName, broker, username, password, options):
parameters['clientProcessId'] = str(getpid())

version, serverKey, salt = self.__session.open_database(dbName, parameters, cp)

self.__protocolVersion = version

sessionKey = cp.computeSessionKey(username.upper(), password, salt, serverKey)
self.__session.setCiphers(RC4Cipher(sessionKey), RC4Cipher(sessionKey))

Expand All @@ -112,12 +113,12 @@ def __init__(self, dbName, broker, username, password, options):

def testConnection(self):
"""Tests to ensure the connection was properly established.
This function will test the connection and if it was created should print out:
count: 1
name: ONE
value: 1
:rtype: None
"""
self.__session.test_connection()
Expand All @@ -127,7 +128,7 @@ def auto_commit(self):
"""Gets the value of auto_commit from the database."""
self._check_closed()
return self.__session.get_autocommit()

@auto_commit.setter
def auto_commit(self, value):
"""Sets the value of auto_commit on the database."""
Expand Down
17 changes: 11 additions & 6 deletions pynuodb/encodedsession.py
Expand Up @@ -114,9 +114,9 @@ def __init__(self, host, port, service='SQL2'):
:type : uuid.UUID
"""

self.__serverVersion = 0
self.__sessionVersion = 0
"""
Server's current version
Client protocol version for the session
:type : int
"""

Expand All @@ -128,7 +128,7 @@ def __init__(self, host, port, service='SQL2'):

self.__effectivePlatformVersion = 0
"""
Agreed upon version by the server (for multiple nodes)
Database protocol version when the session is created (may change!)
:type : int
"""

Expand Down Expand Up @@ -175,7 +175,7 @@ def open_database(self, db_name, parameters, cp):
self.__connectedNodeID = self.getInt()
self.__maxNodes = self.getInt()

self.__serverVersion = protocolVersion
self.__sessionVersion = protocolVersion

return protocolVersion, serverKey, salt

Expand All @@ -188,6 +188,12 @@ def check_auth(self):
except SessionException as e:
raise ProgrammingError('Failed to authenticate: ' + str(e))

def get_version(self):
"""
:rtype sessionVersion: int
"""
return self.__sessionVersion

def get_auth_types(self):
self._putMessageId(protocol.AUTHORIZETYPESREQUEST)
self._exchangeMessages()
Expand Down Expand Up @@ -1025,7 +1031,6 @@ def getValue(self):
def _exchangeMessages(self, getResponse=True):
"""Exchange the pending message for an optional response from the server."""
try:
#print("message to server: %s" % (self.__output))
self.send(self.__output)
finally:
self.__output = None
Expand Down Expand Up @@ -1058,7 +1063,7 @@ def _setup_statement(self, handle, msgId):
:type msgId: int
"""
self._putMessageId(msgId)
if(self.__serverVersion >= protocol.PROTOCOL_VERSION17):
if self.__sessionVersion >= protocol.PROTOCOL_VERSION17:
self.putInt(self.getCommitInfo(self.__connectedNodeID))
self.putInt(handle)

Expand Down
8 changes: 4 additions & 4 deletions pynuodb/session.py
Expand Up @@ -91,7 +91,7 @@ def __init__(self, host, port=None, service="Identity"):

self.__service = service

self.__version = sys.version[0]
self.__pyversion = sys.version[0]

@property
def address(self):
Expand Down Expand Up @@ -196,7 +196,7 @@ def send(self, message):

try:
messageBuilder = None
if self.__version == '3':
if self.__pyversion == '3':
messageBuilder = lenStr + bytes(message, 'latin-1')

else:
Expand Down Expand Up @@ -228,7 +228,7 @@ def recv(self, doStrip=True):
else:
msg = self.__cipherIn.transform(msg)

if type(msg) is bytes and self.__version == '3':
if type(msg) is bytes and self.__pyversion == '3':
msg = msg.decode("latin-1")
return msg

Expand All @@ -244,7 +244,7 @@ def __readFully(self, msgLength):
if not received:
raise SessionException("Session was closed while receiving msgLength=[%d] len(msg)=[%d] "
"len(received)=[%d]" % (msgLength, len(msg), len(received)))
if self.__version == '3':
if self.__pyversion == '3':
msg = b''.join([msg, received])
msgLength = msgLength - len(received.decode('latin-1'))
else:
Expand Down

0 comments on commit 5e76dc5

Please sign in to comment.