Skip to content

Commit

Permalink
Fix mock import in Python 3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwager committed Sep 20, 2016
1 parent 8445ceb commit 5ee4186
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions tests/test_fierce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import ipaddress
import unittest
import unittest.mock

try:
# Python 3.3+
from unittest import mock
except ImportError:
# Python 3.2 (third-party)
import mock

import dns.name
import dns.resolver
Expand Down Expand Up @@ -149,13 +155,13 @@ def test_recursive_query_basic_failure(self):
domain = dns.name.from_text('example.com.')
record_type = 'NS'

with unittest.mock.patch.object(fierce, 'query', return_value=None) as mock_method:
with mock.patch.object(fierce, 'query', return_value=None) as mock_method:
result = fierce.recursive_query(resolver, domain, record_type=record_type)

expected = [
unittest.mock.call(resolver, 'example.com.', record_type),
unittest.mock.call(resolver, 'com.', record_type),
unittest.mock.call(resolver, '', record_type),
mock.call(resolver, 'example.com.', record_type),
mock.call(resolver, 'com.', record_type),
mock.call(resolver, '', record_type),
]

mock_method.assert_has_calls(expected)
Expand All @@ -166,15 +172,15 @@ def test_recursive_query_long_domain_failure(self):
domain = dns.name.from_text('sd1.sd2.example.com.')
record_type = 'NS'

with unittest.mock.patch.object(fierce, 'query', return_value=None) as mock_method:
with mock.patch.object(fierce, 'query', return_value=None) as mock_method:
result = fierce.recursive_query(resolver, domain, record_type=record_type)

expected = [
unittest.mock.call(resolver, 'sd1.sd2.example.com.', record_type),
unittest.mock.call(resolver, 'sd2.example.com.', record_type),
unittest.mock.call(resolver, 'example.com.', record_type),
unittest.mock.call(resolver, 'com.', record_type),
unittest.mock.call(resolver, '', record_type),
mock.call(resolver, 'sd1.sd2.example.com.', record_type),
mock.call(resolver, 'sd2.example.com.', record_type),
mock.call(resolver, 'example.com.', record_type),
mock.call(resolver, 'com.', record_type),
mock.call(resolver, '', record_type),
]

mock_method.assert_has_calls(expected)
Expand All @@ -184,19 +190,19 @@ def test_recursive_query_basic_success(self):
resolver = dns.resolver.Resolver()
domain = dns.name.from_text('example.com.')
record_type = 'NS'
good_response = unittest.mock.MagicMock()
good_response = mock.MagicMock()
side_effect = [
None,
good_response,
None,
]

with unittest.mock.patch.object(fierce, 'query', side_effect=side_effect) as mock_method:
with mock.patch.object(fierce, 'query', side_effect=side_effect) as mock_method:
result = fierce.recursive_query(resolver, domain, record_type=record_type)

expected = [
unittest.mock.call(resolver, 'example.com.', record_type),
unittest.mock.call(resolver, 'com.', record_type),
mock.call(resolver, 'example.com.', record_type),
mock.call(resolver, 'com.', record_type),
]

mock_method.assert_has_calls(expected)
Expand Down

0 comments on commit 5ee4186

Please sign in to comment.