diff --git a/src/translators/config.py b/src/translators/config.py index f27e1531..d41df340 100644 --- a/src/translators/config.py +++ b/src/translators/config.py @@ -4,6 +4,10 @@ from utils.cfgreader import EnvReader, BoolVar, IntVar +DEFAULT_LIMIT_VAR = 'DEFAULT_LIMIT' +KEEP_RAW_ENTITY_VAR = 'KEEP_RAW_ENTITY' + + class SQLTranslatorConfig: """ Provide access to SQL Translator config values. @@ -15,9 +19,9 @@ def __init__(self, env: dict = os.environ): def default_limit(self) -> int: fallback_limit = 10000 - var = IntVar('DEFAULT_LIMIT', default_value=fallback_limit) + var = IntVar(DEFAULT_LIMIT_VAR, default_value=fallback_limit) return self.store.safe_read(var) def keep_raw_entity(self) -> bool: - var = BoolVar('KEEP_RAW_ENTITY', False) + var = BoolVar(KEEP_RAW_ENTITY_VAR, False) return self.store.safe_read(var) diff --git a/src/translators/sql_translator.py b/src/translators/sql_translator.py index fe697810..d6a1a2a7 100644 --- a/src/translators/sql_translator.py +++ b/src/translators/sql_translator.py @@ -10,7 +10,6 @@ import logging from geocoding.slf import SlfQuery import dateutil.parser -import os import json from typing import List, Optional @@ -287,8 +286,9 @@ def _insert_entities_of_type(self, def _insert_entity_rows(self, table_name: str, col_names: List[str], rows: List[List], entities: List[dict]): - col_list = ', '.join(['"{}"'.format(c.lower()) for c in col_names]) - placeholders = ','.join(['?'] * len(col_names)) + col_list, placeholders, rows = \ + self._build_insert_params_and_values(col_names, rows, entities) + stmt = f"insert into {table_name} ({col_list}) values ({placeholders})" try: self.cursor.executemany(stmt, rows) @@ -303,6 +303,24 @@ def _insert_entity_rows(self, table_name: str, col_names: List[str], ) self._insert_original_entities(table_name, entities) + def _build_insert_params_and_values( + self, col_names: List[str], rows: List[List], + entities: List[dict]) -> (str, str, List[List]): + if self.config.keep_raw_entity(): + original_entity_col_index = col_names.index(ORIGINAL_ENTITY_COL) + for i, r in enumerate(rows): + r[original_entity_col_index] = json.dumps(entities[i]) + + col_list = ', '.join(['"{}"'.format(c.lower()) for c in col_names]) + placeholders = ','.join(['?'] * len(col_names)) + return col_list, placeholders, rows + # NOTE. Brittle code. + # This code, like the rest of the insert workflow implicitly assumes + # 1. col_names[k] <-> rows[k] <-> entities[k] + # 2. original entity column always gets added upfront + # But we never really check anywhere (1) and (2) always hold true, + # so slight changes to the insert workflow could cause nasty bugs... + def _should_insert_original_entities(self, insert_error: Exception) -> bool: raise NotImplementedError diff --git a/src/translators/tests/original_data_scenarios.py b/src/translators/tests/original_data_scenarios.py index 4eda9751..e3fde2dc 100644 --- a/src/translators/tests/original_data_scenarios.py +++ b/src/translators/tests/original_data_scenarios.py @@ -1,10 +1,12 @@ import json +import os import pytest import random from time import sleep from typing import Any, Callable, Generator, List from translators.base_translator import TIME_INDEX_NAME +from translators.config import KEEP_RAW_ENTITY_VAR from translators.sql_translator import SQLTranslator, current_timex from translators.sql_translator import ORIGINAL_ENTITY_COL, ENTITY_ID_COL, \ TYPE_PREFIX, TENANT_PREFIX @@ -42,6 +44,11 @@ def gen_entity(entity_id: int, attr_type: str, attr_value) -> dict: } +def assert_saved_original(actual_row, original_entity): + saved_entity = json.loads(actual_row[ORIGINAL_ENTITY_COL]) + assert original_entity == saved_entity + + def assert_inserted_entity(actual_row, original_entity): assert actual_row['a_number'] == \ maybe_value(original_entity, 'a_number', 'value') @@ -55,8 +62,7 @@ def assert_failed_entity(actual_row, original_entity): assert actual_row['an_attr'] is None assert actual_row[ORIGINAL_ENTITY_COL] is not None - saved_entity = json.loads(actual_row[ORIGINAL_ENTITY_COL]) - assert original_entity == saved_entity + assert_saved_original(actual_row, original_entity) def full_table_name(tenant: str) -> str: @@ -149,3 +155,26 @@ def run_success_scenario(self): assert_inserted_entity(rs[0], e1) assert_inserted_entity(rs[1], e2) assert_inserted_entity(rs[2], e3) + + def _do_success_scenario_with_keep_raw_on(self): + tenant = gen_tenant_id() + e1, e2, e3 = [gen_entity(k + 1, 'Number', k + 1) for k in range(3)] + + self.insert_entities(tenant, [e1]) + self.insert_entities(tenant, [e2, e3]) + + rs = self.fetch_rows(tenant) + + assert len(rs) == 3 + assert_saved_original(rs[0], e1) + assert_saved_original(rs[1], e2) + assert_saved_original(rs[2], e3) + + def run_success_scenario_with_keep_raw_on(self): + os.environ[KEEP_RAW_ENTITY_VAR] = 'true' + try: + self._do_success_scenario_with_keep_raw_on() + except Exception: + del os.environ[KEEP_RAW_ENTITY_VAR] + raise + del os.environ[KEEP_RAW_ENTITY_VAR] diff --git a/src/translators/tests/test_crate_original_data.py b/src/translators/tests/test_crate_original_data.py index 57f063a0..26d2c451 100644 --- a/src/translators/tests/test_crate_original_data.py +++ b/src/translators/tests/test_crate_original_data.py @@ -36,3 +36,7 @@ def test_data_loss_scenario(with_crate): def test_success_scenario(with_crate): with_crate.run_success_scenario() + + +def test_success_scenario_with_keep_raw_on(with_crate): + with_crate.run_success_scenario_with_keep_raw_on() diff --git a/src/translators/tests/test_timescale_original_data.py b/src/translators/tests/test_timescale_original_data.py index 9214b003..60089469 100644 --- a/src/translators/tests/test_timescale_original_data.py +++ b/src/translators/tests/test_timescale_original_data.py @@ -39,3 +39,7 @@ def test_data_loss_scenario(with_timescale): def test_success_scenario(with_timescale): with_timescale.run_success_scenario() + + +def test_success_scenario_with_keep_raw_on(with_timescale): + with_timescale.run_success_scenario_with_keep_raw_on()