Skip to content

Commit

Permalink
Clean up CSV util primary key handling. Fixes #534
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Feb 24, 2015
1 parent 189933c commit 3b29611
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
13 changes: 9 additions & 4 deletions playhouse/csv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ class Loader(_CSVReader):
"""
def __init__(self, db_or_model, file_or_name, fields=None,
field_names=None, has_header=True, sample_size=10,
converter=None, db_table=None, **reader_kwargs):
converter=None, db_table=None, pk_in_csv=False,
**reader_kwargs):
self.file_or_name = file_or_name
self.fields = fields
self.field_names = field_names
Expand All @@ -208,8 +209,9 @@ def __init__(self, db_or_model, file_or_name, fields=None,
self.db_table = self.model._meta.db_table
self.fields = self.model._meta.get_fields()
self.field_names = self.model._meta.get_field_names()
# If using an auto-incrementing primary key, ignore it.
if self.model._meta.auto_increment:
# If using an auto-incrementing primary key, ignore it unless we
# are told the primary key is included in the CSV.
if self.model._meta.auto_increment and not pk_in_csv:
self.fields = self.fields[1:]
self.field_names = self.field_names[1:]

Expand Down Expand Up @@ -238,7 +240,10 @@ def get_model_class(self, field_names, fields):
if self.model:
return self.model
attrs = dict(zip(field_names, fields))
attrs['_auto_pk'] = PrimaryKeyField()
if 'id' not in attrs:
attrs['_auto_pk'] = PrimaryKeyField()
elif isinstance(attrs['id'], IntegerField):
attrs['id'] = PrimaryKeyField()
klass = type(self.db_table.title(), (Model,), attrs)
klass._meta.database = self.database
klass._meta.db_table = self.db_table
Expand Down
14 changes: 7 additions & 7 deletions playhouse/tests/test_csv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def load(self, *lines, **loader_kwargs):
return TestLoader(**loader_kwargs).load()

def assertData(self, ModelClass, expected):
name_field = ModelClass._meta.get_fields()[2]
name_field = ModelClass._meta.get_fields()[1]
query = ModelClass.select().order_by(name_field).tuples()
self.assertEqual([row[1:] for row in query], expected)
self.assertEqual([row for row in query], expected)

def test_defaults(self):
ModelClass = self.load(
Expand All @@ -83,7 +83,7 @@ def test_defaults(self):
self.assertData(ModelClass, [
(10, 'F1 L1', date(1983, 1, 1), 10000., 't'),
(20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f'),
(0, 'F3 L3', None, 0., ''),
(21, 'F3 L3', None, 0., ''),
])

def test_no_header(self):
Expand All @@ -95,8 +95,8 @@ def test_no_header(self):
self.assertEqual(ModelClass._meta.get_field_names(), [
'_auto_pk', 'f1', 'f2', 'f3', 'f4', 'f5'])
self.assertData(ModelClass, [
(10, 'F1 L1', date(1983, 1, 1), 10000., 't'),
(20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')])
(1, 10, 'F1 L1', date(1983, 1, 1), 10000., 't'),
(2, 20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')])

def test_no_header_no_fieldnames(self):
ModelClass = self.load(
Expand All @@ -117,7 +117,7 @@ def test_mismatch_types(self):

def test_fields(self):
fields = [
IntegerField(),
PrimaryKeyField(),
CharField(),
DateField(),
FloatField(),
Expand All @@ -129,7 +129,7 @@ def test_fields(self):
fields=fields)
self.assertEqual(
list(map(type, fields)),
list(map(type, ModelClass._meta.get_fields()[1:])))
list(map(type, ModelClass._meta.get_fields())))
self.assertData(ModelClass, [
(10, 'F1 L1', date(1983, 1, 1), 10000., 't'),
(20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')])
Expand Down

1 comment on commit 3b29611

@MartynBliss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Please sign in to comment.