Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Allow reading from buffered stdout using np.fromfile #12324

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 39 additions & 50 deletions numpy/_core/include/numpy/npy_3kcompat.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ PyUnicode_Concat2(PyObject **left, PyObject *right)
static inline FILE*
npy_PyFile_Dup2(PyObject *file, char *mode, npy_off_t *orig_pos)
{
int fd, fd2, unbuf;
int fd, fd2, seekable;
Py_ssize_t fd2_tmp;
PyObject *ret, *os, *io, *io_raw;
PyObject *ret, *os;
npy_off_t pos;
FILE *handle;

Expand All @@ -223,6 +223,17 @@ npy_PyFile_Dup2(PyObject *file, char *mode, npy_off_t *orig_pos)
return PyFile_AsFile(file);
}
#endif
// Check for seekability before performing any file operations
// in case of error.
ret = PyObject_CallMethod(file, "seekable", NULL);
if (ret == NULL){
return NULL;
}
seekable = PyObject_IsTrue(ret);
Py_DECREF(ret);
if (seekable == -1){
return NULL;
}

/* Flush first to ensure things end up in the file in the correct order */
ret = PyObject_CallMethod(file, "flush", "");
Expand Down Expand Up @@ -276,33 +287,18 @@ npy_PyFile_Dup2(PyObject *file, char *mode, npy_off_t *orig_pos)
return NULL;
}

if (seekable == 0) {
/* Set the original pos as invalid when the object is not seekable */
*orig_pos = -1;
return handle;
}

/* Record the original raw file handle position */
*orig_pos = npy_ftell(handle);
if (*orig_pos == -1) {
/* The io module is needed to determine if buffering is used */
io = PyImport_ImportModule("io");
if (io == NULL) {
fclose(handle);
return NULL;
}
/* File object instances of RawIOBase are unbuffered */
io_raw = PyObject_GetAttrString(io, "RawIOBase");
Py_DECREF(io);
if (io_raw == NULL) {
fclose(handle);
return NULL;
}
unbuf = PyObject_IsInstance(file, io_raw);
Py_DECREF(io_raw);
if (unbuf == 1) {
/* Succeed if the IO is unbuffered */
return handle;
}
else {
PyErr_SetString(PyExc_IOError, "obtaining file position failed");
fclose(handle);
return NULL;
}
PyErr_SetString(PyExc_IOError, "obtaining file position failed");
fclose(handle);
return NULL;
}

/* Seek raw handle to the Python-side position */
Expand Down Expand Up @@ -331,8 +327,8 @@ npy_PyFile_Dup2(PyObject *file, char *mode, npy_off_t *orig_pos)
static inline int
npy_PyFile_DupClose2(PyObject *file, FILE* handle, npy_off_t orig_pos)
{
int fd, unbuf;
PyObject *ret, *io, *io_raw;
int fd, seekable;
PyObject *ret;
npy_off_t position;

/* For Python 2 PyFileObject, do nothing */
Expand All @@ -356,29 +352,22 @@ npy_PyFile_DupClose2(PyObject *file, FILE* handle, npy_off_t orig_pos)
return -1;
}

if (npy_lseek(fd, orig_pos, SEEK_SET) == -1) {
ret = PyObject_CallMethod(file, "seekable", NULL);
if (ret == NULL){
return -1;
}
seekable = PyObject_IsTrue(ret);
Py_DECREF(ret);
if (seekable == -1){
return -1;
}
else if (seekable == 0) {
return 0;
}

/* The io module is needed to determine if buffering is used */
io = PyImport_ImportModule("io");
if (io == NULL) {
return -1;
}
/* File object instances of RawIOBase are unbuffered */
io_raw = PyObject_GetAttrString(io, "RawIOBase");
Py_DECREF(io);
if (io_raw == NULL) {
return -1;
}
unbuf = PyObject_IsInstance(file, io_raw);
Py_DECREF(io_raw);
if (unbuf == 1) {
/* Succeed if the IO is unbuffered */
return 0;
}
else {
PyErr_SetString(PyExc_IOError, "seeking file failed");
return -1;
}
if (npy_lseek(fd, orig_pos, SEEK_SET) == -1) {
PyErr_SetString(PyExc_IOError, "seeking file failed");
return -1;
}

if (position == -1) {
Expand Down
41 changes: 40 additions & 1 deletion numpy/_core/tests/test_longdouble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
temppath, IS_MUSL
)
from numpy._core.tests._locales import CommaDecimalPointLocale
import subprocess as sp
import sys


LD_INFO = np.finfo(np.longdouble)
Expand Down Expand Up @@ -234,7 +236,44 @@ def test_fromfile_complex(self):
res = np.fromfile(path, dtype=ctype, sep=",")
assert_equal(res, np.array([1.j]))


@pytest.mark.parametrize('mode', ['numpy', 'python'])
@pytest.mark.parametrize(
'buffer_type',
['buffered', 'unbuffered'])
def test_fromfile_buffered_unseekable(self, buffer_type, mode):
# stdout is quite a unique file descripter as it can be buffered, and
# unseekable
if buffer_type == 'buffered':
bufsize = -1
else:
bufsize = 0
s1 = b"numpy is not np"
s2 = b" and rain is falling"
s = s1 + s2
p = sp.Popen([sys.executable, '-c',
'import sys; sys.stdout.buffer.write(' + repr(s) + ')'
],
stdout=sp.PIPE, bufsize=bufsize)
if mode == 'python':
buf = p.stdout.read(len(s1))
arr = np.frombuffer(buf, dtype=np.uint8)
elif mode == 'numpy':
arr = np.fromfile(p.stdout, dtype=np.uint8, count=len(s1))
else:
raise ValueError('Unknown mode {}'.format(mode))
assert arr.tobytes() == s1

# Read the rest of the buffer
# We can't use `count=-1` because stdout doesn't support
# ftell and/or npy_fseek(fp, 0, SEEK_END)
if mode == 'python':
buf = p.stdout.read(len(s2))
arr = np.frombuffer(buf, dtype=np.uint8)
elif mode == 'numpy':
arr = np.fromfile(p.stdout, dtype=np.uint8, count=len(s2))
else:
raise ValueError('Unknown mode {}'.format(mode))
assert arr.tobytes() == s2

@pytest.mark.skipif(string_to_longdouble_inaccurate,
reason="Need strtold_l")
Expand Down
Loading