Skip to content

Commit

Permalink
Merge pull request #85 from perpetua1/master
Browse files Browse the repository at this point in the history
read_frame: Fix fieldname deduplication bug
  • Loading branch information
chrisdev committed Sep 12, 2017
2 parents cf39363 + 5298ac2 commit 538e095
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
32 changes: 15 additions & 17 deletions django_pandas/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def read_frame(qs, fieldnames=(), index_col=None, coerce_float=False,
"""

if fieldnames:
fieldnames = pd.unique(fieldnames)
if index_col is not None and index_col not in fieldnames:
# Add it to the field names if not already there
fieldnames = tuple(fieldnames) + (index_col,)
Expand All @@ -87,34 +88,31 @@ def read_frame(qs, fieldnames=(), index_col=None, coerce_float=False,
if annotation_field_names is None:
annotation_field_names = []

extra_names = qs.extra_names
if extra_names is None:
extra_names = []
extra_field_names = qs.extra_names
if extra_field_names is None:
extra_field_names = []

fieldnames = qs.field_names + annotation_field_names + extra_names

fields = [None if '__' in f else qs.model._meta.get_field(f)
for f in qs.field_names] + \
[None] * (len(annotation_field_names) + len(extra_names))
select_field_names = qs.field_names

else:
annotation_field_names = list(qs.query.annotation_select)

select_field_names = list(qs.query.values_select)
extra_field_names = list(qs.query.extra_select)
select_field_names = list(qs.query.values_select)

fieldnames = select_field_names + annotation_field_names \
+ extra_field_names
fieldnames = select_field_names + annotation_field_names + \
extra_field_names
fields = [None if '__' in f else qs.model._meta.get_field(f)
for f in select_field_names] + \
[None] * (len(annotation_field_names) + len(extra_field_names))

fields = [None if '__' in f else qs.model._meta.get_field(f)
for f in select_field_names] + \
[None] * (len(annotation_field_names) + len(extra_field_names))
uniq_fields = set()
fieldnames, fields = zip(
*(f for f in zip(fieldnames, fields)
if f[0] not in uniq_fields and not uniq_fields.add(f[0])))
else:
fields = qs.model._meta.fields
fieldnames = [f.name for f in fields]

fieldnames = pd.unique(fieldnames)

if is_values_queryset(qs):
recs = list(qs)
else:
Expand Down
17 changes: 17 additions & 0 deletions django_pandas/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def test_verbose(self):
df2.trader.tolist()
)

def test_verbose_duplicates_fieldnames(self):
qs = TradeLog.objects.all()
df = read_frame(qs, fieldnames=['trader', 'trader', 'price'])
self.assertListEqual(
list(qs.values_list('price', flat=True)),
df.price.tolist()
)

def test_verbose_duplicate_values(self):
qs = TradeLog.objects.all()
qs = qs.values('trader', 'trader', 'price')
df = read_frame(qs)
self.assertListEqual(
list(qs.values_list('price', flat=True)),
df.price.tolist()
)

def test_related_selected_field(self):
qs = TradeLog.objects.all().values('trader__name')
df = read_frame(qs)
Expand Down

0 comments on commit 538e095

Please sign in to comment.