# Handling dependencies

What if my "unit" uses system facilities?

- os.remove / datetime.now / etc.
- database access
- network connectivity
- user input

## Option 1: Dependency injection

Make the dependency part of the unit interface

In [1]:
import os, time
from collections import namedtuple

AuditRecord = namedtuple('AuditRecord', 'uid timestamp action')


def make_audit_record(action):
    return AuditRecord(os.getuid(), time.time(), action)

In [2]:
make_audit_record('update')

AuditRecord(uid=1000, timestamp=1655320980.4269292, action='update')

Depends on: current user, current time (hard to test)

### With Dependency injection

In [3]:
def make_audit_record(action, time=time.time, getuid=os.getuid):
    return AuditRecord(getuid(), time(), action)    

In [4]:
make_audit_record('update')

AuditRecord(uid=1000, timestamp=1655321111.1369326, action='update')

In [5]:
def mytime():
    return 200

def mygetuid():
    return 1024

ar = make_audit_record('update', time=mytime, getuid=mygetuid)
assert ar.timestamp == 200
assert ar.uid == 1024
assert ar.action == 'update'
ar

AuditRecord(uid=1024, timestamp=200, action='update')

Test can inject new versions of utcnow and getuid and test to ensure things work as expected. Unfortunately, we've "uglified" our interface.

## Option 2: Mocking and patching

The `unittest.mock` library provides the ability to patch dependencies:

In [6]:
# original implementation

def make_audit_record(action):
    return AuditRecord(os.getuid(), time.time(), action)

In [7]:
# Test code
import os, time
from unittest.mock import patch

with patch('os.getuid', return_value=1024), patch('time.time', return_value=200):
    rec = make_audit_record('update')
rec

AuditRecord(uid=1024, timestamp=200, action='update')

In [8]:
os.getuid(), time.time()

(1000, 1655321283.6480753)

In [9]:
import unittest

class TestCase(unittest.TestCase):
    @patch('os.getuid') # , return_value=1024)
    @patch('time.time', return_value=200)
    def test_userstuff(self, p_time, p_getuid):
        p_getuid.return_value = 1024
        print(make_audit_record('test'))
        p_getuid.assert_called()
        p_time.assert_called()



In [10]:
TestCase().test_userstuff()

AuditRecord(uid=1024, timestamp=200, action='test')


patch -> guerilla patch -> monkey patch

# Database Mocking

Depending on how much DBMS-specific SQL you use, you *may* be able to use Python's builtin `sqlite` database as a mock:

In [11]:
!pip install pandas

