Skip to content

Commit

Permalink
Fixes to pass all the unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
bbangert committed Apr 29, 2010
1 parent ff56e60 commit 5438be4
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions openidredis/__init__.py
Expand Up @@ -45,10 +45,11 @@ def _filenameEscape(s):

class RedisStore(OpenIDStore):
"""Implementation of OpenIDStore for Redis"""
def __init__(self, host='localhost', port=6379, db=0):
def __init__(self, host='localhost', port=6379, db=0, key_prefix='oid_redis'):
self.host = host
self.port = port
self.db = db
self.key_prefix = key_prefix
self._conn = redis.Redis(host=self.host, port=self.port, db=self.db)

def getAssociationFilename(self, server_url, handle):
Expand All @@ -71,36 +72,57 @@ def getAssociationFilename(self, server_url, handle):
else:
handle_hash = ''

filename = '%s-%s-%s-%s' % (proto, domain, url_hash, handle_hash)
filename = '%s-%s-%s-%s-%s' % (self.key_prefix, proto, domain, url_hash, handle_hash)
log.debug('Returning filename: %s', filename)
return filename

def storeAssociation(self, server_url, association):
association_s = association.serialize()
full_key_name = self.getAssociationFilename(server_url, association.handle)
server_key_name = self.getAssociationFilename(server_url, None)
# Determine how long this association is good for
issued_offset = int(time.time()) - association.issued
seconds_from_now = issued_offset + association.lifetime

# If this association is already expired, don't even store it
if seconds_from_now < 1:
return None

for key_name in [full_key_name, server_key_name]:
self._conn.set(key_name, association_s)
log.debug('Storing key: %s', key_name)
association_s = association.serialize()
key_name = self.getAssociationFilename(server_url, association.handle)

# By default, set the expiration from the assocation expiration
self._conn.expire(key_name, association.lifetime)
log.debug('Expiring: %s, in %s seconds', key_name, association.lifetime)
self._conn.set(key_name, association_s)
log.debug('Storing key: %s', key_name)

# By default, set the expiration from the assocation expiration
self._conn.expire(key_name, seconds_from_now)
log.debug('Expiring: %s, in %s seconds', key_name, seconds_from_now)
return None

def getAssociation(self, server_url, handle=None):
log.debug('Association requested for server_url: %s, with handle: %s', server_url, handle)
key_name = self.getAssociationFilename(server_url, handle)
if handle is None:
handle = ''
association_s = self._conn.get(key_name)
if association_s:
log.debug('getAssociation found, returning association')
return Association.deserialize(association_s)
# Retrieve all the keys for this server connection
key_name = self.getAssociationFilename(server_url, '')
assocs = self._conn.keys('%s*' % key_name)

if not assocs:
log.debug('No association found for: %s', server_url)
return None

# Now use the one that was issued most recently
associations = []
for assoc in self._conn.mget(assocs):
associations.append(Association.deserialize(assoc))
associations.sort(cmp=lambda x,y: cmp(x.issued, y.issued))
log.debug('getAssociation found, returns most recently issued')
return associations[-1]
else:
log.debug('No association found for getAssociation')
return None
key_name = self.getAssociationFilename(server_url, handle)
association_s = self._conn.get(key_name)
if association_s:
log.debug('getAssociation found, returning association')
return Association.deserialize(association_s)
else:
log.debug('No association found for getAssociation')
return None

def removeAssociation(self, server_url, handle):
key_name = self.getAssociationFilename(server_url, handle)
Expand All @@ -109,7 +131,11 @@ def removeAssociation(self, server_url, handle):

def useNonce(self, server_url, timestamp, salt):
if abs(timestamp - time.time()) > nonce.SKEW:
log.debug('Invalid nonce used, time skew boom')
log.debug('Timestamp from current time is less than skew')
return False

# We're not even holding onto nonces apparently
if nonce.SKEW < 1:
return False

if server_url:
Expand All @@ -123,14 +149,27 @@ def useNonce(self, server_url, timestamp, salt):
url_hash = _safe64(server_url)
salt_hash = _safe64(salt)

anonce = '%08x-%s-%s-%s-%s' % (timestamp, proto, domain,
anonce = '%s-nonce-%08x-%s-%s-%s-%s' % (self.key_prefix, timestamp, proto, domain,
url_hash, salt_hash)
new_nonce = self._conn.setnx(anonce, 'nonce')
if new_nonce:
# Expire the nonce in 5 minutes
self._conn.expire(anonce, 300)
log.debug('Unused nonce, all good')
return True
else:
log.debug('Nonce already exists, oops')
exists = self._conn.getset(anonce, '%s' % timestamp)
log.debug('And new_nonce results: %s', exists)
if exists:
log.debug('Nonce already exists, oops: %s', anonce)
return False
else:
log.debug('Unused nonce, all good: %s', anonce)
# Let's set an expire time
curr_offset = time.time() - timestamp
self._conn.expire(anonce, curr_offset + nonce.SKEW)
return True

def cleanupNonces(self):
keys = self._conn.keys('%s-nonce-*' % self.key_prefix)
expired = 0
for key in keys:
# See if its expired
timestamp = int(self._conn.get(key))
if abs(timestamp - time.time()) > nonce.SKEW:
self._conn.delete(key)
expired += 1
return expired

0 comments on commit 5438be4

Please sign in to comment.