diff --git a/oauth2client/contrib/multiprocess_file_storage.py b/oauth2client/contrib/multiprocess_file_storage.py new file mode 100644 index 000000000..d9e2ad720 --- /dev/null +++ b/oauth2client/contrib/multiprocess_file_storage.py @@ -0,0 +1,261 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""File-based storage that supports multiple credentials and cross-process +access. + +This module supersedes the functionality previously found in `multistore_file`. + +This module provides the MultiprocessFileStorage class that: +* Is tied to a single credential via a user-specified key. This key can be used + to distinguish between multiple users, client ids, and/or scopes. +* Can be safely accessed and refreshed across threads and processes. + +Process & thread safety guarantees the following behavior: +* If one process refreshes a credential, subsequent refreshes from other + processes will re-fetch the credentials from the file instead of performing + an http request. +* If two processes attempt to refresh concurrently, only one process will be + able to acquire the lock and refresh, with the deadlock caveat below. +* The interprocess lock will not deadlock, instead, the if a process can not + acquire the interprocess lock within INTERPROCESS_LOCK_DEADLINE it will + allow refreshing the credential but will not write the updated credential to + disk, This logic happens during every lock cycle - if the credentials are + refreshed again it will retry locking and writing as normal. + +""" + +import base64 +import json +import logging +import os +import threading + +import fasteners +from six import iteritems + +from oauth2client.client import Credentials +from oauth2client.client import Storage as BaseStorage + +INTERPROCESS_LOCK_DEADLINE = 1 +logger = logging.getLogger(__name__) +_backends = {} +_backends_lock = threading.Lock() + + +def _get_backend(filename): + """A helper method to get or create a backend with thread locking. + + There should only be one backend per-file per-process.""" + + # This method prevents race conditions. + try: + return _backends[filename] + except KeyError: + pass + + with _backends_lock: + backend = _MultiprocessStorageBackend(filename) + _backends[filename] = backend + + return backend + + +class _MultiprocessStorageBackend(object): + + def __init__(self, filename): + self._file = None + self._filename = filename + self._process_lock = fasteners.InterProcessLock( + '{}.lock'.format(filename)) + self._thread_lock = threading.Lock() + self._read_only = False + self._credentials = {} + + def _create_file_if_needed(self): + """Creates the an empty credential storage file if it does not + exist.""" + if self._read_only: + return False + + if not os.path.exists(self._filename): + open(self._filename, 'a+b').close() + logging.info('Credential file {} created'.format(self._filename)) + return False + + return True + + def _load_credentials(self): + """(Re-)loads the credentials from the file.""" + if not self._create_file_if_needed(): + return + + self._credentials.update(self._load_credentials_file(self._file)) + + logger.debug('Read credential file') + + def _load_credentials_file(self, fh): + credentials = {} + + try: + fh.seek(0) + data = json.load(fh) + except Exception: + logger.warning( + 'Credentials file could not be loaded, will ignore and ' + 'overwrite.') + return credentials + + if data.get('file_version') != 2: + data = {} + logger.warning( + 'Credentials file is not version 2, will ignore and ' + 'overwrite.') + return credentials + + for key, encoded_credential in iteritems(data.get('credentials', {})): + try: + credential_json = base64.b64decode(encoded_credential) + credential = Credentials.new_from_json(credential_json) + credentials[key] = credential + except: + logger.warning( + 'Invalid credential {} in file, ignoring.'.format(key)) + + return credentials + + def _write_credentials(self): + if self._read_only: + logger.debug('In read-only mode, not writing credentials.') + return + + self._create_file_if_needed() + self._write_credentials_file(self._file, self._credentials) + logger.debug('Wrote credential file {}.'.format(self._filename)) + + def _write_credentials_file(self, fh, credentials): + data = {'file_version': 2, 'credentials': {}} + + for key, credential in iteritems(credentials): + credential_json = credential.to_json() + encoded_credential = base64.b64encode(credential_json) + data['credentials'][key] = encoded_credential + + fh.seek(0) + json.dump(data, fh) + fh.truncate() + + def acquire_lock(self): + self._thread_lock.acquire() + locked = self._process_lock.acquire(timeout=INTERPROCESS_LOCK_DEADLINE) + + self._create_file_if_needed() + + if locked: + self._file = open(self._filename, 'r+') + self._read_only = False + + else: + self._read_only = True + logger.warn( + 'Failed to obtain interprocess lock for credentials. ' + 'If a credential is being refreshed, other processes may ' + 'not see the updated access token and refresh as well.') + self._file = open(self._filename, 'r') + + self._load_credentials() + + def release_lock(self): + if self._file is not None: + self._file.close() + self._file = None + + if not self._read_only: + self._process_lock.release() + + self._thread_lock.release() + + def _refresh_predicate(self, credentials): + if credentials is None: + return True + if credentials.invalid: + return True + if credentials.access_token_expired: + return True + return False + + def locked_get(self, key): + # Check if the credential is already in memory. + credentials = self._credentials.get(key, None) + + # Use the refresh predicate to determine if the entire store should be + # reloaded. This basically checks if the credentials are invalid + # or expired. This covers the situation where another process has + # refreshed the credentials and this process doesn't know about it yet. + # In that case, this process won't needlessly refresh the credentials. + if self._refresh_predicate(credentials): + self._load_credentials() + credentials = self._credentials.get(key, None) + + return credentials + + def locked_put(self, key, credentials): + self._load_credentials() + self._credentials[key] = credentials + self._write_credentials() + + def locked_delete(self, key): + self._load_credentials() + try: + del self._credentials[key] + except KeyError: + pass + self._write_credentials() + + +class MultiprocessFileStorage(BaseStorage): + def __init__(self, filename, key): + self._key = key + self._backend = _get_backend(filename) + + def acquire_lock(self): + self._backend.acquire_lock() + + def release_lock(self): + self._backend.release_lock() + + def locked_get(self): + """Retrieves the current credentials from the store. + + Returns: + oauth2client.client.Credentials or None + """ + credential = self._backend.locked_get(self._key) + + if credential: + credential.set_store(self) + + return credential + + def locked_put(self, credentials): + """Writes the given credentials to the store. + + Args: + credentials: an oauth2client.client.Credentials object. + """ + return self._backend.locked_put(self._key, credentials) + + def locked_delete(self): + """Deletes the current credentials from the store.""" + return self._backend.locked_delete(self._key) diff --git a/tests/contrib/test_multiprocess_file_storage.py b/tests/contrib/test_multiprocess_file_storage.py new file mode 100644 index 000000000..87274a708 --- /dev/null +++ b/tests/contrib/test_multiprocess_file_storage.py @@ -0,0 +1,293 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for oauth2client.multistore_file.""" + +import contextlib +import datetime +import json +import multiprocessing +import os +import tempfile +import unittest2 +import fasteners +import mock + +from oauth2client.client import OAuth2Credentials +from oauth2client.contrib import multiprocess_file_storage +from six import StringIO + +from ..http_mock import HttpMockSequence + +_filehandle, FILENAME = tempfile.mkstemp('oauth2client_test.data') +os.close(_filehandle) + + +@contextlib.contextmanager +def scoped_child_process(target, **kwargs): + die_event = multiprocessing.Event() + ready_event = multiprocessing.Event() + process = multiprocessing.Process( + target=target, args=(die_event, ready_event), kwargs=kwargs) + process.start() + try: + ready_event.wait() + yield + finally: + die_event.set() + process.join(5) + + +def _create_test_credentials(expiration=None): + access_token = 'foo' + client_secret = 'cOuDdkfjxxnv+' + refresh_token = '1/0/a.df219fjls0' + token_expiry = expiration or ( + datetime.datetime.utcnow() + datetime.timedelta(seconds=3600)) + token_uri = 'https://www.google.com/accounts/o8/oauth2/token' + user_agent = 'refresh_checker/1.0' + + credentials = OAuth2Credentials( + access_token, 'test-client-id', client_secret, + refresh_token, token_expiry, token_uri, + user_agent) + return credentials + + +class MultiprocessStorageBehaviorTests(unittest2.TestCase): + + def setUp(self): + try: + os.unlink(FILENAME) + os.unlink('{}.lock'.format(FILENAME)) + except OSError: # pragma: NO COVER + pass + + def tearDown(self): + try: + os.unlink(FILENAME) + os.unlink('{}.lock'.format(FILENAME)) + except OSError: # pragma: NO COVER + pass + + def test_basic_operations(self): + credentials = _create_test_credentials() + + store = multiprocess_file_storage.MultiprocessFileStorage( + FILENAME, 'basic') + + # Save credentials + store.put(credentials) + credentials = store.get() + + self.assertNotEquals(None, credentials) + self.assertEquals('foo', credentials.access_token) + + # Reset internal cache, ensure credentials were saved. + store._backend._credentials = {} + credentials = store.get() + + self.assertNotEquals(None, credentials) + self.assertEquals('foo', credentials.access_token) + + # Delete credentials + store.delete() + credentials = store.get() + + self.assertEquals(None, credentials) + + def _generate_token_response_http(self, new_token='new_token'): + token_response = json.dumps({ + 'access_token': new_token, + 'expires_in': '3600', + }) + http = HttpMockSequence([ + ({'status': '200'}, token_response), + ]) + + return http + + def test_single_process_refresh(self): + store = multiprocess_file_storage.MultiprocessFileStorage( + FILENAME, 'single-process') + credentials = _create_test_credentials() + credentials.set_store(store) + + http = self._generate_token_response_http() + credentials.refresh(http) + assert credentials.access_token == 'new_token' + + retrieved = store.get() + assert retrieved.access_token == 'new_token' + + def test_multi_process_refresh(self): + # This will test that two processes attempting to refresh credentials + # will only refresh once. + store = multiprocess_file_storage.MultiprocessFileStorage( + FILENAME, 'multi-process') + credentials = _create_test_credentials() + credentials.set_store(store) + store.put(credentials) + + def child_process_func( + die_event, ready_event, check_event): # pragma: NO COVER + store = multiprocess_file_storage.MultiprocessFileStorage( + FILENAME, 'multi-process') + + credentials = store.get() + assert credentials + + # Make sure this thread gets to refresh first. + original_acquire_lock = store.acquire_lock + + def replacement_acquire_lock(*args, **kwargs): + result = original_acquire_lock(*args, **kwargs) + ready_event.set() + check_event.wait() + return result + + credentials.store.acquire_lock = replacement_acquire_lock + + http = self._generate_token_response_http('b') + credentials.refresh(http) + + assert credentials.access_token == 'b' + + check_event = multiprocessing.Event() + with scoped_child_process(child_process_func, check_event=check_event): + # The lock should be currently held by the child process. + assert not store._backend._process_lock.acquire(blocking=False) + check_event.set() + + # The child process will refresh first, so we should end up + # with 'b' as the token. + http = mock.Mock() + credentials.refresh(http=http) + assert credentials.access_token == 'b' + assert not http.request.called + + retrieved = store.get() + assert retrieved.access_token == 'b' + + def test_read_only_file_fail_lock(self): + credentials = _create_test_credentials() + + # Grab the lock in another process, preventing this process from + # acquiring the lock. + def child_process(die_event, ready_event): # pragma: NO COVER + lock = fasteners.InterProcessLock( + '{}.lock'.format(FILENAME)) + with lock: + ready_event.set() + die_event.wait() + + with scoped_child_process(child_process): + store = multiprocess_file_storage.MultiprocessFileStorage( + FILENAME, 'fail-lock') + store.put(credentials) + self.assertTrue(store._backend._read_only) + + # These credentials should still be in the store's memory-only cache. + assert store.get() + + +class MultiprocessStorageUnitTests(unittest2.TestCase): + + def setUp(self): # pragma: NO COVER + try: + os.unlink(FILENAME) + os.unlink('{}.lock'.format(FILENAME)) + except OSError: + pass + + def tearDown(self): # pragma: NO COVER + try: + os.unlink(FILENAME) + os.unlink('{}.lock'.format(FILENAME)) + except OSError: + pass + + def test__read_write_credentials_file(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + credentials = _create_test_credentials() + contents = StringIO() + + backend._write_credentials_file(contents, {'key': credentials}) + + contents.seek(0) + data = json.load(contents) + self.assertEqual(data['file_version'], 2) + self.assertTrue(data['credentials']['key']) + + # Read it back. + contents.seek(0) + results = backend._load_credentials_file(contents) + self.assertEqual( + results['key'].access_token, credentials.access_token) + + # Add an invalid credential and try reading it back. It should ignore + # the invalid one but still load the valid one. + data['credentials']['invalid'] = '123' + results = backend._load_credentials_file(StringIO(json.dumps(data))) + self.assertNotIn('invalid', results) + self.assertEqual( + results['key'].access_token, credentials.access_token) + + def test__load_credentials_file_invalid_json(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + contents = StringIO('{[') + self.assertEqual(backend._load_credentials_file(contents), {}) + + def test__load_credentials_file_no_file_version(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + contents = StringIO('{}') + self.assertEqual(backend._load_credentials_file(contents), {}) + + def test__load_credentials_file_bad_file_version(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + contents = StringIO(json.dumps({'file_version': 1})) + self.assertEqual(backend._load_credentials_file(contents), {}) + + def test_release_lock_with_no_file(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + backend._file = None + backend._read_only = True + backend._thread_lock.acquire() + backend.release_lock() + + def test_empty_delete(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + try: + backend.acquire_lock() + backend.locked_delete('non-existent') + finally: + backend.release_lock() + + def test__refresh_predicate(self): + backend = multiprocess_file_storage._get_backend(FILENAME) + + credentials = _create_test_credentials() + self.assertFalse(backend._refresh_predicate(credentials)) + + credentials.invalid = True + self.assertTrue(backend._refresh_predicate(credentials)) + + credentials = _create_test_credentials( + expiration=( + datetime.datetime.utcnow() - datetime.timedelta(seconds=3600))) + self.assertTrue(backend._refresh_predicate(credentials)) + + +if __name__ == '__main__': # pragma: NO COVER + unittest2.main() diff --git a/tox.ini b/tox.ini index 8e94c012b..d224d08ef 100644 --- a/tox.ini +++ b/tox.ini @@ -13,6 +13,7 @@ basedeps = mock>=1.3.0 deps = {[testenv]basedeps} django keyring + fasteners setenv = pypy: with_gmp=no DJANGO_SETTINGS_MODULE=tests.contrib.test_django_settings