diff --git a/avocado/export/_base.py b/avocado/export/_base.py index abc5e1c..618c45f 100644 --- a/avocado/export/_base.py +++ b/avocado/export/_base.py @@ -58,15 +58,25 @@ def _format_row(self, row, **kwargs): yield formatter(values, preferred_formats=self.preferred_formats, **kwargs) - def read(self, iterable, force_distinct=True, *args, **kwargs): + def read(self, iterable, force_distinct=True, offset=None, limit=None, + *args, **kwargs): """Takes an iterable that produces rows to be formatted. If `force_distinct` is true, rows will be filtered based on the slice of the row that is going to be formatted. + + If `offset` is defined, only rows that are produced after the offset + index are returned. + + If `limit` is defined, only rows up the """ + emitted = 0 unique_rows = set() - for row in iterable: + for i, row in enumerate(iterable): + if limit is not None and emitted >= limit: + break + _row = row[:self.row_length] if force_distinct: @@ -77,7 +87,9 @@ def read(self, iterable, force_distinct=True, *args, **kwargs): unique_rows.add(_row_hash) - yield self._format_row(_row, **kwargs) + if offset is None or i >= offset: + emitted += 1 + yield self._format_row(_row, **kwargs) def write(self, iterable, *args, **kwargs): for row_gen in self.read(iterable, *args, **kwargs): diff --git a/tests/cases/exporting/tests.py b/tests/cases/exporting/tests.py index 6150444..e8bffad 100644 --- a/tests/cases/exporting/tests.py +++ b/tests/cases/exporting/tests.py @@ -17,27 +17,30 @@ class ExportTestCase(TestCase): def setUp(self): management.call_command('avocado', 'init', 'tests', quiet=True) - - def test_view(self): - salary_field = DataField.objects.get_by_natural_key('tests', 'title', 'salary') - salary_concept = DataConcept() - salary_concept.save() - DataConceptField(concept=salary_concept, field=salary_field, order=1).save() + salary_concept = DataField.objects.get(field_name='salary').concepts.all()[0] view = DataView(json={'ordering': [[salary_concept.pk, 'desc']]}) - query = view.apply(tree=models.Employee).raw() - + self.query = view.apply(tree=models.Employee).raw() # Ick.. - exporter = export.CSVExporter(view) - exporter.params.insert(0, (RawFormatter(keys=['pk']), 1)) - exporter.row_length += 1 + self.exporter = export.BaseExporter(view) + self.exporter.params.insert(0, (RawFormatter(keys=['pk']), 1)) + self.exporter.row_length += 1 - buff = exporter.write(query) - buff.seek(0) + def test(self): + rows = list(self.exporter.write(self.query)) + self.assertEqual([r[0] for r in rows], [2, 4, 6, 1, 3, 5]) + + def test_offset(self): + rows = list(self.exporter.write(self.query, offset=2)) + self.assertEqual([r[0] for r in rows], [6, 1, 3, 5]) + + def test_limit(self): + rows = list(self.exporter.write(self.query, limit=2)) + self.assertEqual([r[0] for r in rows], [2, 4]) - lines = buff.read().splitlines() - # Skip the header - self.assertEqual([int(x) for x in lines[1:]], [2, 4, 6, 1, 3, 5]) + def test_limit_offset(self): + rows = list(self.exporter.write(self.query, offset=2, limit=2)) + self.assertEqual([r[0] for r in rows], [6, 1]) class FileExportTestCase(TestCase):