From 5438be4f9ae4c40043ce15fae7cac5d96ea85689 Mon Sep 17 00:00:00 2001 From: Ben Bangert Date: Thu, 29 Apr 2010 13:05:36 -0700 Subject: [PATCH] Fixes to pass all the unit tests. --- openidredis/__init__.py | 97 +++++++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/openidredis/__init__.py b/openidredis/__init__.py index 86a0940..2cec35a 100644 --- a/openidredis/__init__.py +++ b/openidredis/__init__.py @@ -45,10 +45,11 @@ def _filenameEscape(s): class RedisStore(OpenIDStore): """Implementation of OpenIDStore for Redis""" - def __init__(self, host='localhost', port=6379, db=0): + def __init__(self, host='localhost', port=6379, db=0, key_prefix='oid_redis'): self.host = host self.port = port self.db = db + self.key_prefix = key_prefix self._conn = redis.Redis(host=self.host, port=self.port, db=self.db) def getAssociationFilename(self, server_url, handle): @@ -71,36 +72,57 @@ def getAssociationFilename(self, server_url, handle): else: handle_hash = '' - filename = '%s-%s-%s-%s' % (proto, domain, url_hash, handle_hash) + filename = '%s-%s-%s-%s-%s' % (self.key_prefix, proto, domain, url_hash, handle_hash) log.debug('Returning filename: %s', filename) return filename def storeAssociation(self, server_url, association): - association_s = association.serialize() - full_key_name = self.getAssociationFilename(server_url, association.handle) - server_key_name = self.getAssociationFilename(server_url, None) + # Determine how long this association is good for + issued_offset = int(time.time()) - association.issued + seconds_from_now = issued_offset + association.lifetime + + # If this association is already expired, don't even store it + if seconds_from_now < 1: + return None - for key_name in [full_key_name, server_key_name]: - self._conn.set(key_name, association_s) - log.debug('Storing key: %s', key_name) + association_s = association.serialize() + key_name = self.getAssociationFilename(server_url, association.handle) - # By default, set the expiration from the assocation expiration - self._conn.expire(key_name, association.lifetime) - log.debug('Expiring: %s, in %s seconds', key_name, association.lifetime) + self._conn.set(key_name, association_s) + log.debug('Storing key: %s', key_name) + + # By default, set the expiration from the assocation expiration + self._conn.expire(key_name, seconds_from_now) + log.debug('Expiring: %s, in %s seconds', key_name, seconds_from_now) return None def getAssociation(self, server_url, handle=None): log.debug('Association requested for server_url: %s, with handle: %s', server_url, handle) - key_name = self.getAssociationFilename(server_url, handle) if handle is None: - handle = '' - association_s = self._conn.get(key_name) - if association_s: - log.debug('getAssociation found, returning association') - return Association.deserialize(association_s) + # Retrieve all the keys for this server connection + key_name = self.getAssociationFilename(server_url, '') + assocs = self._conn.keys('%s*' % key_name) + + if not assocs: + log.debug('No association found for: %s', server_url) + return None + + # Now use the one that was issued most recently + associations = [] + for assoc in self._conn.mget(assocs): + associations.append(Association.deserialize(assoc)) + associations.sort(cmp=lambda x,y: cmp(x.issued, y.issued)) + log.debug('getAssociation found, returns most recently issued') + return associations[-1] else: - log.debug('No association found for getAssociation') - return None + key_name = self.getAssociationFilename(server_url, handle) + association_s = self._conn.get(key_name) + if association_s: + log.debug('getAssociation found, returning association') + return Association.deserialize(association_s) + else: + log.debug('No association found for getAssociation') + return None def removeAssociation(self, server_url, handle): key_name = self.getAssociationFilename(server_url, handle) @@ -109,7 +131,11 @@ def removeAssociation(self, server_url, handle): def useNonce(self, server_url, timestamp, salt): if abs(timestamp - time.time()) > nonce.SKEW: - log.debug('Invalid nonce used, time skew boom') + log.debug('Timestamp from current time is less than skew') + return False + + # We're not even holding onto nonces apparently + if nonce.SKEW < 1: return False if server_url: @@ -123,14 +149,27 @@ def useNonce(self, server_url, timestamp, salt): url_hash = _safe64(server_url) salt_hash = _safe64(salt) - anonce = '%08x-%s-%s-%s-%s' % (timestamp, proto, domain, + anonce = '%s-nonce-%08x-%s-%s-%s-%s' % (self.key_prefix, timestamp, proto, domain, url_hash, salt_hash) - new_nonce = self._conn.setnx(anonce, 'nonce') - if new_nonce: - # Expire the nonce in 5 minutes - self._conn.expire(anonce, 300) - log.debug('Unused nonce, all good') - return True - else: - log.debug('Nonce already exists, oops') + exists = self._conn.getset(anonce, '%s' % timestamp) + log.debug('And new_nonce results: %s', exists) + if exists: + log.debug('Nonce already exists, oops: %s', anonce) return False + else: + log.debug('Unused nonce, all good: %s', anonce) + # Let's set an expire time + curr_offset = time.time() - timestamp + self._conn.expire(anonce, curr_offset + nonce.SKEW) + return True + + def cleanupNonces(self): + keys = self._conn.keys('%s-nonce-*' % self.key_prefix) + expired = 0 + for key in keys: + # See if its expired + timestamp = int(self._conn.get(key)) + if abs(timestamp - time.time()) > nonce.SKEW: + self._conn.delete(key) + expired += 1 + return expired