Skip to content

Commit

Permalink
[FIX] Operate csv following RFC4180 (#814)
Browse files Browse the repository at this point in the history
* [FIX] Operate csv following RFC4180

* [FIX] CRLF

* for windows CRLF

* test guess_orient
  • Loading branch information
bojiang committed Jun 18, 2020
1 parent c89cc94 commit f893fc4
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 52 deletions.
122 changes: 76 additions & 46 deletions bentoml/utils/dataframe_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List
from typing import Iterable, List, Iterator
import sys
import json
import collections
Expand Down Expand Up @@ -45,64 +45,58 @@ def check_dataframe_column_contains(required_column_names, df):
)


def _to_str(v):
if v is None:
return ""
return str(v)


def _from_json_records(state: DataFrameState, table: list):
if not state.line_num: # make header
state.columns = table[0].keys()
yield itertools.chain(('',), state.columns)
yield state.columns

for tr in table:
tds = (tr[c] for c in state.columns) if state.columns else tr.values()
state.line_num += 1
yield itertools.chain((state.line_num - 1,), tds)
yield tds


def _from_json_values(state: DataFrameState, table: list):
if not state.line_num: # make header
yield itertools.chain(('',), range(len(table[0])))
yield range(len(table[0]))

for tr in table:
state.line_num += 1
yield itertools.chain((state.line_num - 1,), tr)
yield tr


def _from_json_columns(state: DataFrameState, table: dict):
if not state.line_num: # make header
state.columns = table.keys()
yield itertools.chain(('',), state.columns)
yield state.columns

for row in next(iter(table.values())):
if state.columns:
tr = (table[col][row] for col in state.columns)
else:
tr = (table[col][row] for col in table.keys())
state.line_num += 1
yield itertools.chain((state.line_num - 1,), tr)
yield tr


def _from_json_index(state: DataFrameState, table: dict):
if not state.line_num: # make header
state.columns = next(iter(table.values())).keys()
yield itertools.chain(('',), state.columns)
yield state.columns

for row in table.keys():
if state.columns:
tr = (table[row][col] for col in state.columns)
else:
tr = (td for td in table[row].values())
state.line_num += 1
yield itertools.chain((state.line_num - 1,), tr)
yield tr


def _from_json_split(state: DataFrameState, table: dict):
if not state.line_num: # make header
state.columns = table['columns']
yield itertools.chain(('',), state.columns)
yield state.columns

if state.columns:
_id_map = {k: i for i, k in enumerate(state.columns)}
Expand All @@ -113,29 +107,70 @@ def _from_json_split(state: DataFrameState, table: dict):
else:
tr = row
state.line_num += 1
yield itertools.chain((state.line_num - 1,), tr)


def _from_csv_with_index(state: DataFrameState, table: List[str]):
if not state.line_num:
state.columns = table[0].split(',')[1:]
yield table[0]
for row_str in table[1:]:
if not row_str.strip(): # skip blank line
continue
state.line_num += 1
yield f"{str(state.line_num - 1)},{row_str.split(',', maxsplit=1)[1]}"


def _from_csv_without_index(state: DataFrameState, table: List[str]):
yield tr


def _csv_split(string, delimiter, maxsplit=None) -> Iterator[str]:
dlen = len(delimiter)
if '"' in string:

def _iter_line(line):
quoted = False
last_cur = 0
split = 0
for i, c in enumerate(line):
if c == '"':
quoted = not quoted
if not quoted and string[i : i + dlen] == delimiter:
yield line[last_cur:i]
last_cur = i + dlen
split += 1
if maxsplit is not None and split == maxsplit:
break
yield line[last_cur:]

return _iter_line(string)
return iter(string.split(delimiter, maxsplit=maxsplit or 0))


def _csv_unquote(string):
if '"' in string:
string = string.strip()
assert string[0] == '"' and string[-1] == '"'
return string[1:-1].replace('""', '"')
return string


def _csv_quote(td):
if td is None:
td = ''
elif not isinstance(td, str):
td = str(td)
if '\n' in td or '"' in td or ',' in td or not td.strip():
return td.replace('"', '""').join('""')
return td


def _from_csv_without_index(state: DataFrameState, table: Iterator[str]):
row_str = next(table) # skip column names
if not state.line_num:
state.columns = table[0].split(',')
yield "," + table[0]
for row_str in table[1:]:
if not row_str.strip(): # skip blank line
if row_str.endswith('\r'):
row_str = row_str[:-1]
state.columns = tuple(_csv_unquote(s) for s in _csv_split(row_str, ','))
if not row_str.strip(): # for special value ' ', which is a bug of pandas
yield _csv_quote(row_str)
else:
yield row_str
for row_str in table:
if row_str.endswith('\r'):
row_str = row_str[:-1]
if not row_str: # skip blank line
continue
state.line_num += 1
yield f"{str(state.line_num - 1)},{row_str.strip()}"
if not row_str.strip():
yield _csv_quote(row_str)
else:
yield row_str


def _detect_orient(table):
Expand Down Expand Up @@ -202,7 +237,7 @@ def _dataframe_csv_from_input(tables, content_types, orients):
table.decode('utf-8'), object_pairs_hook=collections.OrderedDict
)
elif content_type.lower() == "text/csv":
table = table.decode('utf-8').split('\n')
table = _csv_split(table.decode('utf-8'), '\n')
if not table:
continue
else:
Expand Down Expand Up @@ -240,12 +275,8 @@ def _dataframe_csv_from_input(tables, content_types, orients):

