Skip to content

Commit

Permalink
Fixed the windows line endings in the CSVCallback. (#11124)
Browse files Browse the repository at this point in the history
* Fix a bug to write no-need blank line each line on windows

* Fixed the windows line endings in the CSVCallback.

* Converted strings to unicode when using python 2.

* Fixed the unicode issue by opening in binary mode.

* Used six to simplify the compatibility issues.

* Fixed the py2 issue with endline not available for binary mode.
  • Loading branch information
gabrieldemarmiesse authored and fchollet committed Sep 14, 2018
1 parent 5b62434 commit f60313e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
23 changes: 18 additions & 5 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import time
import json
import warnings
import io
import sys

from collections import deque
from collections import OrderedDict
Expand Down Expand Up @@ -1122,17 +1124,25 @@ def __init__(self, filename, separator=',', append=False):
self.writer = None
self.keys = None
self.append_header = True
self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
if six.PY2:
self.file_flags = 'b'
self._open_args = {}
else:
self.file_flags = ''
self._open_args = {'newline': '\n'}
super(CSVLogger, self).__init__()

def on_train_begin(self, logs=None):
if self.append:
if os.path.exists(self.filename):
with open(self.filename, 'r' + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
self.csv_file = open(self.filename, 'a' + self.file_flags)
mode = 'a'
else:
self.csv_file = open(self.filename, 'w' + self.file_flags)
mode = 'w'
self.csv_file = io.open(self.filename,
mode + self.file_flags,
**self._open_args)

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
Expand All @@ -1156,9 +1166,12 @@ def handle_value(k):
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep

fieldnames = ['epoch'] + self.keys
if six.PY2:
fieldnames = [unicode(x) for x in fieldnames]
self.writer = csv.DictWriter(self.csv_file,
fieldnames=['epoch'] + self.keys, dialect=CustomDialect)
fieldnames=fieldnames,
dialect=CustomDialect)
if self.append_header:
self.writer.writeheader()

Expand Down
8 changes: 6 additions & 2 deletions tests/keras/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,11 +556,15 @@ def make_model():

# case 3, reuse of CSVLogger object
model.fit(X_train, y_train, batch_size=batch_size,
validation_data=(X_test, y_test), callbacks=cbks, epochs=1)
validation_data=(X_test, y_test), callbacks=cbks, epochs=2)

import re
with open(filepath) as csvfile:
output = " ".join(csvfile.readlines())
list_lines = csvfile.readlines()
for line in list_lines:
assert line.count(sep) == 4
assert len(list_lines) == 5
output = " ".join(list_lines)
assert len(re.findall('epoch', output)) == 1

os.remove(filepath)
Expand Down

0 comments on commit f60313e

Please sign in to comment.