Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions django_pandas/io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pandas as pd
from django.db import connections
from pandas.io.sql import execute, _safe_fetch
from .utils import update_with_verbose


Expand Down Expand Up @@ -49,17 +51,25 @@ def read_frame(qs, fieldnames=(), index_col=None, coerce_float=False,
fields = to_fields(qs, fieldnames)
else:
fields = qs.model._meta.fields
fieldnames = [f.name for f in fields]
fieldnames = getattr(qs, '_fields', None) or [f.name for f in fields]

recs = list(qs.values_list(*fieldnames))
qs = qs.values_list(*fieldnames)
compiler = qs.query.get_compiler(using=qs.db)
connection = connections[qs.db]
query, args = compiler.as_sql()

df = pd.DataFrame.from_records(recs, columns=fieldnames,
coerce_float=coerce_float)
# because pandas.io.sql.read_frame always runs con.commit(),
# some code extracted from there.
cur = execute(query, connection, params=args)
rows = _safe_fetch(cur)
cur.close()

if verbose:
update_with_verbose(df, fieldnames, fields)
df = pd.DataFrame.from_records(rows, columns=fieldnames, coerce_float=coerce_float)

if index_col is not None:
df.set_index(index_col, inplace=True)

if verbose:
update_with_verbose(df, fieldnames, fields)

return df
53 changes: 51 additions & 2 deletions django_pandas/managers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from django.db.models.query import QuerySet
from django.db.models.query import QuerySet, ValuesListQuerySet, ValuesQuerySet
from model_utils.managers import PassThroughManager
from .io import read_frame


class DataFrameQuerySet(QuerySet):
class DataFrameMixin(object):

def to_pivot_table(self, fieldnames=(), verbose=True,
values=None, rows=None, cols=None,
Expand Down Expand Up @@ -197,6 +197,55 @@ def to_dataframe(self, fieldnames=(), verbose=True, index=None,
return df


class ValuesMixin(object):
"""
Mixin for overriding return type for values() and values_list().
"""

def values(self, *fields):
return self._clone(klass=DataFrameValuesQuerySet, setup=True, _fields=fields)

def values_list(self, *fields, **kwargs):
flat = kwargs.pop('flat', False)
if kwargs:
raise TypeError('Unexpected keyword arguments to values_list: %s'
% (list(kwargs),))
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
return self._clone(klass=DataFrameValuesListQuerySet, setup=True, flat=flat,
_fields=fields)


class ConstraintValuesMixin(object):
"""
Limits field list for values() and values_list() with original
field list.
"""

def values(self, *fields):
fields = filter(lambda f: f in self._fields, fields)
return super(ConstraintValuesMixin, self).values(*fields)

def values_list(self, *fields, **kwargs):
fields = filter(lambda f: f in self._fields, fields)
return super(ConstraintValuesMixin, self).values_list(*fields, **kwargs)


class DataFrameValuesQuerySet(DataFrameMixin, ConstraintValuesMixin,
ValuesMixin, ValuesQuerySet):
pass


class DataFrameValuesListQuerySet(DataFrameMixin, ConstraintValuesMixin,
ValuesMixin, ValuesListQuerySet):
pass


class DataFrameQuerySet(DataFrameMixin, ValuesMixin, QuerySet):
pass


class DataFrameManager(PassThroughManager):

def get_query_set(self):
return DataFrameQuerySet(self.model)
10 changes: 10 additions & 0 deletions django_pandas/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def test_dataframe(self):
n, c = df2.shape
self.assertEqual((n, c), (3, 3))

def test_values_list_qs(self):
qs = DataFrame.objects.values_list('col1', 'col2')
df = qs.to_dataframe()
self.assertNotIn('id', df.columns.values.tolist())

def test_values_qs(self):
qs = DataFrame.objects.values('col1', 'col2')
df = qs.to_dataframe()
self.assertNotIn('id', df.columns.values.tolist())


class TimeSeriesTest(TestCase):
def unpivot(self, frame):
Expand Down