diff --git a/.gitignore b/.gitignore index ecb8c2d..67bbe44 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,16 @@ cover .tox *.egg-info +.cache +.eggs +bin +include +lib +lib64 +local +man +pip-selfcheck.json +pyvenv.cfg +share +_build +dist diff --git a/bmemcached/client.py b/bmemcached/client.py index 42011d7..02c84c8 100644 --- a/bmemcached/client.py +++ b/bmemcached/client.py @@ -1,7 +1,3 @@ -try: - import cPickle as pickle -except ImportError: - import pickle import six @@ -14,27 +10,37 @@ class Client(object): """ This is intended to be a client class which implement standard cache interface that common libs do. + + :param servers: A list of servers with ip[:port] or unix socket. + :type servers: list + :param username: If your server requires SASL authentication, provide the username. + :type username: six.string_types + :param password: If your server requires SASL authentication, provide the password. + :type password: six.string_types + :param compression: This memcached client uses zlib compression by default, + but you can change it to any Python module that provides + `compress` and `decompress` functions, such as `bz2`. + :type compression: Python module + :param dumps: Use this to replace the object serialization mechanism. + The default is JSON encoding. + :type dumps: function + :param loads: Use this to replace the object deserialization mechanism. + The default is JSON decoding. + :type dumps: function + :param socket_timeout: The timeout applied to memcached connections. + :type socket_timeout: float """ def __init__(self, servers=('127.0.0.1:11211',), username=None, password=None, compression=None, socket_timeout=_SOCKET_TIMEOUT, - pickle_protocol=0, - pickler=pickle.Pickler, unpickler=pickle.Unpickler): - """ - :param servers: A list of servers with ip[:port] or unix socket. - :type servers: list - :param username: If your server have auth activated, provide it's username. - :type username: six.string_type - :param password: If your server have auth activated, provide it's password. - :type password: six.string_type - """ + dumps=None, + loads=None): self.username = username self.password = password self.compression = compression self.socket_timeout = socket_timeout - self.pickle_protocol = pickle_protocol - self.pickler = pickler - self.unpickler = unpickler + self.dumps = dumps + self.loads = loads self.set_servers(servers) @property @@ -55,14 +61,15 @@ def set_servers(self, servers): servers = [servers] assert servers, "No memcached servers supplied" - self._servers = [Protocol(server, - self.username, - self.password, - self.compression, - self.socket_timeout, - self.pickle_protocol, - self.pickler, - self.unpickler) for server in servers] + self._servers = [Protocol( + server=server, + username=self.username, + password=self.password, + compression=self.compression, + socket_timeout=self.socket_timeout, + dumps=self.dumps, + loads=self.loads, + ) for server in servers] def _set_retry_delay(self, value): for server in self._servers: @@ -88,7 +95,7 @@ def get(self, key, get_cas=False): Get a key from server. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param get_cas: If true, return (value, cas), where cas is the new CAS value. :type get_cas: boolean :return: Returns a key data from server. @@ -109,7 +116,7 @@ def gets(self, key): This method is for API compatibility with other implementations. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :return: Returns (key data, value), or (None, None) if the value is not in cache. :rtype: object """ @@ -143,45 +150,53 @@ def get_multi(self, keys, get_cas=False): break return d - def set(self, key, value, time=0): + def set(self, key, value, time=0, compress_level=-1): """ Set a value for a key on server. :param key: Key's name - :type key: six.string_type + :type key: str :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True in case of success and False in case of failure :rtype: bool """ returns = [] for server in self.servers: - returns.append(server.set(key, value, time)) + returns.append(server.set(key, value, time, compress_level=compress_level)) return any(returns) - def cas(self, key, value, cas, time=0): + def cas(self, key, value, cas, time=0, compress_level=-1): """ Set a value for a key on server if its CAS value matches cas. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True in case of success and False in case of failure :rtype: bool """ returns = [] for server in self.servers: - returns.append(server.cas(key, value, cas, time)) + returns.append(server.cas(key, value, cas, time, compress_level=compress_level)) return any(returns) - def set_multi(self, mappings, time=0): + def set_multi(self, mappings, time=0, compress_level=-1): """ Set multiple keys with it's values on server. @@ -189,51 +204,63 @@ def set_multi(self, mappings, time=0): :type mappings: dict :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True in case of success and False in case of failure :rtype: bool """ returns = [] if mappings: for server in self.servers: - returns.append(server.set_multi(mappings, time)) + returns.append(server.set_multi(mappings, time, compress_level=compress_level)) return all(returns) - def add(self, key, value, time=0): + def add(self, key, value, time=0, compress_level=-1): """ Add a key/value to server ony if it does not exist. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True if key is added False if key already exists :rtype: bool """ returns = [] for server in self.servers: - returns.append(server.add(key, value, time)) + returns.append(server.add(key, value, time, compress_level=compress_level)) return any(returns) - def replace(self, key, value, time=0): + def replace(self, key, value, time=0, compress_level=-1): """ Replace a key/value to server ony if it does exist. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True if key is replace False if key does not exists :rtype: bool """ returns = [] for server in self.servers: - returns.append(server.replace(key, value, time)) + returns.append(server.replace(key, value, time, compress_level=compress_level)) return any(returns) @@ -242,7 +269,7 @@ def delete(self, key, cas=0): Delete a key/value from server. If key does not exist, it returns True. :param key: Key's name to be deleted - :type key: six.string_type + :type key: six.string_types :return: True in case o success and False in case of failure. :rtype: bool """ @@ -264,7 +291,7 @@ def incr(self, key, value): Increment a key, if it exists, returns it's actual value, if it don't, return 0. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: Number to be incremented :type value: int :return: Actual value of the key on server @@ -282,7 +309,7 @@ def decr(self, key, value): Minimum value of decrement return is 0. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: Number to be decremented :type value: int :return: Actual value of the key on server @@ -314,7 +341,7 @@ def stats(self, key=None): Return server stats. :param key: Optional if you want status from a key. - :type key: six.string_type + :type key: six.string_types :return: A dict with server stats :rtype: dict """ diff --git a/bmemcached/protocol.py b/bmemcached/protocol.py index adbda22..2bf8888 100644 --- a/bmemcached/protocol.py +++ b/bmemcached/protocol.py @@ -9,15 +9,10 @@ except ImportError: from urllib.parse import splitport -try: - import cPickle as pickle -except ImportError: - import pickle - assert pickle - import zlib -from io import BytesIO import six +from six import binary_type, text_type +import json from bmemcached.compat import long from bmemcached.exceptions import AuthenticationNotSupported, InvalidCredentials, MemcachedException @@ -72,18 +67,19 @@ class Protocol(threading.local): } FLAGS = { - 'pickle': 1 << 0, + 'object': 1 << 0, 'integer': 1 << 1, 'long': 1 << 2, - 'compressed': 1 << 3 + 'compressed': 1 << 3, + 'binary': 1 << 4, } MAXIMUM_EXPIRE_TIME = 0xfffffffe COMPRESSION_THRESHOLD = 128 - def __init__(self, server, username=None, password=None, compression=None, socket_timeout=None, pickle_protocol=0, - pickler=None, unpickler=None): + def __init__(self, server, username=None, password=None, compression=None, socket_timeout=None, + dumps=None, loads=None): super(Protocol, self).__init__() self.server = server self._username = username @@ -93,9 +89,8 @@ def __init__(self, server, username=None, password=None, compression=None, socke self.connection = None self.authenticated = False self.socket_timeout = socket_timeout - self.pickle_rotocol = pickle_protocol - self.pickler = pickler - self.unpickler = unpickler + self.dumps = dumps + self.loads = loads self.reconnects_deferred_until = None @@ -169,7 +164,7 @@ def _read_socket(self, size): :param size: Size in bytes to be read. :type size: int :return: Data from socket - :rtype: six.string_type + :rtype: six.string_types """ value = b'' while len(value) < size: @@ -233,9 +228,9 @@ def authenticate(self, username, password): Authenticate user on server. :param username: Username used to be authenticated. - :type username: six.string_type + :type username: six.string_types :param password: Password used to be authenticated. - :type password: six.string_type + :type password: six.string_types :return: True if successful. :raises: InvalidCredentials, AuthenticationNotSupported, MemcachedException :rtype: bool @@ -252,7 +247,7 @@ def _send_authentication(self): if not self._username or not self._password: return False - logger.info('Authenticating as %s' % self._username) + logger.info('Authenticating as %s', self._username) self._send(struct.pack(self.HEADER_STRUCT, self.MAGIC['request'], self.COMMANDS['auth_negotiation']['command'], @@ -277,7 +272,7 @@ def _send_authentication(self): method = b'PLAIN' auth = '\x00%s\x00%s' % (self._username, self._password) - if six.PY3: + if isinstance(auth, text_type): auth = auth.encode() self._send(struct.pack(self.HEADER_STRUCT + @@ -297,24 +292,29 @@ def _send_authentication(self): if status != self.STATUS['success']: raise MemcachedException('Code: %d Message: %s' % (status, extra_content)) - logger.debug('Auth OK. Code: %d Message: %s' % (status, extra_content)) + logger.debug('Auth OK. Code: %d Message: %s', status, extra_content) self.authenticated = True return True - def serialize(self, value): + def serialize(self, value, compress_level=-1): """ - Serializes a value based on it's type. + Serializes a value based on its type. :param value: Something to be serialized - :type value: six.string_type, int, long, object + :type value: six.string_types, int, long, object + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: Serialized type :rtype: str """ flags = 0 - if isinstance(value, str): - if six.PY3: - value = value.encode('utf8') + if isinstance(value, binary_type): + flags |= self.FLAGS['binary'] + elif isinstance(value, text_type): + value = value.encode('utf8') elif isinstance(value, int) and isinstance(value, bool) is False: flags |= self.FLAGS['integer'] value = str(value) @@ -322,15 +322,23 @@ def serialize(self, value): flags |= self.FLAGS['long'] value = str(value) else: - flags |= self.FLAGS['pickle'] - buf = BytesIO() - pickler = self.pickler(buf, self.pickle_rotocol) - pickler.dump(value) - value = buf.getvalue() - - if len(value) > self.COMPRESSION_THRESHOLD: - value = self.compression.compress(value) - flags |= self.FLAGS['compressed'] + flags |= self.FLAGS['object'] + dumps = self.dumps + if dumps is None: + dumps = self.json_dumps + value = dumps(value) + + if compress_level != 0 and len(value) > self.COMPRESSION_THRESHOLD: + if compress_level is not None and compress_level > 0: + # Use the specified compression level. + compressed_value = self.compression.compress(value, compress_level) + else: + # Use the default compression level. + compressed_value = self.compression.compress(value) + # Use the compressed value only if it is actually smaller. + if compressed_value and len(compressed_value) < len(value): + value = compressed_value + flags |= self.FLAGS['compressed'] return flags, value @@ -339,28 +347,47 @@ def deserialize(self, value, flags): Deserialized values based on flags or just return it if it is not serialized. :param value: Serialized or not value. - :type value: six.string_type, int + :type value: six.string_types, int :param flags: Value flags :type flags: int :return: Deserialized value - :rtype: six.string_type|int + :rtype: six.string_types|int """ - to_str = lambda v: v.decode('utf8') if six.PY3 else v + FLAGS = self.FLAGS - if flags & self.FLAGS['compressed']: # pragma: no branch + if flags & FLAGS['compressed']: # pragma: no branch value = self.compression.decompress(value) - if flags & self.FLAGS['integer']: - return int(to_str(value)) - elif flags & self.FLAGS['long']: - return long(to_str(value)) - elif flags & self.FLAGS['pickle']: - buf = BytesIO(value) + if flags & FLAGS['binary']: + return value + + if flags & FLAGS['integer']: + return int(value) + elif flags & FLAGS['long']: + return long(value) + elif flags & FLAGS['object']: + loads = self.loads + if loads is None: + loads = self.json_loads + return loads(value) + + if six.PY3: + return value.decode('utf8') + + # In Python 2, mimic the behavior of the json library: return a str + # unless the value contains unicode characters. + try: + value.decode('ascii') + except UnicodeDecodeError: + return value.decode('utf8') + else: + return value - unpickler = self.unpickler(buf) - return unpickler.load() + def json_dumps(self, value): + return json.dumps(value).encode('utf8') - return to_str(value) + def json_loads(self, value): + return json.loads(value.decode('utf8')) def get(self, key): """ @@ -368,11 +395,11 @@ def get(self, key): (None, None). :param key: Key's name - :type key: six.string_type + :type key: six.string_types :return: Returns (value, cas). :rtype: object """ - logger.info('Getting key %s' % key) + logger.debug('Getting key %s', key) data = struct.pack(self.HEADER_STRUCT + self.COMMANDS['get']['struct'] % (len(key)), self.MAGIC['request'], @@ -383,13 +410,12 @@ def get(self, key): (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() - logger.debug('Value Length: %d. Body length: %d. Data type: %d' % ( - extlen, bodylen, datatype)) + logger.debug('Value Length: %d. Body length: %d. Data type: %d', + extlen, bodylen, datatype) if status != self.STATUS['success']: if status == self.STATUS['key_not_found']: - logger.debug('Key not found. Message: %s' - % extra_content) + logger.debug('Key not found. Message: %s', extra_content) return None, None if status == self.STATUS['server_disconnected']: @@ -450,26 +476,30 @@ def get_multi(self, keys): return d - def _set_add_replace(self, command, key, value, time, cas=0): + def _set_add_replace(self, command, key, value, time, cas=0, compress_level=-1): """ Function to set/add/replace commands. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int :param cas: The CAS value that must be matched for this operation to complete, or 0 for no CAS. :type cas: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True in case of success and False in case of failure :rtype: bool """ time = time if time >= 0 else self.MAXIMUM_EXPIRE_TIME - logger.info('Setting/adding/replacing key %s.' % key) - flags, value = self.serialize(value) - logger.info('Value bytes %d.' % len(value)) - if six.PY3 and isinstance(value, str): + logger.debug('Setting/adding/replacing key %s.', key) + flags, value = self.serialize(value, compress_level=compress_level) + logger.debug('Value bytes %s.', len(value)) + if isinstance(value, text_type): value = value.encode('utf8') self._send(struct.pack(self.HEADER_STRUCT + @@ -493,31 +523,39 @@ def _set_add_replace(self, command, key, value, time, cas=0): return True - def set(self, key, value, time): + def set(self, key, value, time, compress_level=-1): """ Set a value for a key on server. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True in case of success and False in case of failure :rtype: bool """ - return self._set_add_replace('set', key, value, time) + return self._set_add_replace('set', key, value, time, compress_level=compress_level) - def cas(self, key, value, cas, time): + def cas(self, key, value, cas, time, compress_level=-1): """ Add a key/value to server ony if it does not exist. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True if key is added False if key already exists and has a different CAS :rtype: bool """ @@ -529,41 +567,49 @@ def cas(self, key, value, cas, time): # If we get a cas of None, interpret that as "compare against nonexistant and set", # which is simply Add. if cas is None: - return self._set_add_replace('add', key, value, time) + return self._set_add_replace('add', key, value, time, compress_level=compress_level) else: - return self._set_add_replace('set', key, value, time, cas=cas) + return self._set_add_replace('set', key, value, time, cas=cas, compress_level=compress_level) - def add(self, key, value, time): + def add(self, key, value, time, compress_level=-1): """ Add a key/value to server ony if it does not exist. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True if key is added False if key already exists :rtype: bool """ - return self._set_add_replace('add', key, value, time) + return self._set_add_replace('add', key, value, time, compress_level=compress_level) - def replace(self, key, value, time): + def replace(self, key, value, time, compress_level=-1): """ Replace a key/value to server ony if it does exist. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: A value to be stored on server. :type value: object :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True if key is replace False if key does not exists :rtype: bool """ - return self._set_add_replace('replace', key, value, time) + return self._set_add_replace('replace', key, value, time, compress_level=compress_level) - def set_multi(self, mappings, time=100): + def set_multi(self, mappings, time=100, compress_level=-1): """ Set multiple keys with its values on server. @@ -574,6 +620,10 @@ def set_multi(self, mappings, time=100): :type mappings: dict :param time: Time in seconds that your key will expire. :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int :return: True :rtype: bool """ @@ -593,7 +643,7 @@ def set_multi(self, mappings, time=100): else: command = 'setq' - flags, value = self.serialize(value) + flags, value = self.serialize(value, compress_level=compress_level) m = struct.pack(self.HEADER_STRUCT + self.COMMANDS[command]['struct'] % (len(key), len(value)), self.MAGIC['request'], @@ -634,7 +684,7 @@ def _incr_decr(self, command, key, value, default, time): Function which increments and decrements. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: Number to be (de|in)cremented :type value: int :param default: Default value if key does not exist. @@ -665,10 +715,10 @@ def _incr_decr(self, command, key, value, default, time): def incr(self, key, value, default=0, time=1000000): """ - Increment a key, if it exists, returns it's actual value, if it don't, return 0. + Increment a key, if it exists, returns its actual value, if it doesn't, return 0. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: Number to be incremented :type value: int :param default: Default value if key does not exist. @@ -682,11 +732,11 @@ def incr(self, key, value, default=0, time=1000000): def decr(self, key, value, default=0, time=100): """ - Decrement a key, if it exists, returns it's actual value, if it don't, return 0. + Decrement a key, if it exists, returns its actual value, if it doesn't, return 0. Minimum value of decrement return is 0. :param key: Key's name - :type key: six.string_type + :type key: six.string_types :param value: Number to be decremented :type value: int :param default: Default value if key does not exist. @@ -700,16 +750,16 @@ def decr(self, key, value, default=0, time=100): def delete(self, key, cas=0): """ - Delete a key/value from server. If key existed and was deleted, it returns True. + Delete a key/value from server. If key existed and was deleted, return True. :param key: Key's name to be deleted - :type key: six.string_type + :type key: six.string_types :param cas: If set, only delete the key if its CAS value matches. :type cas: int :return: True in case o success and False in case of failure. :rtype: bool """ - logger.info('Deleting key %s' % key) + logger.debug('Deleting key %s', key) self._send(struct.pack(self.HEADER_STRUCT + self.COMMANDS['delete']['struct'] % len(key), self.MAGIC['request'], @@ -724,7 +774,7 @@ def delete(self, key, cas=0): if status != self.STATUS['success'] and status not in (self.STATUS['key_not_found'], self.STATUS['key_exists']): raise MemcachedException('Code: %d message: %s' % (status, extra_content)) - logger.debug('Key deleted %s' % key) + logger.debug('Key deleted %s', key) return status != self.STATUS['key_exists'] def delete_multi(self, keys): @@ -736,7 +786,7 @@ def delete_multi(self, keys): :return: True in case of success and False in case of failure. :rtype: bool """ - logger.info('Deleting keys %r' % keys) + logger.debug('Deleting keys %r', keys) if six.PY2: msg = '' else: @@ -800,13 +850,13 @@ def stats(self, key=None): Return server stats. :param key: Optional if you want status from a key. - :type key: six.string_type + :type key: six.string_types :return: A dict with server stats :rtype: dict """ # TODO: Stats with key is not working. if key is not None: - if isinstance(key, str) and six.PY3: + if isinstance(key, text_type): key = str_to_bytes(key) keylen = len(key) packed = struct.pack( diff --git a/test/test_compression.py b/test/test_compression.py index 512cb55..26b2ba7 100644 --- a/test/test_compression.py +++ b/test/test_compression.py @@ -2,6 +2,12 @@ import bz2 import bmemcached +import six +if six.PY3: + from unittest import mock +else: + import mock + class MemcachedTests(unittest.TestCase): def setUp(self): @@ -9,7 +15,7 @@ def setUp(self): self.client = bmemcached.Client(self.server, 'user', 'password') self.bzclient = bmemcached.Client(self.server, 'user', 'password', compression=bz2) - self.data = b'this is test data. ' * 32 + self.data = 'this is test data. ' * 32 def tearDown(self): self.client.delete('test_key') @@ -31,3 +37,24 @@ def testCompressionMissmatch(self): self.assertEqual(self.client.get('test_key'), self.bzclient.get('test_key2')) self.assertRaises(IOError, self.bzclient.get, 'test_key') + + def testCompressionEnabled(self): + import zlib + compression = mock.Mock() + compression.compress.side_effect = zlib.compress + compression.decompress.side_effect = zlib.decompress + for proto in self.client._servers: + proto.compression = compression + self.client.set('test_key', self.data) + self.assertEqual(self.data, self.client.get('test_key')) + compression.compress.assert_called_with(self.data.encode('ascii')) + self.assertEqual(1, compression.decompress.call_count) + + def testCompressionDisabled(self): + compression = mock.Mock() + for proto in self.client._servers: + proto.compression = compression + self.client.set('test_key', self.data, compress_level=0) + self.assertEqual(self.data, self.client.get('test_key')) + compression.compress.assert_not_called() + compression.decompress.assert_not_called() diff --git a/test/test_error_handling.py b/test/test_error_handling.py index 8eae8a1..73a3625 100644 --- a/test/test_error_handling.py +++ b/test/test_error_handling.py @@ -95,7 +95,7 @@ def setUp(self): self._stop_proxy() self._start_proxy() - self.client = bmemcached.Client(self.server) + self.client = bmemcached.Client(self.server, 'user', 'password') # Disable retry delays, so we can disconnect and reconnect from the # server without needing to put delays in most of the tests. diff --git a/test/test_pickler.py b/test/test_pickler.py index 786f552..4fed713 100644 --- a/test/test_pickler.py +++ b/test/test_pickler.py @@ -1,57 +1,46 @@ -from io import BytesIO try: import cPickle as pickle except ImportError: import pickle -import json import unittest import bmemcached -class JsonPickler(object): +class PickleableThing(object): + pass - def __init__(self, f, protocol=0): - self.f = f - def dump(self, obj): - # if isinstance(obj, str): - # obj = obj.encode() - - if isinstance(self.f, BytesIO): - return self.f.write(json.dumps(obj).encode()) - - return json.dump(obj, self.f) - - def load(self): - if isinstance(self.f, BytesIO): - return json.loads(self.f.read().decode()) - return json.load(self.f) - - -class MemcachedTests(unittest.TestCase): +class PicklerTests(unittest.TestCase): def setUp(self): self.server = '127.0.0.1:11211' - self.dclient = bmemcached.Client(self.server, 'user', 'password') - self.jclient = bmemcached.Client(self.server, 'user', 'password', - pickler=JsonPickler, - unpickler=JsonPickler) + self.json_client = bmemcached.Client(self.server, 'user', 'password') + self.pickle_client = bmemcached.Client(self.server, 'user', 'password', + dumps=pickle.dumps, + loads=pickle.loads) self.data = {'a': 'b'} def tearDown(self): - self.jclient.delete('test_key') - self.jclient.disconnect_all() - self.dclient.disconnect_all() - - def testJson(self): - self.jclient.set('test_key', self.data) - self.assertEqual(self.data, self.jclient.get('test_key')) - - def testDefaultVsJson(self): - self.dclient.set('test_key', self.data) - self.assertRaises(ValueError, self.jclient.get, 'test_key') - - def testJsonVsDefault(self): - self.jclient.set('test_key', self.data) - self.assertRaises(pickle.UnpicklingError, self.dclient.get, 'test_key') + self.json_client.delete('test_key') + self.json_client.disconnect_all() + self.pickle_client.disconnect_all() + + def testPickleDict(self): + self.pickle_client.set('test_key', self.data) + self.assertEqual(self.data, self.pickle_client.get('test_key')) + + def testPickleClassInstance(self): + to_pickle = PickleableThing() + self.pickle_client.set('test_key', to_pickle) + unpickled = self.pickle_client.get('test_key') + self.assertEqual(type(unpickled), PickleableThing) + self.assertFalse(unpickled is to_pickle) + + def testPickleVsJson(self): + self.pickle_client.set('test_key', self.data) + self.assertRaises(ValueError, self.json_client.get, 'test_key') + + def testJsonVsPickle(self): + self.json_client.set('test_key', self.data) + self.assertRaises(pickle.UnpicklingError, self.pickle_client.get, 'test_key') diff --git a/test/test_simple_functions.py b/test/test_simple_functions.py index de9f923..9f204f7 100644 --- a/test/test_simple_functions.py +++ b/test/test_simple_functions.py @@ -2,12 +2,19 @@ import bmemcached from bmemcached.compat import long, unicode +import six +if six.PY3: + from unittest import mock +else: + import mock + class MemcachedTests(unittest.TestCase): def setUp(self): self.server = '127.0.0.1:11211' self.server = '/tmp/memcached.sock' - self.client = bmemcached.Client(self.server) + self.client = bmemcached.Client(self.server, 'user', 'password') + self.reset() def tearDown(self): self.reset() @@ -29,10 +36,20 @@ def testSetMultiBigData(self): self.client.set_multi( dict((unicode(k), b'value') for k in range(32767))) - def testGet(self): + def testGetSimple(self): self.client.set('test_key', 'test') self.assertEqual('test', self.client.get('test_key')) + def testGetBytes(self): + # Ensure the code is 8-bit clean. + value = b'\x01z\x7f\x00\x80\xfe\xff\x00' + self.client.set('test_key', value) + self.assertEqual(value, self.client.get('test_key')) + + def testGetDecodedText(self): + self.client.set('test_key', u'\u30b7') + self.assertEqual(u'\u30b7', self.client.get('test_key')) + def testCas(self): value, cas = self.client.gets('nonexistant') self.assertTrue(value is None) @@ -113,8 +130,8 @@ def testGetEmptyString(self): self.assertEqual('', self.client.get('test_key')) def testGetUnicodeString(self): - self.client.set('test_key', '\xac') - self.assertEqual('\xac', self.client.get('test_key')) + self.client.set('test_key', u'\xac') + self.assertEqual(u'\xac', self.client.get('test_key')) def testGetMulti(self): self.assertTrue(self.client.set_multi({ @@ -221,6 +238,13 @@ def tearDown(self): def testTimeout(self): self.client = bmemcached.Client(self.server, 'user', 'password', socket_timeout=0.00000000000001) + + for proto in self.client._servers: + # Set up a mock connection that gives the impression of + # timing out in every recv() call. + proto.connection = mock.Mock() + proto.connection.recv.return_value = b'' + self.client.set('timeout_key', 'test') self.assertEqual(self.client.get('timeout_key'), None)