Looking in links: /home/rick446/src/wheelhouse
You should consider upgrading via the '/home/rick446/.virtualenvs/classes/bin/python -m pip install --upgrade pip' command.[0m


In [12]:
import pandas as pd

data = pd.read_csv('data/closing-prices.csv', index_col=0, parse_dates=[0])
data.head()

Unnamed: 0,F,TSLA,GOOG,IBM,AAPL
2014-01-02,12.089,150.1,,157.6001,72.7741
2014-01-03,12.1438,149.56,,158.543,71.1756
2014-01-06,12.1986,147.0,,157.9993,71.5637
2014-01-07,12.042,149.36,,161.1508,71.0516
2014-01-08,12.1673,151.28,,159.6728,71.5019


## We can dump the dataframe to an in-memory (SQL) database

In [13]:
import sqlite3

conn = sqlite3.connect(':memory:')
data.to_sql('prices', conn)

In [17]:
for row in conn.execute('SELECT * FROM prices WHERE IBM > 150 LIMIT 5'):
    print(row)

('2014-01-02 00:00:00', 12.089, 150.1, None, 157.6001, 72.7741)
('2014-01-03 00:00:00', 12.1438, 149.56, None, 158.543, 71.1756)
('2014-01-06 00:00:00', 12.1986, 147.0, None, 157.9993, 71.5637)
('2014-01-07 00:00:00', 12.042, 149.36, None, 161.1508, 71.0516)
('2014-01-08 00:00:00', 12.1673, 151.28, None, 159.6728, 71.5019)


# We can read it back into a dataframe

In [15]:
data2 = pd.read_sql('SELECT * FROM prices WHERE IBM > 160', conn)
data2.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 177 entries, 0 to 176
Data columns (total 6 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   index   177 non-null    object 
 1   F       177 non-null    float64
 2   TSLA    177 non-null    float64
 3   GOOG    167 non-null    float64
 4   IBM     177 non-null    float64
 5   AAPL    177 non-null    float64
dtypes: float64(5), object(1)
memory usage: 8.4+ KB


In [16]:
data2.head()

Unnamed: 0,index,F,TSLA,GOOG,IBM,AAPL
0,2014-01-07 00:00:00,12.042,149.36,,161.1508,71.0516
1,2014-01-16 00:00:00,13.099,170.97,,160.3438,72.9215
2,2014-01-17 00:00:00,12.9346,170.01,,161.4736,71.1348
3,2014-01-21 00:00:00,12.8484,176.68,,160.0635,72.24
4,2014-03-06 00:00:00,12.3674,252.94,,160.2662,70.2476


# More mock examples

In [18]:
from unittest import mock

m = mock.Mock()
print(m.whatever.else_.i.want)

<Mock name='mock.whatever.else_.i.want' id='140292021754944'>


In [19]:
m.whatever()

<Mock name='mock.whatever()' id='140292021754800'>

In [20]:
m.amethod.return_value = 42

In [21]:
m.amethod(1, 2, 3)

42

In [22]:
m.amethod.assert_called()

In [23]:
m.amethod.assert_called_with(1, 2, 3)

In [24]:
m.amethod.call_args_list

[call(1, 2, 3)]

In [25]:
mock_session = mock.Mock()

In [26]:
# code under test
mock_session.post('some_incorrect_url', json={'a': 5})

<Mock name='mock.post()' id='140292021752880'>

In [27]:
# Assertions in test case
mock_session.post.assert_called_with('some_url', json={'a': 5})

AssertionError: expected call not found.
Expected: post('some_url', json={'a': 5})
Actual: post('some_incorrect_url', json={'a': 5})

In [28]:
%%file data/test-examples/test8.py
import unittest
from unittest import mock


def echo_data(socket):
    data = socket.recv(42)
    socket.send(data)


class MyTest(unittest.TestCase):

    def test_send_recv(self):
        socket = mock.Mock()
        socket.recv.return_value = b'Some data'
        echo_data(socket)
        socket.recv.assert_called_with(42)
        socket.send.assert_called_with(b'Some data')


Overwriting data/test-examples/test8.py


In [29]:
!python -m unittest data/test-examples/test8.py

.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK


In [30]:
from unittest import mock

In [31]:
m = mock.Mock(name='bertrand')

In [32]:
m

<Mock name='bertrand' id='140293301039008'>

In [33]:
m.foo

<Mock name='bertrand.foo' id='140292021753648'>

In [34]:
m()

<Mock name='bertrand()' id='140292021753840'>

In [35]:
m.foo.return_value = 16
m.foo()

16

In [36]:
m.foo.assert_called_with()

In [37]:
m.side_effect = ValueError

In [38]:
m()

ValueError: 

In [39]:
m[10]

TypeError: 'Mock' object is not subscriptable

In [40]:
m + m

TypeError: unsupported operand type(s) for +: 'Mock' and 'Mock'

In [41]:
for obj in m:
    print(obj)

TypeError: 'Mock' object is not iterable

In [42]:
m = mock.MagicMock(name='russell')

In [43]:
m - m

<MagicMock name='russell.__sub__()' id='140292020333776'>

In [44]:
m.__add__.return_value = 'foo'

In [45]:
m + m

'foo'

In [46]:
m + None

'foo'

In [47]:
m - 5

<MagicMock name='russell.__sub__()' id='140292020333776'>

# Mock a REST API call

In [48]:
import requests
sess = requests.Session()
resp = sess.get('https://api.github.com').json()

In [49]:
resp

{'current_user_url': 'https://api.github.com/user',
 'current_user_authorizations_html_url': 'https://github.com/settings/connections/applications{/client_id}',
 'authorizations_url': 'https://api.github.com/authorizations',
 'code_search_url': 'https://api.github.com/search/code?q={query}{&page,per_page,sort,order}',
 'commit_search_url': 'https://api.github.com/search/commits?q={query}{&page,per_page,sort,order}',
 'emails_url': 'https://api.github.com/user/emails',
 'emojis_url': 'https://api.github.com/emojis',
 'events_url': 'https://api.github.com/events',
 'feeds_url': 'https://api.github.com/feeds',
 'followers_url': 'https://api.github.com/user/followers',
 'following_url': 'https://api.github.com/user/following{/target}',
 'gists_url': 'https://api.github.com/gists{/gist_id}',
 'hub_url': 'https://api.github.com/hub',
 'issue_search_url': 'https://api.github.com/search/issues?q={query}{&page,per_page,sort,order}',
 'issues_url': 'https://api.github.com/issues',
 'keys_url': '

In [50]:
# code under test
def get_feeds_url(sess):
    resp = sess.get('https://api.github.com')
    data = resp.json()
    return data['feeds_url']

In [51]:
mock_sess = mock.Mock()
mock_resp = mock_sess.get()
mock_resp.json.return_value = {'feeds_url': 'TESTFEEDS'}


In [52]:
result = get_feeds_url(mock_sess)

In [53]:
assert result == 'TESTFEEDS'

In [54]:
mock_sess.get.assert_called_with('https://api.github.com')

In [55]:
mock_sess.get().json.assert_called()

</external api example>

# Autospec

In [56]:
from datetime import datetime

In [57]:
datetime.utcnow()

datetime.datetime(2022, 6, 15, 19, 46, 40, 867235)

In [58]:
datetime.gmtnow()

AttributeError: type object 'datetime.datetime' has no attribute 'gmtnow'

In [59]:
mock_dt = mock.Mock()

In [60]:
mock_dt.utcnow()

<Mock name='mock.utcnow()' id='140293172682272'>

In [61]:
mock_dt.gmtnow()

<Mock name='mock.gmtnow()' id='140293172681936'>

In [62]:
mock_dt = mock.create_autospec(datetime)

In [63]:
mock_dt.utcnow()

<MagicMock name='mock.utcnow()' id='140293172240144'>

In [64]:
mock_dt.gmtnow()

AttributeError: Mock object has no attribute 'gmtnow'

In [65]:
mock_dt.utcnow('this is a spurious argument')

<MagicMock name='mock.utcnow()' id='140293172240144'>

In [66]:
mock_dt.utcnow.return_value = datetime(2011, 1, 1)

In [67]:
mock_dt.utcnow()

datetime.datetime(2011, 1, 1, 0, 0)

In [68]:
mock_dt.utcnow.assert_called_with()

Autospec + API

In [69]:
mock_sess = mock.create_autospec(requests.Session())
mock_resp = mock_sess.get('/some_url')
mock_resp.json.return_value = {'feeds_url': 'TESTFEEDS'}


In [70]:
result = get_feeds_url(mock_sess)

In [71]:
assert result == 'TESTFEEDS'

In [72]:
mock_sess.get.assert_called_with('https://api.github.com')

In [73]:
mock_resp.json.assert_called()

In [74]:
mock_sess.some_nonexistent_method()

AttributeError: Mock object has no attribute 'some_nonexistent_method'

[placebo] is also interesting for mocking AWS calls

[placebo]: https://github.com/garnaat/placebo

In [None]:
# import boto3
# ec2 = boto3.resource('ec2')
ec2 = mock.Mock()
ec2.meta.client.describe_instances.return_value = {
    'Instances': [
        {'id': ...}
    ]
}

ec2.meta.client.describe_instances()

Multiple different mock results

In [75]:
m = mock.Mock()

In [76]:
m.side_effect = [1,2,3,ValueError]

In [77]:
for x in range(4):
    print(m())

1
2
3


ValueError: 