From ef182f79be294431164275c37b76f4fa929fdb99 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 17 Apr 2024 22:45:00 +0200 Subject: [PATCH] adds options to write headers, change delimiter --- dlt/common/data_writers/writers.py | 32 ++++++++++++++++--- .../docs/dlt-ecosystem/file-formats/csv.md | 20 ++++++++++++ .../load/pipeline/test_filesystem_pipeline.py | 29 +++++++++++++++++ 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 2b14d8cd72..850a27e8bc 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -381,15 +381,27 @@ def writer_spec(cls) -> FileWriterSpec: ) +@configspec +class CsvDataWriterConfiguration(BaseConfiguration): + delimiter: str = "," + include_header: bool = True + + __section__: ClassVar[str] = known_sections.DATA_WRITER + + class CsvWriter(DataWriter): + @with_config(spec=CsvDataWriterConfiguration) def __init__( self, f: IO[Any], caps: DestinationCapabilitiesContext = None, + *, delimiter: str = ",", + include_header: bool = True, bytes_encoding: str = "utf-8", ) -> None: super().__init__(f, caps) + self.include_header = include_header self.delimiter = delimiter self.writer: csv.DictWriter[str] = None self.bytes_encoding = bytes_encoding @@ -404,7 +416,8 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: delimiter=self.delimiter, quoting=csv.QUOTE_NONNUMERIC, ) - self.writer.writeheader() + if self.include_header: + self.writer.writeheader() # find row items that are of the complex type (could be abstracted out for use in other writers?) self.complex_indices = [ i for i, field in columns_schema.items() if field["data_type"] == "complex" @@ -499,11 +512,19 @@ def writer_spec(cls) -> FileWriterSpec: class ArrowToCsvWriter(DataWriter): + @with_config(spec=CsvDataWriterConfiguration) def __init__( - self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: bytes = b"," + self, + f: IO[Any], + caps: DestinationCapabilitiesContext = None, + *, + delimiter: str = ",", + include_header: bool = True, ) -> None: super().__init__(f, caps) self.delimiter = delimiter + self._delimiter_b = delimiter.encode("ascii") + self.include_header = include_header self.writer: Any = None def write_header(self, columns_schema: TTableSchemaColumns) -> None: @@ -521,7 +542,8 @@ def write_data(self, rows: Sequence[Any]) -> None: self._f, row.schema, write_options=pyarrow.csv.WriteOptions( - include_header=True, delimiter=self.delimiter + include_header=self.include_header, + delimiter=self._delimiter_b, ), ) self._first_schema = row.schema @@ -573,10 +595,10 @@ def write_data(self, rows: Sequence[Any]) -> None: self.items_count += row.num_rows def write_footer(self) -> None: - if self.writer is None: + if self.writer is None and self.include_header: # write empty file self._f.write( - self.delimiter.join( + self._delimiter_b.join( [ b'"' + col["name"].encode("utf-8") + b'"' for col in self._columns_schema.values() diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md index 4eb94b5ff0..dcd9e251f5 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/csv.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -32,6 +32,26 @@ info = pipeline.run(some_source(), loader_file_format="csv") * UNIX new lines are used * dates are represented as ISO 8601 +### Change settings +You can change basic **csv** settings, this may be handy when working with **filesystem** destination. Other destinations are tested +with standard settings: + +* delimiter: change the delimiting character (default: ',') +* include_header: include the header row (default: True) + +```toml +[normalize.data_writer] +delimiter="|" +include_header=false +``` + +Or using environment variables: + +```sh +NORMALIZE__DATA_WRITER__DELIMITER=| +NORMALIZE__DATA_WRITER__INCLUDE_HEADER=False +``` + ## Limitations **arrow writer** diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 20f326b160..b02525f4a4 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -145,6 +145,35 @@ def test_pipeline_csv_filesystem_destination(item_type: TestDataItemFormat) -> N assert len(csv_rows) == 3 +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_csv_options(item_type: TestDataItemFormat) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + os.environ["RESTORE_FROM_DESTINATION"] = "False" + # set delimiter and disable headers + os.environ["NORMALIZE__DATA_WRITER__DELIMITER"] = "|" + os.environ["NORMALIZE__DATA_WRITER__INCLUDE_HEADER"] = "False" + # store locally + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "file://_storage" + pipeline = dlt.pipeline( + pipeline_name="parquet_test_" + uniq_id(), + destination="filesystem", + dataset_name="parquet_test_" + uniq_id(), + ) + + item, rows, _ = arrow_table_all_data_types(item_type, include_json=False, include_time=True) + info = pipeline.run(item, table_name="table", loader_file_format="csv") + info.raise_on_failed_jobs() + job = info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + with open(job, "r", encoding="utf-8", newline="") as f: + csv_rows = list(csv.reader(f, dialect=csv.unix_dialect, delimiter="|")) + # no header + assert len(csv_rows) == 3 + # object csv adds dlt columns + dlt_columns = 2 if item_type == "object" else 0 + assert len(rows[0]) + dlt_columns == len(csv_rows[0]) + + def test_pipeline_parquet_filesystem_destination() -> None: import pyarrow.parquet as pq # Module is evaluated by other tests