Skip to content

Commit

Permalink
Merge pull request #104 from sigmavirus24/feature/hooks
Browse files Browse the repository at this point in the history
Start adding hooks
  • Loading branch information
sigmavirus24 committed Apr 24, 2016
2 parents 72a8839 + f52f86d commit 68da87d
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 73 deletions.
25 changes: 12 additions & 13 deletions betamax/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import os

from .cassette import Cassette
from . import cassette
from .exceptions import BetamaxError
from datetime import datetime, timedelta
from requests.adapters import BaseAdapter, HTTPAdapter
Expand Down Expand Up @@ -65,7 +65,7 @@ def load_cassette(self, cassette_name, serialize, options):
placeholders = self.options.get('placeholders', [])
cassette_options = {}

default_options = Cassette.default_cassette_options
default_options = cassette.Cassette.default_cassette_options

match_requests_on = self.options.get(
'match_requests_on', default_options['match_requests_on']
Expand All @@ -85,7 +85,7 @@ def load_cassette(self, cassette_name, serialize, options):
if value is None:
cassette_options.pop(option)

self.cassette = Cassette(
self.cassette = cassette.Cassette(
cassette_name, serialize, placeholders=placeholders,
cassette_library_dir=self.options.get('cassette_library_dir'),
**cassette_options
Expand All @@ -112,21 +112,22 @@ def send(self, request, stream=False, timeout=None, verify=True,
:returns: A Response object
"""
interaction = None
current_cassette = self.cassette

if not self.cassette:
if not current_cassette:
raise BetamaxError('No cassette was specified or found.')

if self.cassette.interactions:
interaction = self.cassette.find_match(request)
if current_cassette.interactions:
interaction = current_cassette.find_match(request)

if not interaction and self.cassette.is_recording():
if not interaction and current_cassette.is_recording():
interaction = self.send_and_record(
request, stream, timeout, verify, cert, proxies
)

if not interaction:
raise BetamaxError(unhandled_request_message(request,
self.cassette))
current_cassette))

resp = interaction.as_response()
resp.connection = self
Expand All @@ -145,17 +146,15 @@ def send_and_record(self, request, stream=False, timeout=None,
:param bool verify: (optional) verify SSL certificate
:param str cert: (optional) path to SSL client
:param proxies dict: (optional) mapping protocol to URL of the proxy
:return: Iteraction
:rtype: class:`betamax.cassette.iteraction`
:return: Interaction
:rtype: class:`betamax.cassette.Interaction`
"""
adapter = self.find_adapter(request.url)
response = adapter.send(
request, stream=True, timeout=timeout, verify=verify,
cert=cert, proxies=proxies
)
self.cassette.save_interaction(response, request)
return self.cassette.interactions[-1]
return self.cassette.save_interaction(response, request)

def find_adapter(self, url):
"""Find adapter.
Expand Down
4 changes: 2 additions & 2 deletions betamax/cassette/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .cassette import Cassette
from .cassette import Cassette, dispatch_hooks
from .interaction import Interaction

__all__ = ('Cassette', 'Interaction')
__all__ = ('Cassette', 'Interaction', 'dispatch_hooks')
59 changes: 41 additions & 18 deletions betamax/cassette/cassette.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
from .interaction import Interaction
import collections
from datetime import datetime
from functools import partial

import os.path

from .interaction import Interaction

from .. import matchers
from .. import serializers
from betamax.util import (_option_from, serialize_prepared_request,
Expand All @@ -22,6 +23,8 @@ class Cassette(object):
'allow_playback_repeats': False,
}

hooks = collections.defaultdict(list)

def __init__(self, cassette_name, serialization_format, **kwargs):
#: Short name of the cassette
self.cassette_name = cassette_name
Expand Down Expand Up @@ -112,7 +115,7 @@ def find_match(self, request):
``use_cassette`` and passes in the request currently in progress.
:param request: ``requests.PreparedRequest``
:returns: :class:`Interaction <Interaction>`
:returns: :class:`~betamax.cassette.Interaction`
"""
# if we are recording, do not filter by match
if self.is_recording() and self.record_mode != 'all':
Expand All @@ -125,19 +128,24 @@ def find_match(self, request):
for o in opts
]

for i in self.interactions:
for interaction in self.interactions:
if not interaction.match(curried_matchers):
continue

if interaction.used or interaction.ignored:
continue

# If the interaction matches everything
if i.match(curried_matchers) and not i.used:
if self.record_mode == 'all':
# If we're recording everything and there's a matching
# interaction we want to overwrite it, so we remove it.
self.interactions.remove(i)
break

# set interaction as used before returning
if not self.allow_playback_repeats:
i.used = True
return i
if self.record_mode == 'all':
# If we're recording everything and there's a matching
# interaction we want to overwrite it, so we remove it.
self.interactions.remove(interaction)
break

# set interaction as used before returning
if not self.allow_playback_repeats:
interaction.used = True
return interaction

# No matches. So sad.
return None
Expand All @@ -162,15 +170,20 @@ def load_interactions(self):
self.interactions = [Interaction(i) for i in interactions]

for i in self.interactions:
dispatch_hooks('before_playback', i, self)
i.replace_all(self.placeholders, ('placeholder', 'replace'))

def sanitize_interactions(self):
for i in self.interactions:
i.replace_all(self.placeholders)

def save_interaction(self, response, request):
interaction = self.serialize_interaction(response, request)
self.interactions.append(Interaction(interaction, response))
serialized_data = self.serialize_interaction(response, request)
interaction = Interaction(serialized_data, response)
dispatch_hooks('before_record', interaction, self)
if not interaction.ignored: # If a hook caused this to be ignored
self.interactions.append(interaction)
return interaction

def serialize_interaction(self, response, request):
return {
Expand All @@ -191,7 +204,17 @@ def _save_cassette(self):
self.sanitize_interactions()

cassette_data = {
'http_interactions': [i.json for i in self.interactions],
'http_interactions': [i.data for i in self.interactions],
'recorded_with': 'betamax/{0}'.format(__version__)
}
self.serializer.serialize(cassette_data)


def dispatch_hooks(hook_name, *args):
"""Dispatch registered hooks."""
# Cassette.hooks is a dictionary that defaults to an empty list,
# we neither need to check for the presence of hook_name in it, nor
# need to worry about whether the return value will be iterable
hooks = Cassette.hooks[hook_name]
for hook in hooks:
hook(*args)
32 changes: 20 additions & 12 deletions betamax/cassette/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,19 @@ class Interaction(object):
"""

def __init__(self, interaction, response=None):
self.json = interaction
self.data = interaction
self.orig_response = response
self.used = False
self.recorded_response = self.deserialize()
self.used = False
self.ignored = False

def ignore(self):
"""Ignore this interaction.
This is only to be used from a before_record or a before_playback
callback.
"""
self.ignored = True

def as_response(self):
"""Return the Interaction as a Response object."""
Expand All @@ -35,18 +43,18 @@ def as_response(self):

@property
def recorded_at(self):
return datetime.strptime(self.json['recorded_at'], '%Y-%m-%dT%H:%M:%S')
return datetime.strptime(self.data['recorded_at'], '%Y-%m-%dT%H:%M:%S')

def deserialize(self):
"""Turn a serialized interaction into a Response."""
r = util.deserialize_response(self.json['response'])
r.request = util.deserialize_prepared_request(self.json['request'])
r = util.deserialize_response(self.data['response'])
r.request = util.deserialize_prepared_request(self.data['request'])
extract_cookies_to_jar(r.cookies, r.request, r.raw)
return r

def match(self, matchers):
"""Return whether this interaction is a match."""
request = self.json['request']
request = self.data['request']
return all(m(request) for m in matchers)

def replace(self, text_to_replace, placeholder):
Expand All @@ -63,29 +71,29 @@ def replace_all(self, replacements, key_order=('replace', 'placeholder')):

def replace_in_headers(self, text_to_replace, placeholder):
for obj in ('request', 'response'):
headers = self.json[obj]['headers']
headers = self.data[obj]['headers']
for k, v in list(headers.items()):
v = util.from_list(v)
headers[k] = v.replace(text_to_replace, placeholder)

def replace_in_body(self, text_to_replace, placeholder):
for obj in ('request', 'response'):
body = self.json[obj]['body']
body = self.data[obj]['body']
old_style = hasattr(body, 'replace')
if not old_style:
body = body.get('string', '')

if text_to_replace in body:
body = body.replace(text_to_replace, placeholder)
if old_style:
self.json[obj]['body'] = body
self.data[obj]['body'] = body
else:
self.json[obj]['body']['string'] = body
self.data[obj]['body']['string'] = body

def replace_in_uri(self, text_to_replace, placeholder):
for (obj, key) in (('request', 'uri'), ('response', 'url')):
uri = self.json[obj][key]
uri = self.data[obj][key]
if text_to_replace in uri:
self.json[obj][key] = uri.replace(
self.data[obj][key] = uri.replace(
text_to_replace, placeholder
)
46 changes: 46 additions & 0 deletions betamax/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,52 @@ def __setattr__(self, prop, value):
else:
super(Configuration, self).__setattr__(prop, value)

def before_playback(self, tag=None, callback=None):
"""Register a function to call before playing back an interaction.
Example usage:
.. code-block:: python
def before_playback(interaction, cassette):
pass
with Betamax.configure() as config:
config.before_playback(callback=before_playback)
:param str tag:
Limits the interactions passed to the function based on the
interaction's tag (currently unsupported).
:param callable callback:
The function which either accepts just an interaction or an
interaction and a cassette and mutates the interaction before
returning.
"""
Cassette.hooks['before_playback'].append(callback)

def before_record(self, tag=None, callback=None):
"""Register a function to call before recording an interaction.
Example usage:
.. code-block:: python
def before_record(interaction, cassette):
pass
with Betamax.configure() as config:
config.before_record(callback=before_record)
:param str tag:
Limits the interactions passed to the function based on the
interaction's tag (currently unsupported).
:param callable callback:
The function which either accepts just an interaction or an
interaction and a cassette and mutates the interaction before
returning.
"""
Cassette.hooks['before_record'].append(callback)

@property
def cassette_library_dir(self):
"""Retrieve and set the directory to store the cassettes in."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_backwards_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ def tests_populates_correct_fields_with_missing_data(self):
def tests_deserializes_old_cassette_headers(self):
with betamax.Betamax(self.session).use_cassette('GitHub_emojis') as b:
self.session.get('https://api.github.com/emojis')
interaction = b.current_cassette.interactions[0].json
interaction = b.current_cassette.interactions[0].data
header = interaction['request']['headers']['Accept']
assert not isinstance(header, list)
62 changes: 62 additions & 0 deletions tests/integration/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import betamax

from . import helper


def prerecord_hook(interaction, cassette):
assert cassette.interactions == []
interaction.data['response']['headers']['Betamax-Fake-Header'] = 'success'


def ignoring_hook(interaction, cassette):
interaction.ignore()


def preplayback_hook(interaction, cassette):
assert cassette.interactions != []
interaction.data['response']['headers']['Betamax-Fake-Header'] = 'temp'


class TestHooks(helper.IntegrationHelper):
def tearDown(self):
super(TestHooks, self).tearDown()
# Clear out the hooks
betamax.cassette.Cassette.hooks.pop('before_record', None)
betamax.cassette.Cassette.hooks.pop('before_playback', None)

def test_prerecord_hook(self):
with betamax.Betamax.configure() as config:
config.before_record(callback=prerecord_hook)

recorder = betamax.Betamax(self.session)
with recorder.use_cassette('prerecord_hook'):
self.cassette_path = recorder.current_cassette.cassette_path
response = self.session.get('https://httpbin.org/get')
assert response.headers['Betamax-Fake-Header'] == 'success'

with recorder.use_cassette('prerecord_hook', record='none'):
response = self.session.get('https://httpbin.org/get')
assert response.headers['Betamax-Fake-Header'] == 'success'

def test_preplayback_hook(self):
with betamax.Betamax.configure() as config:
config.before_playback(callback=preplayback_hook)

recorder = betamax.Betamax(self.session)
with recorder.use_cassette('preplayback_hook'):
self.cassette_path = recorder.current_cassette.cassette_path
self.session.get('https://httpbin.org/get')

with recorder.use_cassette('preplayback_hook', record='none'):
response = self.session.get('https://httpbin.org/get')
assert response.headers['Betamax-Fake-Header'] == 'temp'

def test_prerecord_ignoring_hook(self):
with betamax.Betamax.configure() as config:
config.before_record(callback=ignoring_hook)

recorder = betamax.Betamax(self.session)
with recorder.use_cassette('ignore_hook'):
self.cassette_path = recorder.current_cassette.cassette_path
self.session.get('https://httpbin.org/get')
assert recorder.current_cassette.interactions == []

0 comments on commit 68da87d

Please sign in to comment.