Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merged from python-memcached changes since 1.44 including CAS, IPv6 a…

…nd many bugfixes and performance improvements.
  • Loading branch information...
commit fac56ebb16dbab690a96fac2ba63cb8b04e2c4af 2 parents 20814b4 + 6c921fe
@eguven authored
Showing with 367 additions and 149 deletions.
  1. +2 −1  README
  2. +365 −148 memcache.py
View
3  README
@@ -1,4 +1,5 @@
-Python3 port for pure python memcached client library.
+Python3 port for pure python memcached client library, ported and being
+kept up to date by https://github.com/eguven
Please report issues and submit code changes to the github repository at:
View
513 memcache.py
@@ -48,14 +48,20 @@
import time
import os
import re
-import types
-try:
- import pickle as pickle
-except ImportError:
- import pickle
+import pickle
from io import StringIO, BytesIO
+from binascii import crc32 # zlib version is not cross-platform
+def cmemcache_hash(key):
+ return((((crc32(key) & 0xffffffff) >> 16) & 0x7fff) or 1)
+serverHashFunction = cmemcache_hash
+
+def useOldServerHashFunction():
+ """Use the old python-memcache server hash function."""
+ global serverHashFunction
+ serverHashFunction = crc32
+
try:
from zlib import compress, decompress
_supports_compress = True
@@ -63,15 +69,16 @@
_supports_compress = False
# quickly define a decompress just in case we recv compressed data.
def decompress(val):
- raise _Error("received compressed data but I don't support compession (import error)")
+ raise _Error("received compressed data but I don't support compression (import error)")
-from binascii import crc32 # zlib version is not cross-platform
-serverHashFunction = crc32
+invalid_key_characters = ''.join(map(chr, list(range(33)) + [127]))
-__author__ = "Evan Martin <martine@danga.com>"
-__version__ = "1.44.1"
+# Original author: Evan Martin of Danga Interactive
+__author__ = "Sean Reifschneider <jafo-memcached@tummy.com>"
+__version__ = "1.51"
__copyright__ = "Copyright (C) 2003 Danga Interactive"
-__license__ = "Python"
+# http://en.wikipedia.org/wiki/Python_Software_Foundation_License
+__license__ = "Python Software Foundation License"
SERVER_MAX_KEY_LENGTH = 250
# Storing values larger than 1MB requires recompiling memcached. If you do,
@@ -82,13 +89,16 @@ def decompress(val):
class _Error(Exception):
pass
-try:
- # Only exists in Python 2.4+
- from threading import local
-except ImportError:
- # TODO: add the pure-python local implementation
- class local(object):
- pass
+
+class _ConnectionDeadError(Exception):
+ pass
+
+
+from threading import local
+
+
+_DEAD_RETRY = 30 # number of seconds before retrying a dead server.
+_SOCKET_TIMEOUT = 3 # number of seconds before sockets timeout.
class Client(local):
@@ -135,7 +145,11 @@ class MemcachedStringEncodingError(Exception):
def __init__(self, servers, debug=0, pickleProtocol=0,
pickler=pickle.Pickler, unpickler=pickle.Unpickler,
- pload=None, pid=None):
+ pload=None, pid=None,
+ server_max_key_length=SERVER_MAX_KEY_LENGTH,
+ server_max_value_length=SERVER_MAX_VALUE_LENGTH,
+ dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT,
+ cache_cas = False, flush_on_reconnect=0, check_keys=True):
"""
Create a new Client object with the given list of servers.
@@ -149,11 +163,38 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
Useful for cPickle since subclassing isn't allowed.
@param pid: optional persistent_id function to call on pickle storing.
Useful for cPickle since subclassing isn't allowed.
+ @param dead_retry: number of seconds before retrying a blacklisted
+ server. Default to 30 s.
+ @param socket_timeout: timeout in seconds for all calls to a server. Defaults
+ to 3 seconds.
+ @param cache_cas: (default False) If true, cas operations will be
+ cached. WARNING: This cache is not expired internally, if you have
+ a long-running process you will need to expire it manually via
+ client.reset_cas(), or the cache can grow unlimited.
+ @param server_max_key_length: (default SERVER_MAX_KEY_LENGTH)
+ Data that is larger than this will not be sent to the server.
+ @param server_max_value_length: (default SERVER_MAX_VALUE_LENGTH)
+ Data that is larger than this will not be sent to the server.
+ @param flush_on_reconnect: optional flag which prevents a scenario that
+ can cause stale data to be read: If there's more than one memcached
+ server and the connection to one is interrupted, keys that mapped to
+ that server will get reassigned to another. If the first server comes
+ back, those keys will map to it again. If it still has its data, get()s
+ can read stale data that was overwritten on another server. This flag
+ is off by default for backwards compatibility.
+ @param check_keys: (default True) If True, the key is checked to
+ ensure it is the correct length and composed of the right characters.
"""
local.__init__(self)
- self.set_servers(servers)
self.debug = debug
+ self.dead_retry = dead_retry
+ self.socket_timeout = socket_timeout
+ self.flush_on_reconnect = flush_on_reconnect
+ self.set_servers(servers)
self.stats = {}
+ self.cache_cas = cache_cas
+ self.reset_cas()
+ self.do_check_key = check_keys
# Allow users to modify pickling/unpickling behavior
self.pickleProtocol = pickleProtocol
@@ -161,6 +202,8 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
self.unpickler = unpickler
self.persistent_load = pload
self.persistent_id = pid
+ self.server_max_key_length = server_max_key_length
+ self.server_max_value_length = server_max_value_length
# figure out the pickler style
file = StringIO()
@@ -170,6 +213,17 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
except TypeError:
self.picklerIsKeyword = False
+
+ def reset_cas(self):
+ """
+ Reset the cas cache. This is only used if the Client() object
+ was created with "cache_cas=True". If used, this cache does not
+ expire internally, so it can grow unbounded if you do not clear it
+ yourself.
+ """
+ self.cas_ids = {}
+
+
def set_servers(self, servers):
"""
Set the pool of servers used by this client.
@@ -180,12 +234,18 @@ def set_servers(self, servers):
2. Tuples of the form C{("host:port", weight)}, where C{weight} is
an integer weight value.
"""
- self.servers = [_Host(s, self.debuglog) for s in servers]
+ self.servers = [_Host(s, self.debug, dead_retry=self.dead_retry,
+ socket_timeout=self.socket_timeout,
+ flush_on_reconnect=self.flush_on_reconnect)
+ for s in servers]
self._init_buckets()
- def get_stats(self):
+ def get_stats(self, stat_args = None):
'''Get statistics from each of the servers.
+ @param stat_args: Additional arguments to pass to the memcache
+ "stats" command.
+
@return: A list of tuples ( server_identifier, stats_dictionary ).
The dictionary contains a number of name/value pairs specifying
the name of the status field and the string value associated with
@@ -196,9 +256,16 @@ def get_stats(self):
if not s.connect(): continue
if s.family == socket.AF_INET:
name = '%s:%s (%s)' % ( s.ip, s.port, s.weight )
+ elif s.family == socket.AF_INET6:
+ name = '[%s]:%s (%s)' % ( s.ip, s.port, s.weight )
else:
name = 'unix:%s (%s)' % ( s.address, s.weight )
- s.send_cmd(b'stats')
+ if not stat_args:
+ s.send_cmd(b'stats')
+ elif isinstance(stat_args, bytes):
+ s.send_cmd(b'stats ' + stat_args)
+ else:
+ s.send_cmd(b'stats ' + str(stat_args).encode('utf-8'))
serverData = {}
data.append(( name.encode('ascii'), serverData ))
readline = s.readline
@@ -216,6 +283,8 @@ def get_slabs(self):
if not s.connect(): continue
if s.family == socket.AF_INET:
name = '%s:%s (%s)' % ( s.ip, s.port, s.weight )
+ elif s.family == socket.AF_INET6:
+ name = '[%s]:%s (%s)' % ( s.ip, s.port, s.weight )
else:
name = 'unix:%s (%s)' % ( s.address, s.weight )
serverData = {}
@@ -235,11 +304,10 @@ def get_slabs(self):
return data
def flush_all(self):
- 'Expire all data currently in the memcache servers.'
+ """Expire all data in memcache servers that are reachable."""
for s in self.servers:
if not s.connect(): continue
- s.send_cmd(b'flush_all')
- s.expect(b"OK")
+ s.flush()
def debuglog(self, str):
if self.debug:
@@ -265,7 +333,7 @@ def _init_buckets(self):
self.buckets.append(server)
def _get_server(self, key):
- if type(key) == tuple:
+ if isinstance(key, tuple):
serverhash, key = key
else:
serverhash = serverHashFunction(key.encode('utf-8'))
@@ -330,7 +398,7 @@ def delete_multi(self, keys, time=0, key_prefix=''):
server.send_cmds(''.join(bigcmd))
except socket.error as msg:
rc = 0
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
@@ -338,13 +406,12 @@ def delete_multi(self, keys, time=0, key_prefix=''):
for server in dead_servers:
del server_keys[server]
- notstored = [] # original keys.
for server, keys in server_keys.items():
try:
for key in keys:
server.expect(b"DELETED")
except socket.error as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
rc = 0
return rc
@@ -353,27 +420,31 @@ def delete(self, key, time=0):
'''Deletes a key from the memcache.
@return: Nonzero on success.
- @param time: number of seconds any subsequent set / update commands should fail. Defaults to 0 for no delay.
+ @param time: number of seconds any subsequent set / update commands
+ should fail. Defaults to None for no delay.
@rtype: int
'''
- check_key(key)
+ if self.do_check_key:
+ self.check_key(key)
server, key = self._get_server(key)
if not server:
return 0
self._statlog('delete')
- if time != None:
+ if time != None and time != 0:
cmd = "delete %s %d" % (key, time)
else:
cmd = "delete %s" % key
try:
server.send_cmd(cmd.encode('utf-8'))
- server.expect(b"DELETED")
+ line = server.readline()
+ if line and line.strip() in [b'DELETED', b'NOT_FOUND']: return 1
+ self.debuglog('Delete expected DELETED or NOT_FOUND, got: %s'
+ % repr(line))
except socket.error as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
- return 0
- return 1
+ return 0
def incr(self, key, delta=1):
"""
@@ -408,16 +479,17 @@ def decr(self, key, delta=1):
returns 0, not -1.
@param delta: Integer amount to decrement by (should be zero or greater).
- @return: New value after decrementing.
+ @return: New value after decrementing or None on error.
@rtype: int
"""
return self._incrdecr("decr", key, delta)
def _incrdecr(self, cmd, key, delta):
- check_key(key)
+ if self.do_check_key:
+ self.check_key(key)
server, key = self._get_server(key)
if not server:
- return 0
+ return None
self._statlog(cmd)
cmd = "%s %s %d" % (cmd, key, delta)
try:
@@ -501,6 +573,35 @@ def set(self, key, val, time=0, min_compress_len=0):
return self._set("set", key, val, time, min_compress_len)
+ def cas(self, key, val, time=0, min_compress_len=0):
+ '''Sets a key to a given value in the memcache if it hasn't been
+ altered since last fetched. (See L{gets}).
+
+ The C{key} can optionally be an tuple, with the first element
+ being the server hash value and the second being the key.
+ If you want to avoid making this module calculate a hash value.
+ You may prefer, for example, to keep all of a given user's objects
+ on the same memcache server, so you could use the user's unique
+ id as the hash value.
+
+ @return: Nonzero on success.
+ @rtype: int
+ @param time: Tells memcached the time which this value should expire,
+ either as a delta number of seconds, or an absolute unix
+ time-since-the-epoch value. See the memcached protocol docs section
+ "Storage Commands" for more info on <exptime>. We default to
+ 0 == cache forever.
+ @param min_compress_len: The threshold length to kick in
+ auto-compression of the value using the zlib.compress() routine. If
+ the value being cached is a string, then the length of the string is
+ measured, else if the value is an object, then the length of the
+ pickle result is measured. If the resulting attempt at compression
+ yeilds a larger string than the input, then it is discarded. For
+ backwards compatability, this parameter defaults to 0, indicating
+ don't ever try to compress.
+ '''
+ return self._set("cas", key, val, time, min_compress_len)
+
def _map_and_prefix_keys(self, key_iterable, key_prefix):
"""Compute the mapping of server (_Host instance) -> list of keys to stuff onto that server, as well as the mapping of
prefixed key -> original key.
@@ -509,8 +610,8 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
"""
# Check it just once ...
key_extra_len=len(key_prefix)
- if key_prefix:
- check_key(key_prefix)
+ if key_prefix and self.do_check_key:
+ self.check_key(key_prefix)
# server (_Host) -> list of unprefixed server keys in mapping
server_keys = {}
@@ -518,7 +619,7 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
prefixed_to_orig_key = {}
# build up a list for each server of all the keys we want.
for orig_key in key_iterable:
- if type(orig_key) is tuple:
+ if isinstance(orig_key, tuple):
# Tuple of hashvalue, key ala _get_server(). Caller is essentially telling us what server to stuff this on.
# Ensure call to _get_server gets a Tuple as well.
str_orig_key = str(orig_key[1])
@@ -528,7 +629,8 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
server, key = self._get_server(key_prefix + str_orig_key)
# Now check to make sure key length is proper ...
- check_key(str_orig_key, key_extra_len=key_extra_len)
+ if self.do_check_key:
+ self.check_key(str_orig_key, key_extra_len=key_extra_len)
if not server:
continue
@@ -583,12 +685,11 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0):
self._statlog('set_multi')
-
-
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(iter(mapping.keys()), key_prefix)
# send out all requests on each server before reading anything
dead_servers = []
+ notstored = [] # original keys.
for server in server_keys.keys():
bigcmd = bytearray()
@@ -597,15 +698,18 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0):
newline = "\r\n".encode('utf-8')
for key in server_keys[server]: # These are mangled keys
store_info = self._val_to_store_info(mapping[prefixed_to_orig_key[key]], min_compress_len)
- cmd = ("set %s %d %d %d\r\n" % (key, store_info[0], time, store_info[1])).encode('utf-8')
- if not isinstance(store_info[2],bytes):
- cmd += store_info[2].encode('utf-8')
+ if store_info:
+ cmd = ("set %s %d %d %d\r\n" % (key, store_info[0], time, store_info[1])).encode('utf-8')
+ if not isinstance(store_info[2],bytes):
+ cmd += store_info[2].encode('utf-8')
+ else:
+ cmd += store_info[2]
+ write(cmd + newline)
else:
- cmd += store_info[2]
- write(cmd + newline)
+ notstored.append(prefixed_to_orig_key[key])
server.send_cmds(bytes(bigcmd))
except socket.error as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
@@ -616,7 +720,6 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0):
# short-circuit if there are no servers, just return all keys
if not server_keys: return(list(mapping.keys()))
- notstored = [] # original keys.
for server, keys in server_keys.items():
try:
for key in keys:
@@ -626,7 +729,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0):
else:
notstored.append(prefixed_to_orig_key[key]) #un-mangle.
except (_Error, socket.error) as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
return notstored
@@ -670,58 +773,130 @@ def _val_to_store_info(self, val, min_compress_len):
val = comp_val
# silently do not store if value length exceeds maximum
- if len(val) >= SERVER_MAX_VALUE_LENGTH: return(0)
+ if self.server_max_value_length != 0 and \
+ len(val) > self.server_max_value_length: return(0)
return (flags, len(val), val)
+ def _cmd_builder(self, cmd, key, time, store_info):
+ '''A utility method to build platform specific fullcmd, mainly due
+ to pickle return value type.
+ '''
+ # TODO: change _val_to_store_info so we can get rid of this
+ if cmd == 'cas':
+ c = "cas %s %d %d %d %d\r\n" % (
+ key, store_info[0], time, store_info[1], self.cas_ids[key])
+ else:
+ c = "%s %s %d %d %d\r\n" % (
+ cmd, key, store_info[0], time, store_info[1])
+ if isinstance(store_info[2], str):
+ return (c + store_info[2]).encode('utf-8')
+ elif isinstance(store_info[2], bytes):
+ return c.encode('utf-8') + store_info[2]
+ else:
+ raise _Error("_cmd_builder: unknown data type (%s)" %
+ type(store_info[2]))
+
def _set(self, cmd, key, val, time, min_compress_len = 0):
- check_key(key)
+ if self.do_check_key:
+ self.check_key(key)
server, key = self._get_server(key)
if not server:
return 0
- self._statlog(cmd)
+ def _unsafe_set():
+ self._statlog(cmd)
- store_info = self._val_to_store_info(val, min_compress_len)
- if not store_info: return(0)
+ store_info = self._val_to_store_info(val, min_compress_len)
+ if not store_info: return(0)
- fullcmd = ("%s %s %d %d %d\r\n" % (cmd, key, store_info[0], time, store_info[1])).encode('utf-8');
- if not isinstance(store_info[2],bytes):
- fullcmd += store_info[2].encode('utf-8')
- else:
- fullcmd += store_info[2]
- try:
- server.send_cmd(fullcmd)
- return (server.expect(b"STORED") == b"STORED")
- except socket.error as msg:
- if type(msg) is tuple: msg = msg[1]
- server.mark_dead(msg)
- return 0
+ if cmd == 'cas':
+ if key not in self.cas_ids:
+ return self._set('set', key, val, time, min_compress_len)
+ fullcmd = self._cmd_builder(cmd, key, time, store_info)
+ else:
+ fullcmd = self._cmd_builder(cmd, key, time, store_info)
+ try:
+ server.send_cmd(fullcmd)
+ return(server.expect(b"STORED", raise_exception=True)
+ == b"STORED")
+ except socket.error as msg:
+ if isinstance(msg, tuple): msg = msg[1]
+ server.mark_dead(msg)
+ return 0
- def get(self, key):
- '''Retrieves a key from the memcache.
+ try:
+ return _unsafe_set()
+ except _ConnectionDeadError:
+ # retry once
+ try:
+ if server._get_socket():
+ return _unsafe_set()
+ except (_ConnectionDeadError, socket.error) as msg:
+ server.mark_dead(msg)
+ return 0
- @return: The value or None.
- '''
- check_key(key)
+ def _get(self, cmd, key):
+ if self.do_check_key:
+ self.check_key(key)
server, key = self._get_server(key)
if not server:
return None
- self._statlog('get')
+ def _unsafe_get():
+ self._statlog(cmd)
+
+ try:
+ server.send_cmd("{0} {1}".format(cmd, key).encode("utf-8"))
+ rkey = flags = rlen = cas_id = None
+
+ if cmd == 'gets':
+ rkey, flags, rlen, cas_id, = self._expect_cas_value(server,
+ raise_exception=True)
+ if rkey and self.cache_cas:
+ self.cas_ids[rkey] = cas_id
+ else:
+ rkey, flags, rlen, = self._expectvalue(server,
+ raise_exception=True)
+
+ if not rkey:
+ return None
+ try:
+ value = self._recv_value(server, flags, rlen)
+ finally:
+ server.expect(b"END", raise_exception=True)
+ except (_Error, socket.error) as msg:
+ if isinstance(msg, tuple): msg = msg[1]
+ server.mark_dead(msg)
+ return None
+
+ return value
try:
- server.send_cmd("get {0}".format(key).encode("utf-8"))
- rkey, flags, rlen, = self._expectvalue(server)
- if not rkey:
+ return _unsafe_get()
+ except _ConnectionDeadError:
+ # retry once
+ try:
+ if server.connect():
+ return _unsafe_get()
return None
- value = self._recv_value(server, flags, rlen)
- server.expect(b"END")
- except (_Error, socket.error) as msg:
- if type(msg) is tuple: msg = msg[1]
- server.mark_dead(msg)
+ except (_ConnectionDeadError, socket.error) as msg:
+ server.mark_dead(msg)
return None
- return value
+
+ def get(self, key):
+ '''Retrieves a key from the memcache.
+
+ @return: The value or None.
+ '''
+ return self._get('get', key)
+
+ def gets(self, key):
+ '''Retrieves a key from the memcache. Used in conjunction with 'cas'.
+
+ @return: The value or None.
+ '''
+ return self._get('gets', key)
def get_multi(self, keys, key_prefix=''):
'''
@@ -771,7 +946,7 @@ def get_multi(self, keys, key_prefix=''):
try:
server.send_cmd("get %s" % " ".join(server_keys[server]))
except socket.error as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)
@@ -780,6 +955,7 @@ def get_multi(self, keys, key_prefix=''):
del server_keys[server]
retvals = {}
+
for server in server_keys.keys():
try:
line = server.readline()
@@ -793,15 +969,25 @@ def get_multi(self, keys, key_prefix=''):
retvals[prefixed_to_orig_key[rkey]] = val # un-prefix returned key.
line = server.readline()
except (_Error, socket.error) as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
server.mark_dead(msg)
return retvals
- def _expectvalue(self, server, line=None):
+ def _expect_cas_value(self, server, line=None, raise_exception=False):
if not line:
- line = server.readline()
+ line = server.readline(raise_exception)
+
+ if line and line[:5] == b'VALUE':
+ resp, rkey, flags, len, cas_id = line.split()
+ return (rkey, int(flags), int(len), int(cas_id))
+ else:
+ return (None, None, None, None)
- if line[:5] == b'VALUE':
+ def _expectvalue(self, server, line=None, raise_exception=False):
+ if not line:
+ line = server.readline(raise_exception)
+
+ if line and line[:5] == b'VALUE':
resp, rkey, flags, len = line.split()
flags = int(flags)
rlen = int(len)
@@ -812,7 +998,6 @@ def _expectvalue(self, server, line=None):
def _recv_value(self, server, flags, rlen):
rlen += 2 # include \r\n
buf = server.recv(rlen)
-
if len(buf) != rlen:
raise _Error("received %d bytes when expecting %d" % (len(buf), rlen))
@@ -844,12 +1029,42 @@ def _recv_value(self, server, flags, rlen):
return val
+ def check_key(self, key, key_extra_len=0):
+ """Checks sanity of key. Fails if:
+ Key length is > SERVER_MAX_KEY_LENGTH (Raises MemcachedKeyLength).
+ Contains control characters (Raises MemcachedKeyCharacterError).
+ Is not a string (Raises MemcachedKeyError)
+ Is None (Raises MemcachedKeyError)
+ """
+ if isinstance(key, tuple): key = key[1]
+ if not key:
+ raise Client.MemcachedKeyNoneError("Key is None")
+ if not isinstance(key, str):
+ raise Client.MemcachedKeyTypeError("Key must be str()'s")
+
+ if isinstance(key, bytes):
+ keylen = len(key)
+ else:
+ keylen = len(key.encode("utf-8"))
+
+ if self.server_max_key_length != 0 and \
+ keylen + key_extra_len > self.server_max_key_length:
+ raise Client.MemcachedKeyLengthError("Key length is > %s"
+ % self.server_max_key_length)
+ after_translate = key.translate(key.maketrans('', '', invalid_key_characters))
+ if len(key) != len(after_translate):
+ raise Client.MemcachedKeyCharacterError(
+ "Control characters not allowed")
-class _Host:
- _DEAD_RETRY = 30 # number of seconds before retrying a dead server.
- _SOCKET_TIMEOUT = 3 # number of seconds before sockets timeout.
- def __init__(self, host, debugfunc=None):
+class _Host(object):
+
+ def __init__(self, host, debug=0, dead_retry=_DEAD_RETRY,
+ socket_timeout=_SOCKET_TIMEOUT, flush_on_reconnect=0):
+ self.dead_retry = dead_retry
+ self.socket_timeout = socket_timeout
+ self.debug = debug
+ self.flush_on_reconnect = flush_on_reconnect
if isinstance(host, tuple):
host, self.weight = host
else:
@@ -858,9 +1073,12 @@ def __init__(self, host, debugfunc=None):
# parse the connection string
m = re.match(r'^(?P<proto>unix):(?P<path>.*)$', host)
if not m:
+ m = re.match(r'^(?P<proto>inet6):'
+ r'\[(?P<host>[^\[\]]+)\](:(?P<port>[0-9]+))?$', host)
+ if not m:
m = re.match(r'^(?P<proto>inet):'
r'(?P<host>[^:]+)(:(?P<port>[0-9]+))?$', host)
- if not m: m = re.match(r'^(?P<host>[^:]+):(?P<port>[0-9]+)$', host)
+ if not m: m = re.match(r'^(?P<host>[^:]+)(:(?P<port>[0-9]+))?$', host)
if not m:
raise ValueError('Unable to parse connection string: "%s"' % host)
@@ -868,21 +1086,27 @@ def __init__(self, host, debugfunc=None):
if hostData.get('proto') == 'unix':
self.family = socket.AF_UNIX
self.address = hostData['path']
+ elif hostData.get('proto') == 'inet6':
+ self.family = socket.AF_INET6
+ self.ip = hostData['host']
+ self.port = int(hostData.get('port') or 11211)
+ self.address = ( self.ip, self.port )
else:
self.family = socket.AF_INET
self.ip = hostData['host']
- self.port = int(hostData.get('port', 11211))
+ self.port = int(hostData.get('port') or 11211)
self.address = ( self.ip, self.port )
- if not debugfunc:
- debugfunc = lambda x: x
- self.debuglog = debugfunc
-
self.deaduntil = 0
self.socket = None
+ self.flush_on_next_connect = 0
self.buffer = b''
+ def debuglog(self, str):
+ if self.debug:
+ sys.stderr.write("MemCached: %s\n" % str)
+
def _check_dead(self):
if self.deaduntil and self.deaduntil > time.time():
return 1
@@ -896,7 +1120,9 @@ def connect(self):
def mark_dead(self, reason):
self.debuglog("MemCache: %s: %s. Marking dead." % (self, reason))
- self.deaduntil = time.time() + _Host._DEAD_RETRY
+ self.deaduntil = time.time() + self.dead_retry
+ if self.flush_on_reconnect:
+ self.flush_on_next_connect = 1
self.close_socket()
def _get_socket(self):
@@ -905,18 +1131,22 @@ def _get_socket(self):
if self.socket:
return self.socket
s = socket.socket(self.family, socket.SOCK_STREAM)
- if hasattr(s, 'settimeout'): s.settimeout(self._SOCKET_TIMEOUT)
+
+ if hasattr(s, 'settimeout'): s.settimeout(self.socket_timeout)
try:
s.connect(self.address)
except socket.timeout as msg:
self.mark_dead("connect: %s" % msg)
return None
except socket.error as msg:
- if type(msg) is tuple: msg = msg[1]
+ if isinstance(msg, tuple): msg = msg[1]
self.mark_dead("connect: %s" % msg)
return None
self.socket = s
self.buffer = b''
+ if self.flush_on_next_connect:
+ self.flush()
+ self.flush_on_next_connect = 0
return s
def close_socket(self):
@@ -928,7 +1158,7 @@ def send_cmd(self, cmd):
if not isinstance(cmd, bytes):
self.socket.sendall((cmd + '\r\n').encode('ascii'))
else:
- self.socket.sendall(cmd + '\r\n'.encode('ascii'))
+ self.socket.sendall(cmd + b'\r\n')
def send_cmds(self, cmds):
@@ -938,46 +1168,53 @@ def send_cmds(self, cmds):
else:
self.socket.sendall(cmds)
- def readline(self):
+ def readline(self, raise_exception=False):
+ """Read a line and return it. If "raise_exception" is set,
+ raise _ConnectionDeadError if the read fails, otherwise return
+ an empty string.
+ """
buf = self.buffer
recv = self.socket.recv
while True:
- index = buf.find('\r\n'.encode('ascii'))
-
+ index = buf.find(b'\r\n')
if index >= 0:
break
data = recv(4096)
if not data:
- self.mark_dead('Connection closed while reading from %s'
- % repr(self))
- break
- buf += data
- if index >= 0:
- self.buffer = buf[index+2:]
- buf = buf[:index]
- else:
- self.buffer = b''
+ # connection close, let's kill it and raise
+ self.close_socket()
+ if raise_exception:
+ raise _ConnectionDeadError()
+ else:
+ return b''
- return buf
+ buf += data
+ self.buffer = buf[index+2:]
+ return buf[:index]
- def expect(self, text):
- line = self.readline()
+ def expect(self, text, raise_exception=False):
+ line = self.readline(raise_exception)
if line != text:
- self.debuglog("while expecting '%s', got unexpected response '%s'" % (text, line))
+ self.debuglog("while expecting '%s', got unexpected response '%s'"
+ % (text, line))
return line
def recv(self, rlen):
self_socket_recv = self.socket.recv
buf = self.buffer
while len(buf) < rlen:
- foo = self_socket_recv(4096)
+ foo = self_socket_recv(max(rlen - len(buf), 4096))
buf += foo
- if len(foo) == 0:
+ if not foo:
raise _Error( 'Read %d bytes, expecting %d, '
'read returned 0 length bytes' % ( len(buf), rlen ))
self.buffer = buf[rlen:]
return buf[:rlen]
+ def flush(self):
+ self.send_cmd(b'flush_all')
+ self.expect(b'OK')
+
def __str__(self):
d = ''
if self.deaduntil:
@@ -985,32 +1222,11 @@ def __str__(self):
if self.family == socket.AF_INET:
return "inet:%s:%d%s" % (self.address[0], self.address[1], d)
+ elif self.family == socket.AF_INET6:
+ return "inet6:[%s]:%d%s" % (self.address[0], self.address[1], d)
else:
return "unix:%s%s" % (self.address, d)
-def check_key(key, key_extra_len=0):
- """Checks sanity of key. Fails if:
- Key length is > SERVER_MAX_KEY_LENGTH (Raises MemcachedKeyLength).
- Contains control characters (Raises MemcachedKeyCharacterError).
- Is not a string (Raises MemcachedStringEncodingError)
- Is an unicode string (Raises MemcachedStringEncodingError)
- Is not a string (Raises MemcachedKeyError)
- Is None (Raises MemcachedKeyError)
- """
- if type(key) == tuple: key = key[1]
- if not key:
- raise Client.MemcachedKeyNoneError(("Key is None"))
- if not isinstance(key, str):
- raise Client.MemcachedKeyTypeError(("Key must be str()'s"))
-
- if isinstance(key, str):
- key = key.encode("utf-8")
- if len(key) + key_extra_len > SERVER_MAX_KEY_LENGTH:
- raise Client.MemcachedKeyLengthError("Key length is > %s"
- % SERVER_MAX_KEY_LENGTH)
- for char in key:
- if char < 33 or char == 127:
- raise Client.MemcachedKeyCharacterError("Control characters not allowed")
def _doctest():
import doctest, memcache
@@ -1136,7 +1352,8 @@ class StrSubclass(str):
else:
print("FAIL")
- print("Testing using a value larger than the memcached value limit...", end=' ')
+ print("Testing using a value larger than the memcached value limit...")
+ print('NOTE: "MemCached: while expecting[...]" is normal...')
x = mc.set('keyhere', 'a'*SERVER_MAX_VALUE_LENGTH)
if mc.get('keyhere') == None:
print("OK", end=' ')
Please sign in to comment.
Something went wrong with that request. Please try again.