continue
elif content_type.lower() == "text/csv":
if table[0].strip().startswith(','): # csv with index column
for line in _from_csv_with_index(state, table):
yield line, table_id if state.line_num else None
else:
for line in _from_csv_without_index(state, table):
yield line, table_id if state.line_num else None
for line in _from_csv_without_index(state, table):
yield line, table_id if state.line_num else None


def _gen_slice(ids):
Expand Down Expand Up @@ -277,18 +308,17 @@ def read_dataframes_from_json_n_csv(
raise MissingDependencyException('pandas required')
try:
rows_csv_with_id = [
(tds if isinstance(tds, str) else ','.join(map(_to_str, tds)), table_id)
(tds if isinstance(tds, str) else ','.join(map(_csv_quote, tds)), table_id)
for tds, table_id in _dataframe_csv_from_input(
datas, content_types, itertools.repeat(orient)
)
if tds is not None
]
except (TypeError, ValueError) as e:
raise BadInput('Invalid input format for DataframeInput') from e

str_csv = [r for r, _ in rows_csv_with_id]
df_str_csv = '\n'.join(str_csv)
df_merged = pd.read_csv(StringIO(df_str_csv), index_col=0)
df_merged = pd.read_csv(StringIO(df_str_csv), index_col=None)

dfs_id = [i for _, i in rows_csv_with_id][1:]
slices = _gen_slice(dfs_id)
Expand Down
52 changes: 46 additions & 6 deletions tests/handlers/test_dataframe_handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# pylint: disable=redefined-outer-name
import itertools
import time
import pytest
import math
import pandas as pd
import numpy as np
import json

from bentoml.utils.dataframe_util import _csv_split, _guess_orient
from bentoml.adapters import DataframeInput
from bentoml.adapters.dataframe_input import (
check_dataframe_column_contains,
Expand Down Expand Up @@ -153,7 +157,8 @@ def assert_df_equal(left: pd.DataFrame, right: pd.DataFrame):
pd.DataFrame(["str1", "str2", "str3"]), # single dim sting array
pd.DataFrame([np.nan]), # special values
pd.DataFrame([math.nan]), # special values
pd.DataFrame([" "]), # special values
pd.DataFrame([" ", 'a"b', "a,b", "a\nb"]), # special values
pd.DataFrame({"test": [" ", 'a"b', "a,b", "a\nb"]}), # special values
# pd.Series(np.random.rand(2)), # TODO: Series support
# pd.DataFrame([""]), # TODO: -> NaN
)
Expand Down Expand Up @@ -187,11 +192,6 @@ def test_batch_read_dataframes_from_mixed_json_n_csv(df):
test_datas, test_types, orient=None
) # auto detect orient

# test content_type=text/csv
test_datas.extend([df.to_csv().encode()] * 3)
test_types.extend(['text/csv'] * 3)

# test content_type=text/csv without index
test_datas.extend([df.to_csv(index=False).encode()] * 3)
test_types.extend(['text/csv'] * 3)

Expand All @@ -200,6 +200,16 @@ def test_batch_read_dataframes_from_mixed_json_n_csv(df):
assert_df_equal(df_merged[s], df)


def test_batch_read_dataframes_from_csv_other_CRLF(df):
csv_str = df.to_csv(index=False)
if '\r\n' in csv_str:
csv_str = '\n'.join(_csv_split(csv_str, '\r\n')).encode()
else:
csv_str = '\r\n'.join(_csv_split(csv_str, '\n')).encode()
df_merged, _ = read_dataframes_from_json_n_csv([csv_str], ['text/csv'])
assert_df_equal(df_merged, df)


def test_batch_read_dataframes_from_json_of_orients(df, orient):
test_datas = [df.to_json(orient=orient).encode()] * 3
test_types = ['application/json'] * 3
Expand Down Expand Up @@ -236,3 +246,33 @@ def test_batch_read_dataframes_from_json_in_mixed_order():
assert_df_equal(
df_merged[s][["A", "B", "C"]], pd.read_json(df_json1)[["A", "B", "C"]]
)


def test_guess_orient(df, orient):
json_str = df.to_json(orient=orient)
guessed_orient = _guess_orient(json.loads(json_str))
assert orient == guessed_orient or orient in guessed_orient


def test_benchmark_load_dataframes():
'''
read_dataframes_from_json_n_csv should be 30x faster than pd.read_json + pd.concat
'''
test_count = 50

dfs = [pd.DataFrame(np.random.rand(10, 100)) for _ in range(test_count)]
inputs = [df.to_json().encode() for df in dfs]

time_st = time.time()
dfs = [pd.read_json(i) for i in inputs]
result1 = pd.concat(dfs)
time1 = time.time() - time_st

time_st = time.time()
result2, _ = read_dataframes_from_json_n_csv(
inputs, itertools.repeat('application/json'), 'columns'
)
time2 = time.time() - time_st

assert_df_equal(result1, result2)
assert time1 / time2 > 20

0 comments on commit f893fc4

Please sign in to comment.