Skip to content

Commit

Permalink
BUG: Set index when reading Stata file
Browse files Browse the repository at this point in the history
Ensures index is set when requested during reading of a Stata dta file
Deprecates and renames index to index_col for API consistence

closes pandas-dev#16342
  • Loading branch information
bashtage committed Aug 29, 2017
1 parent 0d676a3 commit 4e30df8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.21.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ Other API Changes
- :func:`Series.argmin` and :func:`Series.argmax` will now raise a ``TypeError`` when used with ``object`` dtypes, instead of a ``ValueError`` (:issue:`13595`)
- :class:`Period` is now immutable, and will now raise an ``AttributeError`` when a user tries to assign a new value to the ``ordinal`` or ``freq`` attributes (:issue:`17116`).
- :func:`to_datetime` when passed a tz-aware ``origin=`` kwarg will now raise a more informative ``ValueError`` rather than a ``TypeError`` (:issue:`16842`)
- Renamed non-functional ``index`` to ``index_col`` in :func:`read_stata` to improve API consistency (:issue:`16342`)


.. _whatsnew_0210.deprecations:
Expand Down Expand Up @@ -370,6 +371,7 @@ I/O
- Bug in :func:`read_csv` when called with ``low_memory=False`` in which a CSV with at least one column > 2GB in size would incorrectly raise a ``MemoryError`` (:issue:`16798`).
- Bug in :func:`read_csv` when called with a single-element list ``header`` would return a ``DataFrame`` of all NaN values (:issue:`7757`)
- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`)
- Bug in :func:`read_stata` where the index was not set (:issue:`16342`)
- Bug in :func:`read_html` where import check fails when run in multiple threads (:issue:`16928`)

Plotting
Expand Down
61 changes: 34 additions & 27 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,30 @@
You can find more information on http://presbrey.mit.edu/PyDTA and
http://www.statsmodels.org/devel/
"""
import numpy as np

import sys
import datetime
import struct
from dateutil.relativedelta import relativedelta
import sys

from pandas.core.dtypes.common import (
is_categorical_dtype, is_datetime64_dtype,
_ensure_object)
import numpy as np
from dateutil.relativedelta import relativedelta
from pandas._libs.lib import max_len_string_array, infer_dtype
from pandas._libs.tslib import NaT, Timestamp

import pandas as pd
from pandas import compat, to_timedelta, to_datetime, isna, DatetimeIndex
from pandas.compat import (lrange, lmap, lzip, text_type, string_types, range,
zip, BytesIO)
from pandas.core.base import StringMixin
from pandas.core.categorical import Categorical
from pandas.core.dtypes.common import (is_categorical_dtype, _ensure_object,
is_datetime64_dtype)
from pandas.core.frame import DataFrame
from pandas.core.series import Series
import datetime
from pandas import compat, to_timedelta, to_datetime, isna, DatetimeIndex
from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \
zip, BytesIO
from pandas.util._decorators import Appender
import pandas as pd

from pandas.io.common import (get_filepath_or_buffer, BaseIterator,
_stringify_path)
from pandas._libs.lib import max_len_string_array, infer_dtype
from pandas._libs.tslib import NaT, Timestamp
from pandas.util._decorators import Appender
from pandas.util._decorators import deprecate_kwarg

VALID_ENCODINGS = ('ascii', 'us-ascii', 'latin-1', 'latin_1', 'iso-8859-1',
'iso8859-1', '8859', 'cp819', 'latin', 'latin1', 'L1')
Expand All @@ -53,8 +52,8 @@
Encoding used to parse the files. None defaults to latin-1."""

_statafile_processing_params2 = """\
index : identifier of index column
identifier of column that should be used as index of the DataFrame
index_col : string, optional, default: None
Column to set as index
convert_missing : boolean, defaults to False
Flag indicating whether to convert missing values to their Stata
representations. If False, missing values are replaced with nans.
Expand Down Expand Up @@ -159,15 +158,16 @@


@Appender(_read_stata_doc)
@deprecate_kwarg(old_arg_name='index', new_arg_name='index_col')
def read_stata(filepath_or_buffer, convert_dates=True,
convert_categoricals=True, encoding=None, index=None,
convert_categoricals=True, encoding=None, index_col=None,
convert_missing=False, preserve_dtypes=True, columns=None,
order_categoricals=True, chunksize=None, iterator=False):

reader = StataReader(filepath_or_buffer,
convert_dates=convert_dates,
convert_categoricals=convert_categoricals,
index=index, convert_missing=convert_missing,
index_col=index_col, convert_missing=convert_missing,
preserve_dtypes=preserve_dtypes,
columns=columns,
order_categoricals=order_categoricals,
Expand Down Expand Up @@ -944,8 +944,9 @@ def __init__(self, encoding):
class StataReader(StataParser, BaseIterator):
__doc__ = _stata_reader_doc

@deprecate_kwarg(old_arg_name='index', new_arg_name='index_col')
def __init__(self, path_or_buf, convert_dates=True,
convert_categoricals=True, index=None,
convert_categoricals=True, index_col=None,
convert_missing=False, preserve_dtypes=True,
columns=None, order_categoricals=True,
encoding='latin-1', chunksize=None):
Expand All @@ -956,7 +957,7 @@ def __init__(self, path_or_buf, convert_dates=True,
# calls to read).
self._convert_dates = convert_dates
self._convert_categoricals = convert_categoricals
self._index = index
self._index_col = index_col
self._convert_missing = convert_missing
self._preserve_dtypes = preserve_dtypes
self._columns = columns
Expand Down Expand Up @@ -1460,8 +1461,9 @@ def get_chunk(self, size=None):
return self.read(nrows=size)

@Appender(_read_method_doc)
@deprecate_kwarg(old_arg_name='index', new_arg_name='index_col')
def read(self, nrows=None, convert_dates=None,
convert_categoricals=None, index=None,
convert_categoricals=None, index_col=None,
convert_missing=None, preserve_dtypes=None,
columns=None, order_categoricals=None):
# Handle empty file or chunk. If reading incrementally raise
Expand All @@ -1486,6 +1488,8 @@ def read(self, nrows=None, convert_dates=None,
columns = self._columns
if order_categoricals is None:
order_categoricals = self._order_categoricals
if index_col is None:
index_col = self._index_col

if nrows is None:
nrows = self.nobs
Expand Down Expand Up @@ -1524,14 +1528,14 @@ def read(self, nrows=None, convert_dates=None,
self._read_value_labels()

if len(data) == 0:
data = DataFrame(columns=self.varlist, index=index)
data = DataFrame(columns=self.varlist)
else:
data = DataFrame.from_records(data, index=index)
data = DataFrame.from_records(data)
data.columns = self.varlist

# If index is not specified, use actual row number rather than
# restarting at 0 for each chunk.
if index is None:
if index_col is None:
ix = np.arange(self._lines_read - read_lines, self._lines_read)
data = data.set_index(ix)

Expand All @@ -1553,7 +1557,7 @@ def read(self, nrows=None, convert_dates=None,
cols_ = np.where(self.dtyplist)[0]

# Convert columns (if needed) to match input type
index = data.index
ix = data.index
requires_type_conversion = False
data_formatted = []
for i in cols_:
Expand All @@ -1563,7 +1567,7 @@ def read(self, nrows=None, convert_dates=None,
if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
requires_type_conversion = True
data_formatted.append(
(col, Series(data[col], index, self.dtyplist[i])))
(col, Series(data[col], ix, self.dtyplist[i])))
else:
data_formatted.append((col, data[col]))
if requires_type_conversion:
Expand Down Expand Up @@ -1606,6 +1610,9 @@ def read(self, nrows=None, convert_dates=None,
if convert:
data = DataFrame.from_items(retyped_data)

if index_col is not None:
data = data.set_index(data.pop(index_col))

return data

def _do_convert_missing(self, data, convert_missing):
Expand Down
11 changes: 10 additions & 1 deletion pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_read_write_reread_dta15(self):
tm.assert_frame_equal(parsed_114, parsed_117)

def test_timestamp_and_label(self):
original = DataFrame([(1,)], columns=['var'])
original = DataFrame([(1,)], columns=['variable'])
time_stamp = datetime(2000, 2, 29, 14, 21)
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
Expand Down Expand Up @@ -1309,3 +1309,12 @@ def test_value_labels_iterator(self, write_index):
dta_iter = pd.read_stata(path, iterator=True)
value_labels = dta_iter.value_labels()
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}

def test_set_index(self):
# GH 17328
df = tm.makeDataFrame()
df.index.name = 'index'
with tm.ensure_clean() as path:
df.to_stata(path)
reread = pd.read_stata(path, index_col='index')
tm.assert_frame_equal(df, reread)

0 comments on commit 4e30df8

Please sign in to comment.