diff --git a/dataclass_io/_lib/assertions.py b/dataclass_io/_lib/assertions.py index 36dac4c..3891110 100644 --- a/dataclass_io/_lib/assertions.py +++ b/dataclass_io/_lib/assertions.py @@ -4,6 +4,7 @@ from os import access from os import stat from pathlib import Path +from typing import Any from dataclass_io._lib.dataclass_extensions import DataclassInstance from dataclass_io._lib.dataclass_extensions import fieldnames @@ -98,8 +99,8 @@ def assert_file_is_appendable( def assert_file_header_matches_dataclass( file: Path | ReadableFileHandle, dataclass_type: type[DataclassInstance], - delimiter: str, comment_prefix: str, + **kwargs: Any, ) -> None: """ Check that the specified file has a header and its fields match those of the provided dataclass. @@ -107,11 +108,11 @@ def assert_file_header_matches_dataclass( header: FileHeader | None if isinstance(file, Path): with file.open("r") as fin: - header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix) + header = get_header(reader=fin, comment_prefix=comment_prefix, **kwargs) else: pos = file.tell() try: - header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix) + header = get_header(reader=file, comment_prefix=comment_prefix, **kwargs) finally: file.seek(pos) diff --git a/dataclass_io/_lib/file.py b/dataclass_io/_lib/file.py index 1d9ac8c..0d7251a 100644 --- a/dataclass_io/_lib/file.py +++ b/dataclass_io/_lib/file.py @@ -1,3 +1,4 @@ +from csv import DictReader from dataclasses import dataclass from enum import Enum from enum import unique @@ -66,8 +67,8 @@ class FileHeader: def get_header( reader: ReadableFileHandle, - delimiter: str, comment_prefix: str, + **kwargs: Any, ) -> Optional[FileHeader]: """ Read the header from an open file. @@ -85,6 +86,7 @@ def get_header( Args: reader: An open, readable file handle. comment_char: The character which indicates the start of a comment line. + **kwargs: Additional keyword arguments to pass to `csv.DictReader`. Returns: A `FileHeader` containing the field names and any preceding lines. @@ -103,6 +105,9 @@ def get_header( else: return None - fieldnames = line.strip().split(delimiter) + # msto#19 Read header fields + # Use csv.DictReader because RFC4180 is tricky to implement correctly + header_reader = DictReader([line], **kwargs) + fieldnames = list(header_reader.fieldnames) return FileHeader(preface=preface, fieldnames=fieldnames) diff --git a/dataclass_io/reader.py b/dataclass_io/reader.py index 44ed599..d953694 100644 --- a/dataclass_io/reader.py +++ b/dataclass_io/reader.py @@ -25,9 +25,9 @@ def __init__( self, fin: ReadableFileHandle, dataclass_type: type[DataclassInstance], - delimiter: str = "\t", comment_prefix: str = "#", - **kwds: Any, + delimiter: str = "\t", + **kwargs: Any, ) -> None: """ Args: @@ -35,6 +35,7 @@ def __init__( dataclass_type: Dataclass type. delimiter: The input file delimiter. comment_prefix: The prefix for any comment/preface rows preceding the header row. + quoting: Quoting style (enum value from Python csv package). dataclass_type: Dataclass type. Raises: @@ -46,17 +47,22 @@ def __init__( dataclass_type=dataclass_type, delimiter=delimiter, comment_prefix=comment_prefix, + **kwargs, ) self._dataclass_type = dataclass_type self._fin = fin self._header = get_header( - reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix + reader=self._fin, + delimiter=delimiter, + comment_prefix=comment_prefix, + **kwargs, ) self._reader = DictReader( f=self._fin, fieldnames=fieldnames(dataclass_type), delimiter=delimiter, + **kwargs, ) def __iter__(self) -> "DataclassReader": diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py index 1d71b9a..9f76f5f 100644 --- a/dataclass_io/writer.py +++ b/dataclass_io/writer.py @@ -31,7 +31,7 @@ def __init__( include_fields: list[str] | None = None, exclude_fields: list[str] | None = None, write_header: bool = True, - **kwds: Any, + **kwargs: Any, ) -> None: """ Args: @@ -65,6 +65,7 @@ def __init__( f=self._fout, fieldnames=self._fieldnames, delimiter=delimiter, + **kwargs, ) # TODO: permit writing comment/preface rows before header @@ -124,9 +125,9 @@ def open( dataclass_type: type[DataclassInstance], mode: str = "write", overwrite: bool = True, - delimiter: str = "\t", comment_prefix: str = "#", - **kwds: Any, + delimiter: str = "\t", + **kwargs: Any, ) -> Iterator["DataclassWriter"]: """ Open a new `DataclassWriter` from a file path. @@ -142,11 +143,11 @@ def open( `exclude_fields`. overwrite: If `True`, and `mode="write"`, the file specified at `path` will be overwritten if it exists. - delimiter: The output file delimiter. comment_prefix: The prefix for any comment/preface rows preceding the header row. (This argument is ignored when `mode="write"`. It is used when `mode="append"` to validate that the existing file's header matches the specified dataclass.) - **kwds: Additional keyword arguments to be passed to the `DataclassWriter` constructor. + delimiter: The output file delimiter. + **kwds: Additional keyword arguments to be passed to `csv.DictWriter`. Yields: A `DataclassWriter` instance. @@ -178,6 +179,7 @@ def open( dataclass_type=dataclass_type, delimiter=delimiter, comment_prefix=comment_prefix, + **kwargs, ) fout = filepath.open(write_mode.abbreviation) @@ -186,7 +188,7 @@ def open( fout=fout, dataclass_type=dataclass_type, write_header=(write_mode is WriteMode.WRITE), # Skip header when appending - **kwds, + **kwargs, ) finally: fout.close() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1ad8db4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="session") +def datadir() -> Path: + """Path to the test data directory.""" + + return Path(__file__).parent / "data" diff --git a/tests/data/reader_should_parse_quotes.tsv b/tests/data/reader_should_parse_quotes.tsv new file mode 100644 index 0000000..acad7e4 --- /dev/null +++ b/tests/data/reader_should_parse_quotes.tsv @@ -0,0 +1,3 @@ +"id" "title" +"fake" "A fake object" +"also_fake" "Another fake object" diff --git a/tests/test_reader.py b/tests/test_reader.py index 6e1309e..9d4228f 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -35,3 +35,23 @@ class FakeDataclass: assert isinstance(rows[0], FakeDataclass) assert rows[0].foo == "abc" assert rows[0].bar == 1 + + +def test_reader_should_parse_quotes(datadir: Path) -> None: + """ + Test that having quotes around column names in header row doesn't break anything + https://github.com/msto/dataclass_io/issues/19 + """ + fpath = datadir / "reader_should_parse_quotes.tsv" + + @dataclass + class FakeDataclass: + id: str + title: str + + # Parse CSV using DataclassReader + with DataclassReader.open(fpath, FakeDataclass) as reader: + records = [record for record in reader] + + assert records[0] == FakeDataclass(id="fake", title="A fake object") + assert records[1] == FakeDataclass(id="also_fake", title="Another fake object")