Skip to content

Commit

Permalink
ENGPROD-67: Use csv.writer to write data out in _BufferedDataFrame
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 06d37622529a84dcc0caac2aec27df8126440001
  • Loading branch information
pimlock committed Mar 31, 2022
1 parent 4672906 commit ca729e9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
For example usage, please see our Jupyter Notebook.
"""
import abc
import csv
import glob
import gzip
import io
Expand Down Expand Up @@ -376,21 +377,23 @@ def __init__(self, delim: str, columns: List[str], method: str = FILE):
else:
raise ValueError("Invalid method")

self.csv_writer = csv.writer(self.buffer, delimiter=self.delim)

def add(self, record: dict):
# write the columns names into the buffer, we
# use the first dict to specify the order and
# assume subsequent dicts have the same order
if not self.headers_set:
_columns = self.delim.join(record.keys())
self.buffer.write(_columns + "\n")
self.csv_writer.writerow(record.keys())
self.headers_set = True
_row = self.delim.join(record.values())
self.buffer.write(_row + "\n")

self.csv_writer.writerow(record.values())

@property
def df(self) -> pd.DataFrame:
self.buffer.seek(0)
return pd.read_csv(self.buffer, sep=self.delim)[self.columns]
df = pd.read_csv(self.buffer, sep=self.delim)
return df[self.columns]

def get_records(self) -> pd.DataFrame:
return self.df
Expand Down
21 changes: 21 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,27 @@ def test_buffered_df(buffermode):
buffer.cleanup()


@pytest.mark.parametrize("buffermode", [MEMORY, FILE])
def test_buffered_df_with_special_chars(buffermode):
"""
Make sure special characters are handled correctly.
"""
nl = "newline\n"
uni = "unic\U0001f499ode"
comma = "and, comma"

buffer = _BufferedDataFrame(",", [nl, "foo", uni, comma, "bar"], method=buffermode)
buffer.add({"foo": "A", nl: "B", uni: "C", comma: "D", "bar": "12.3"})
buffer.add({"foo": "A", nl: "B", uni: "C", comma: "D", "bar": "56.2"})

df = buffer.df
assert list(df.columns) == [nl, "foo", uni, comma, "bar"]
assert str(df.bar.dtype) == "float64"
assert len(df) == 2

buffer.cleanup()


# bugfix: incomplete records
@pytest.mark.parametrize("buffermode", [MEMORY, FILE])
def test_buffered_df_incomplete_first_record(buffermode):
Expand Down

0 comments on commit ca729e9

Please sign in to comment.