Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds options to write headers, change delimiter #1239

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 27 additions & 5 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,27 @@ def writer_spec(cls) -> FileWriterSpec:
)


@configspec
class CsvDataWriterConfiguration(BaseConfiguration):
delimiter: str = ","
include_header: bool = True

__section__: ClassVar[str] = known_sections.DATA_WRITER


class CsvWriter(DataWriter):
@with_config(spec=CsvDataWriterConfiguration)
def __init__(
self,
f: IO[Any],
caps: DestinationCapabilitiesContext = None,
*,
delimiter: str = ",",
include_header: bool = True,
bytes_encoding: str = "utf-8",
) -> None:
super().__init__(f, caps)
self.include_header = include_header
self.delimiter = delimiter
self.writer: csv.DictWriter[str] = None
self.bytes_encoding = bytes_encoding
Expand All @@ -404,7 +416,8 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
delimiter=self.delimiter,
quoting=csv.QUOTE_NONNUMERIC,
)
self.writer.writeheader()
if self.include_header:
self.writer.writeheader()
# find row items that are of the complex type (could be abstracted out for use in other writers?)
self.complex_indices = [
i for i, field in columns_schema.items() if field["data_type"] == "complex"
Expand Down Expand Up @@ -499,11 +512,19 @@ def writer_spec(cls) -> FileWriterSpec:


class ArrowToCsvWriter(DataWriter):
@with_config(spec=CsvDataWriterConfiguration)
def __init__(
self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: bytes = b","
self,
f: IO[Any],
caps: DestinationCapabilitiesContext = None,
*,
delimiter: str = ",",
include_header: bool = True,
) -> None:
super().__init__(f, caps)
self.delimiter = delimiter
self._delimiter_b = delimiter.encode("ascii")
self.include_header = include_header
self.writer: Any = None

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
Expand All @@ -521,7 +542,8 @@ def write_data(self, rows: Sequence[Any]) -> None:
self._f,
row.schema,
write_options=pyarrow.csv.WriteOptions(
include_header=True, delimiter=self.delimiter
include_header=self.include_header,
delimiter=self._delimiter_b,
),
)
self._first_schema = row.schema
Expand Down Expand Up @@ -573,10 +595,10 @@ def write_data(self, rows: Sequence[Any]) -> None:
self.items_count += row.num_rows

def write_footer(self) -> None:
if self.writer is None:
if self.writer is None and self.include_header:
# write empty file
self._f.write(
self.delimiter.join(
self._delimiter_b.join(
[
b'"' + col["name"].encode("utf-8") + b'"'
for col in self._columns_schema.values()
Expand Down
20 changes: 20 additions & 0 deletions docs/website/docs/dlt-ecosystem/file-formats/csv.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ info = pipeline.run(some_source(), loader_file_format="csv")
* UNIX new lines are used
* dates are represented as ISO 8601

### Change settings
You can change basic **csv** settings, this may be handy when working with **filesystem** destination. Other destinations are tested
with standard settings:

* delimiter: change the delimiting character (default: ',')
* include_header: include the header row (default: True)

```toml
[normalize.data_writer]
delimiter="|"
include_header=false
```

Or using environment variables:

```sh
NORMALIZE__DATA_WRITER__DELIMITER=|
NORMALIZE__DATA_WRITER__INCLUDE_HEADER=False
```

## Limitations
**arrow writer**

Expand Down
29 changes: 29 additions & 0 deletions tests/load/pipeline/test_filesystem_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,35 @@ def test_pipeline_csv_filesystem_destination(item_type: TestDataItemFormat) -> N
assert len(csv_rows) == 3


@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS)
def test_csv_options(item_type: TestDataItemFormat) -> None:
os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True"
os.environ["RESTORE_FROM_DESTINATION"] = "False"
# set delimiter and disable headers
os.environ["NORMALIZE__DATA_WRITER__DELIMITER"] = "|"
os.environ["NORMALIZE__DATA_WRITER__INCLUDE_HEADER"] = "False"
# store locally
os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "file://_storage"
pipeline = dlt.pipeline(
pipeline_name="parquet_test_" + uniq_id(),
destination="filesystem",
dataset_name="parquet_test_" + uniq_id(),
)

item, rows, _ = arrow_table_all_data_types(item_type, include_json=False, include_time=True)
info = pipeline.run(item, table_name="table", loader_file_format="csv")
info.raise_on_failed_jobs()
job = info.load_packages[0].jobs["completed_jobs"][0].file_path
assert job.endswith("csv")
with open(job, "r", encoding="utf-8", newline="") as f:
csv_rows = list(csv.reader(f, dialect=csv.unix_dialect, delimiter="|"))
# no header
assert len(csv_rows) == 3
# object csv adds dlt columns
dlt_columns = 2 if item_type == "object" else 0
assert len(rows[0]) + dlt_columns == len(csv_rows[0])


def test_pipeline_parquet_filesystem_destination() -> None:
import pyarrow.parquet as pq # Module is evaluated by other tests

Expand Down