Skip to content

Commit

Permalink
Add gets method and simplify connection estabilish
Browse files Browse the repository at this point in the history
- simplify connection estabilish procedure, as tornado.iostream
  permit write before connection estabilished, I removed the
  callback and connection timeout procedure in _get_server()
- add `gets(self, keys, callback, failcallback)` method, receive
  a list of keys as parameter and return a dictionary of results
  result = {key1:value1, key2:value2...}, only return the got
  keys from memcached
- some debug info, might be cleanup later

Signed-off-by: Wang Xu <gnawux@gmail.com>
  • Loading branch information
gnawux committed Mar 7, 2013
1 parent 534fc2b commit 10c5236
Showing 1 changed file with 92 additions and 51 deletions.
143 changes: 92 additions & 51 deletions tornadoasyncmemcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def _on_fail(self, *args):
import sys
import socket
import time
from datetime import datetime
import types
import contextlib
import logging
from tornado import iostream, ioloop
from tornado import stack_context
from functools import partial
Expand All @@ -65,7 +67,7 @@ class TooManyClients(Exception):

class ClientPool(object):

CMDS = ('get', 'replace', 'set', 'decr', 'incr', 'delete')
CMDS = ('get', 'replace', 'set', 'decr', 'incr', 'delete', 'gets')

def __init__(self,
servers,
Expand Down Expand Up @@ -98,11 +100,15 @@ def _create_clients(self, n):
for x in xrange(n)]

def _do(self, cmd, *args, **kwargs):
#print('enter _do: %s', datetime.now().strftime('%T.%f'))
fail_callback = None
if 'fail_callback' in kwargs:
fail_callback = kwargs['fail_callback']
del kwargs['fail_callback']
if not self._clients:
c = None
try:
c = self._clients.popleft()
except IndexError:
if self._maxclients > 0 and (len(self._clients)
+ len(self._used) >= self._maxclients):
fail_reason = "Max of %d clients is already reached" % self._maxclients
Expand All @@ -111,19 +117,20 @@ def _do(self, cmd, *args, **kwargs):
return
else:
raise TooManyClients(fail_reason)
self._clients.append(self._create_clients(1)[0])
c = self._clients.popleft()
c = self._create_clients(1)[0]
kwargs['callback'] = partial(self._gen_cb, c=c, _cb=kwargs['callback'])
self._used.append(c)
context = partial(self._cleanup, fail_callback = partial(self._gen_cb, c=c, _cb= fail_callback))
context = partial(self._cleanup, fail_callback = partial(self._gen_fail_cb, c=c, _cb= fail_callback))
with stack_context.StackContext(context):
#print('send to client: %s', datetime.now().strftime('%T.%f'))
getattr(c, cmd)(*args, **kwargs)

@contextlib.contextmanager
def _cleanup(self, fail_callback = None):
try:
yield
except _Error as e:
print "gotcha", e
if fail_callback:
fail_callback(e.args)

Expand All @@ -134,14 +141,24 @@ def __getattr__(self, name):
(self.__class__.__name__, name))

def _gen_cb(self, response, c, _cb, *args, **kwargs):
self._used.remove(c)
if self._maxcached == 0 or self._maxcached > len(self._clients):
self._clients.append(c)
if c in self._used:
#print('back to _cb: %s', datetime.now().strftime('%T.%f'))
self._used.remove(c)
if self._maxcached == 0 or self._maxcached > len(self._clients):
self._clients.append(c)
else:
c.disconnect_all()
#print('call callback: %s', datetime.now().strftime('%T.%f'))
if _cb:
_cb(response, *args, **kwargs)
else:
c.disconnect_all()
pass #returned but too late

def _gen_fail_cb(self, response, c, _cb, *args, **kwargs):
self._used.remove(c)
c.disconnect_all()
if _cb:
_cb(response, *args, **kwargs)


class _Error(Exception):
pass
Expand Down Expand Up @@ -240,6 +257,7 @@ def _clear_timeout(self):
def _on_timeout(self, server):
self._timeout = None
server.mark_dead('Time out')
print "got timeout event..."
raise _Error('memcache call timeout')

def set_servers(self, servers):
Expand Down Expand Up @@ -283,32 +301,21 @@ def _init_buckets(self):
for i in range(server.weight):
self.buckets.append(server)

def _get_server(self, key, connect_callback, *args, **kwargs):
def _get_server(self, key):

if type(key) == types.TupleType:
serverhash = key[0]
key = key[1]
elif len(self.buckets) <= 1:
serverhash = 0
else:
serverhash = hash(key)

retry = kwargs['retry'] if 'retry' in kwargs else 0
server = kwargs['server'] if 'server' in kwargs else None

if retry > 0:
server.mark_dead('connect failed')
if retry >= self.server_retries:
self._clear_timeout()
raise _Error('no server available')
serverhash = hash(str(serverhash)+str(retry))

server = self.buckets[serverhash % len(self.buckets)]
if self.connect_timeout:
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
time.time() + self.connect_timeout,
stack_context.wrap(partial(self._on_timeout, server)))
server.connect(callback = partial(connect_callback, server, key, *args), fail_callback =
partial(self._get_server, (serverhash, key), connect_callback, *args, server = server, retry = retry +1))
server.connect()
#server.connect(callback = partial(connect_callback, server, key, *args), fail_callback =
# partial(self._get_server, (serverhash, key), connect_callback, *args, server = server, retry = retry +1))
return (server, key)

