diff --git a/README.rst b/README.rst index e72ff98..1ab5cb6 100644 --- a/README.rst +++ b/README.rst @@ -349,6 +349,13 @@ Running Vlads Programatically field. Optional, defaults to the class variable `row_validators` if set, otherwise `[]`, which does not perform any row-level validation. + :``fieldnames=None``: + Sequence of field names to be passed through to the underlying + `csv.DictReader` instance. If provided, the reader will use these field + names instead of inferring them from the input CSV's first row. Intended + only for use with CSVs that do not have header rows. Optional, defaults + to `None`. + :``delimiter=','``: The delimiter used within your csv source. Optional, defaults to `,`. diff --git a/tests/test_vlads.py b/tests/test_vlads.py index 0d0c881..4579b97 100644 --- a/tests/test_vlads.py +++ b/tests/test_vlads.py @@ -90,6 +90,32 @@ class TestVlad(Vlad): assert not TestVlad(source=source).validate() +def test_explicit_fieldnames(): + source = String("Dracula,Vampire") + + class TestVlad(Vlad): + validators = { + "Name": [UniqueValidator()], + "Status": [SetValidator(["Vampire", "Not A Vampire"])], + } + fieldnames = ["Name", "Status"] + + assert TestVlad(source=source).validate() + + +def test_explicit_fieldnames_conflict_fails(): + source = LocalFile("vladiate/examples/vampires.csv") + + class TestVlad(Vlad): + validators = { + "Name": [UniqueValidator()], + "Status": [SetValidator(["Vampire", "Not A Vampire"])], + } + fieldnames = ["Name", "Status"] + + assert not TestVlad(source=source).validate() + + def test_fails_validation(): source = LocalFile("vladiate/examples/vampires.csv") diff --git a/vladiate/vlad.py b/vladiate/vlad.py index c6f8a06..ad2e6f3 100644 --- a/vladiate/vlad.py +++ b/vladiate/vlad.py @@ -17,6 +17,7 @@ def __init__( file_validation_failure_threshold=None, quiet=False, row_validators=[], + fieldnames=None, ): self.logger = logs.logger self.failures = defaultdict(lambda: defaultdict(list)) @@ -26,6 +27,7 @@ def __init__( self.source = source self.validators = validators or getattr(self, "validators", {}) self.row_validators = row_validators or getattr(self, "row_validators", []) + self.fieldnames = fieldnames or getattr(self, "fieldnames", None) self.delimiter = delimiter or getattr(self, "delimiter", ",") self.line_count = 0 self.ignore_missing_validators = ignore_missing_validators @@ -124,7 +126,9 @@ def _log_missing(self, missing_items): ) def _get_total_lines(self): - reader = csv.DictReader(self.source.open(), delimiter=self.delimiter) + reader = csv.DictReader( + self.source.open(), delimiter=self.delimiter, fieldnames=self.fieldnames + ) self.total_lines = sum(1 for _ in reader) return self.total_lines @@ -132,7 +136,9 @@ def validate(self): self.logger.info( "\nValidating {}(source={})".format(self.__class__.__name__, self.source) ) - reader = csv.DictReader(self.source.open(), delimiter=self.delimiter) + reader = csv.DictReader( + self.source.open(), delimiter=self.delimiter, fieldnames=self.fieldnames + ) if not reader.fieldnames: self.logger.info(