diff --git a/manage.py b/manage.py index 72df34a..15a64ba 100755 --- a/manage.py +++ b/manage.py @@ -75,6 +75,7 @@ def sql_files(directory, tables_only=False, where_fragment=None): :param str directory: a sub-directory containing SQL files :param bool tables_only: whether to create SQL tables instead of SQL views + :param str where_fragment: part of a WHERE clause to use when selecting the data """ files = {} diff --git a/tests/test_add.py b/tests/test_add.py index 957e76d..ad26806 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -326,7 +326,10 @@ def test_command(db, tables_only, field_counts, field_lists, tables, views, filt assert_log_records(caplog, command, expected) -@pytest.mark.parametrize('filters', [(('tender.procurementMethod', 'direct'),)]) +@pytest.mark.parametrize('filters', [ + (('tender.procurementMethod', 'direct'),), + (('tender.procurementMethod', 'direct'), ('tender.status', 'planned'),), +]) @pytest.mark.parametrize('tables_only, field_counts, field_lists, tables, views', [ (False, True, False, TABLES | SUMMARY_TABLES, SUMMARY_VIEWS), @@ -361,24 +364,34 @@ def test_command_filter(db, tables_only, field_counts, field_lists, tables, view """) for row in rows: assert row[0] == 'direct' + if len(filters) > 1: + assert len(rows) == 2 + else: + assert len(rows) == 19 # Check data_id's in the summary against the data table # This allows us to check that missing data doesn't have the filtered value - assert len(rows) == 19 rows = db.all(""" SELECT data_id FROM view_data_collection_1.release_summary """) - assert len(rows) == 19 + if len(filters) > 1: + assert len(rows) == 2 + else: + assert len(rows) == 19 data_ids = [row[0] for row in rows] rows = db.all(""" - SELECT data.id, data.data->'tender'->'procurementMethod' FROM data + SELECT + data.id, + data.data->'tender'->'procurementMethod', + data.data->'tender'->'status' + FROM data JOIN release ON release.data_id=data.id WHERE release.collection_id=1 """) for row in rows: - if row[1] == 'direct': + if row[1] == 'direct' and (len(filters) == 1 or row[2] == 'planned'): assert row[0] in data_ids else: assert row[0] not in data_ids @@ -437,7 +450,10 @@ def test_command_filter(db, tables_only, field_counts, field_lists, tables, view }, # document_documenttype_counts 5, # total_items ) - assert len(rows) == 55 + if len(filters) > 1: + assert len(rows) == 7 + else: + assert len(rows) == 55 rows = db.all(""" SELECT @@ -475,14 +491,21 @@ def test_command_filter(db, tables_only, field_counts, field_lists, tables, view 1, # total_additionalidentifiers ) - assert len(rows) == 56 + if len(filters) > 1: + assert len(rows) == 5 + else: + assert len(rows) == 56 if field_counts: # Check contents of field_counts table. rows = db.all('SELECT * FROM view_data_collection_1.field_counts') - assert len(rows) == 13077 - assert rows[0] == (1, 'release', 'awards', 19, 55, 19) + if len(filters) > 1: + assert len(rows) == 1515 + assert rows[0] == (1, 'release', 'awards', 2, 7, 2) + else: + assert len(rows) == 13077 + assert rows[0] == (1, 'release', 'awards', 19, 55, 19) if field_lists: # Check the count of keys in the field_list field for the lowest primary keys in each summary relation.