Skip to content

Commit

Permalink
v0.0.26 Only cast row values that are being tested (#36)
Browse files Browse the repository at this point in the history
* Only cast row values that are being tested

* v0.0.26

* lint
  • Loading branch information
akariv committed Oct 17, 2018
1 parent 4d1e494 commit 1df6436
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion dataflows/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.25
0.0.26
20 changes: 7 additions & 13 deletions dataflows/base/schema_validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from datapackage import Resource
from tableschema import Schema
from tableschema.exceptions import CastError
Expand All @@ -19,21 +17,17 @@ def __init__(self, resource_name, row, index, cast_error):
self.cast_error = cast_error


def schema_validator(resource: Resource, iterator):
def schema_validator(resource: Resource, iterator, field_names=None):
schema: Schema = resource.schema
field_names = [f.name for f in schema.fields]
warned_fields = set()
if field_names is None:
field_names = [f.name for f in schema.fields]
schema_fields = [f for f in schema.fields if f.name in field_names]
for i, row in enumerate(iterator):
to_cast = [row.get(f) for f in field_names]

try:
casted = schema.cast_row(to_cast)
row = dict(zip(field_names, casted))
for f in schema_fields:
row[f.name] = f.cast_value(row[f.name])
except CastError as e:
raise ValidationError(resource.name, row, i, e)

for k in set(row.keys()) - set(field_names):
if k not in warned_fields:
warned_fields.add(k)
logging.warning('Encountered field %r, not in schema', k)

yield row
7 changes: 6 additions & 1 deletion dataflows/processors/set_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def __init__(self, name, resources=-1, **options):
self.name = re.compile(f'^{name}$')
self.options = options
self.resources = resources
self.field_names = []

def process_resources(self, resources):
for res in resources:
if self.matcher.match(res.res.name):
yield schema_validator(res.res, res)
if len(self.field_names) > 0:
yield schema_validator(res.res, res, field_names=self.field_names)
else:
yield res
else:
yield res

Expand All @@ -28,6 +32,7 @@ def process_datapackage(self, dp):
for field in res['schema']['fields']:
if self.name.match(field['name']):
field.update(self.options)
self.field_names.append(field['name'])
added = True
assert added, 'Failed to find field {} in schema'.format(self.name)
return dp
4 changes: 2 additions & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_example_7():
def add_is_guitarist_column(package):

# Add a new field to the first resource
package.pkg.resources[0].descriptor['schema']['fields'].append(dict(
package.pkg.descriptor['resources'][0]['schema']['fields'].append(dict(
name='is_guitarist',
type='boolean'
))
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_example_75():
def add_is_guitarist_column_to_schema(package):

# Add a new field to the first resource
package.pkg.resources[0].descriptor['schema']['fields'].append(dict(
package.pkg.descriptor['resources'][0]['schema']['fields'].append(dict(
name='is_guitarist',
type='boolean'
))
Expand Down

0 comments on commit 1df6436

Please sign in to comment.