def disconnect_all(self):
for s in self.servers:
Expand All @@ -320,9 +327,8 @@ def delete(self, key, time=0, callback=None):
@return: Nonzero on success.
@rtype: int
'''
self._get_server(key, self._real_delete, time, callback)
server, key = self._get_server(key)

def _real_delete(self, server, key, time, callback):
if not server:
self.finish(partial(callback,0))
self._statlog('delete')
Expand Down Expand Up @@ -380,9 +386,8 @@ def decr(self, key, delta=1, callback=None):
self._incrdecr("decr", key, delta, callback=callback)

def _incrdecr(self, cmd, key, delta, callback):
self._get_server(key, self._real_incrdecr, cmd, delta, callback)
server, key = self._get_server(key)

def _real_incrdecr(self, server, key, cmd, delta, callback):
if not server:
self.finish(partial(callback, 0))
self._statlog(cmd)
Expand Down Expand Up @@ -435,9 +440,8 @@ def set(self, key, val, time=0, callback=None):
self._set("set", key, val, time, callback)

def _set(self, cmd, key, val, time, callback):
self._get_server(key, self._real_set, cmd, val, time, callback)
server, key = self._get_server(key)

def _real_set(self, server, key, cmd, val, time, callback):
if not server:
self.finish(partial(callback,0))

Expand Down Expand Up @@ -472,11 +476,9 @@ def get(self, key, callback):
@return: The value or None.
'''
self._get_server(key, self._real_get, callback)
server, key = self._get_server(key)

def _real_get(self, server, key, callback):
if not server:
self._clear_timeout()
raise _Error('No available server for %s' % key)

self._statlog('get')
Expand All @@ -495,6 +497,48 @@ def _get_expectval_cb(self, rkey, flags, rlen, server, callback):
def _get_recv_cb(self, value, server, callback):
server.expect("END", partial(self._expect_cb, value=value, callback=callback))

def gets(self, keys, callback):
'''Retrieves several keys from the memcache.
@return: The value list
'''
servers = dict()
for key in keys:
server, key = self._get_server(key)
if not server:
raise _Error('No available server for %s' % key)
if server in servers:
servers[server].append(key)
else:
servers[server] = [key]

self._statlog('gets')

gets_stat = {'server':servers.keys(), 'finished':0, 'result':{} }

for server in servers:
#print 'get %s' % ' '.join(servers[server])
server.send_cmd("get %s" % ' '.join(servers[server]), partial(self._gets_send_cb, server=server, status=gets_stat, callback=self._set_timeout(server, callback)))

def _gets_send_cb(self, server, status, callback):
self._expectvalue(server, line=None, callback=partial(self._gets_expectval_cb, server=server, status=status, callback=callback))

def _gets_expectval_cb(self, rkey, flags, rlen, server, status, callback):
if not rkey:
#print 'rkey is None'
status['finished'] += 1
if status['finished'] == len(status['server']):
self.finish(partial(callback,status['result']))
return
#print 'get vline for %s' % rkey
self._recv_value(server, flags, rlen, partial(self._gets_recv_cb, key=rkey, server=server, status=status, callback=callback))

def _gets_recv_cb(self, value, key, server, status, callback):
#server.expect("END", partial(self._expect_cb, value=value, callback=callback))
status['result'][key] = value
#print 'get value for %s' % key
self._expectvalue(server, line=None, callback=partial(self._gets_expectval_cb, server=server, status=status, callback=callback))

def _expect_cb(self, expected=None, value=None, callback=None):
# print "in expect cb"
self.finish(partial(callback,value))
Expand Down Expand Up @@ -586,33 +630,28 @@ def _check_dead(self):
self.deaduntil = 0
return 0

def connect(self, callback, fail_callback = None):
return self._get_socket(callback, fail_callback)
def connect(self):
return self._get_socket()

def mark_dead(self, reason):
print "MemCache: %s: %s. Marking dead." % (self, reason)
if self.deaduntil == 0:
self.deaduntil = time.time() + self.dead_retry
self.close_socket()

def _get_socket(self, callback, fail_callback = None):
def _get_socket(self):
if self._check_dead():
if fail_callback:
fail_callback()
return None
else:
return None
raise _Error('server dead')

if self.stream:
callback()
return None
return

addrinfo = socket.getaddrinfo(self.ip, self.port, socket.AF_INET, socket.SOCK_STREAM, 0, 0)
af, socktype, proto, canonname, sockaddr = addrinfo[0]
self.stream = iostream.IOStream(socket.socket(af, socktype, proto),
io_loop = self.io_loop)
self.stream.debug=True
self.stream.connect(sockaddr, callback)
return None
self.stream.debug=False
self.stream.connect(sockaddr)

def close_socket(self):
if self.stream:
Expand All @@ -622,7 +661,9 @@ def close_socket(self):

def send_cmd(self, cmd, callback):
# print "in sendcmd", repr(cmd), callback
#print('begin stream write: %s', datetime.now().strftime('%T.%f'))
self.stream.write(cmd+"\r\n", callback)
#print('end stream write: %s', datetime.now().strftime('%T.%f'))
#self.socket.sendall(cmd + "\r\n")

def readline(self, callback):
Expand Down

0 comments on commit 10c5236

Please sign in to comment.