diff --git a/spoonbill/flatten.py b/spoonbill/flatten.py index a884a45..9646ad9 100644 --- a/spoonbill/flatten.py +++ b/spoonbill/flatten.py @@ -29,6 +29,7 @@ class TableFlattenConfig: only: List[str] = field(default_factory=list) repeat: List[str] = field(default_factory=list) unnest: List[str] = field(default_factory=list) + only: List[str] = field(default_factory=list) name: str = "" @@ -58,7 +59,7 @@ class Flattener: In order to export data correctly Flattener requires previously analyzed tables data. During the process flattener could add columns not based on schema analysis, such as `itemsCount`. - In every generated row, depending on table type, flattener will always few add augenerated columns. + In every generated row, depending on table type, flattener will always few add autogenerated columns. For root table: * rowID * id @@ -100,12 +101,15 @@ def _init_table_cache(self, tables, table): def _init_options(self, tables): for table in tables.values(): + name = table.name count = self.options.count options = self.options.selection[name] unnest = options.unnest split = options.split repeat = options.repeat + only = options.only + if count: for array in table.arrays: parts = array.split("/") @@ -142,6 +146,19 @@ def _init_options(self, tables): child_table = self.tables.get(c_name) child_table.columns[col_id] = col child_table.titles[col_id] = title + if only: + if split: + table.columns = {c_id: c for c_id, c in table.columns.items() if c_id in only} + else: + table.combined_columns = {c_id: c for c_id, c in table.combined_columns.items() if c_id in only} + + def _only(self, table, only, split): + table.types = {c_id: c for c_id, c in table.types.items() if c_id in only} + + if split: + table.columns = {c_id: c for c_id, c in table.columns.items() if c_id in only} + return + table.combined_columns = {c_id: c for c_id, c in table.combined_columns.items() if c_id in only} def _init(self): # init cache and filter only selected tables @@ -149,8 +166,11 @@ def _init(self): for name, table in self.tables.items(): if name not in self.options.selection: continue + options = self.options.selection[name] + split = options.split + if options.only: + self._only(table, options.only, split) self._init_table_cache(tables, table) - split = self.options.selection[name].split if split: for c_name in table.child_tables: if c_name in self.options.exclude: diff --git a/tests/test_flatten.py b/tests/test_flatten.py index e929778..4cc993f 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -120,11 +120,20 @@ def test_flatten_with_exclude(spec_analyzed, releases): def test_flatten_with_only(spec_analyzed, releases): - options = FlattenOptions(**{"selection": {"tenders": {"split": True}}, "exclude": "tender_items"}) + options = FlattenOptions(**{"selection": {"tenders": {"split": True, "only": ["/tender/id"]}}}) flattener = Flattener(options, spec_analyzed.tables) all_rows = defaultdict(list) for count, flat in flattener.flatten(releases): for name, rows in flat.items(): all_rows[name].extend(rows) + for row in all_rows["tenders"]: + assert row == ["/tender/id"] - assert "tender_items" not in all_rows + options = FlattenOptions(**{"selection": {"tenders": {"split": False, "only": ["/tender/id"]}}}) + flattener = Flattener(options, spec_analyzed.tables) + all_rows = defaultdict(list) + for count, flat in flattener.flatten(releases): + for name, rows in flat.items(): + all_rows[name].extend(rows) + for row in all_rows["tenders"]: + assert row == ["/tender/id"]