From aad73dfe6e29f71ce2a32fd899ee2969fb2ba823 Mon Sep 17 00:00:00 2001 From: Haowen Xu Date: Wed, 18 Jul 2018 11:49:39 +0800 Subject: [PATCH 1/3] finished exec_proc --- mltoolkit/datafs/__init__.py | 5 +- mltoolkit/datafs/archivefs.py | 2 +- mltoolkit/datafs/base.py | 3 +- mltoolkit/datafs/localfs.py | 3 +- mltoolkit/datafs/mongofs.py | 121 +--------------- mltoolkit/report/report.py | 19 ++- mltoolkit/utils/__init__.py | 9 +- mltoolkit/utils/exec_proc.py | 131 ++++++++++++++++++ .../{datafs/utils.py => utils/file_utils.py} | 0 mltoolkit/utils/mongo_binder.py | 131 ++++++++++++++++++ tests/datafs/standard_checks.py | 1 + tests/datafs/test_base.py | 2 +- tests/datafs/test_localfs.py | 2 +- tests/datafs/test_mongofs.py | 26 +--- tests/helper.py | 34 +++++ tests/utils/test_exec_proc.py | 124 +++++++++++++++++ .../test_file_utils.py} | 3 +- tests/utils/test_mongo_binder.py | 37 +++++ 18 files changed, 499 insertions(+), 154 deletions(-) create mode 100644 mltoolkit/utils/exec_proc.py rename mltoolkit/{datafs/utils.py => utils/file_utils.py} (100%) create mode 100644 mltoolkit/utils/mongo_binder.py create mode 100644 tests/helper.py create mode 100644 tests/utils/test_exec_proc.py rename tests/{datafs/test_utils.py => utils/test_file_utils.py} (96%) create mode 100644 tests/utils/test_mongo_binder.py diff --git a/mltoolkit/datafs/__init__.py b/mltoolkit/datafs/__init__.py index c6d375b..e0f9374 100644 --- a/mltoolkit/datafs/__init__.py +++ b/mltoolkit/datafs/__init__.py @@ -1,7 +1,7 @@ -from . import (archivefs, base, errors, localfs, mongofs, utils) +from . import (archivefs, base, errors, localfs, mongofs) __all__ = sum( - [m.__all__ for m in [archivefs, base, errors, localfs, mongofs, utils]], + [m.__all__ for m in [archivefs, base, errors, localfs, mongofs]], [] ) @@ -10,7 +10,6 @@ from .errors import * from .localfs import * from .mongofs import * -from .utils import * try: from . import dataflow diff --git a/mltoolkit/datafs/archivefs.py b/mltoolkit/datafs/archivefs.py index 9312f9e..f5c0f60 100644 --- a/mltoolkit/datafs/archivefs.py +++ b/mltoolkit/datafs/archivefs.py @@ -2,9 +2,9 @@ import tarfile import zipfile +from mltoolkit.utils import ActiveFiles, maybe_close from .base import * from .errors import UnsupportedOperation, InvalidOpenMode, DataFileNotExist -from .utils import ActiveFiles, maybe_close __all__ = ['TarArchiveFS', 'ZipArchiveFS'] diff --git a/mltoolkit/datafs/base.py b/mltoolkit/datafs/base.py index a26c52d..715bacb 100644 --- a/mltoolkit/datafs/base.py +++ b/mltoolkit/datafs/base.py @@ -3,9 +3,8 @@ import six -from mltoolkit.utils import DocInherit, AutoInitAndCloseable +from mltoolkit.utils import maybe_close, DocInherit, AutoInitAndCloseable from .errors import UnsupportedOperation, DataFileNotExist -from .utils import maybe_close __all__ = [ 'DataFSCapacity', diff --git a/mltoolkit/datafs/localfs.py b/mltoolkit/datafs/localfs.py index 0f3fb88..1bf6587 100644 --- a/mltoolkit/datafs/localfs.py +++ b/mltoolkit/datafs/localfs.py @@ -1,7 +1,6 @@ import os -from mltoolkit.utils import makedirs -from .utils import ActiveFiles, iter_files +from mltoolkit.utils import makedirs, ActiveFiles, iter_files from .base import DataFS, DataFSCapacity from .errors import InvalidOpenMode, UnsupportedOperation, DataFileNotExist diff --git a/mltoolkit/datafs/mongofs.py b/mltoolkit/datafs/mongofs.py index e4b2078..600663b 100644 --- a/mltoolkit/datafs/mongofs.py +++ b/mltoolkit/datafs/mongofs.py @@ -1,19 +1,16 @@ import six -from gridfs import GridFS, GridFSBucket -from pymongo import MongoClient, CursorType -from pymongo.database import Database -from pymongo.collection import Collection +from pymongo import CursorType +from mltoolkit.utils import MongoBinder from .base import DataFS, DataFSCapacity from .errors import DataFileNotExist, InvalidOpenMode, MetaKeyNotExist -from .utils import ActiveFiles __all__ = ['MongoFS'] META_FIELD = 'metadata' -class MongoFS(DataFS): +class MongoFS(DataFS, MongoBinder): """ MongoDB GridFS based :class:`DataFS`. @@ -33,21 +30,10 @@ def __init__(self, conn_str, db_name, coll_name, strict=False): strict (bool): Whether or not this :class:`DataFS` works in strict mode? (default :obj:`False`) """ - super(MongoFS, self).__init__( - capacity=DataFSCapacity.ALL, - strict=strict - ) - - self._conn_str = conn_str - self._db_name = db_name - self._coll_name = coll_name - self._fs_coll_name = '{}.files'.format(coll_name) - - self._client = None # type: MongoClient - self._db = None # type: Database - self._gridfs = None # type: GridFS - self._gridfs_bucket = None # type: GridFSBucket - self._collection = None # type: Collection + DataFS.__init__( + self, capacity=DataFSCapacity.ALL, strict=strict) + MongoBinder.__init__( + self, conn_str=conn_str, db_name=db_name, coll_name=coll_name) if self.strict: def get_meta_value(r, m, k): @@ -58,7 +44,6 @@ def get_meta_value(r, m, k): get_meta_value = lambda r, m, k: m.get(k) self._get_meta_value_from_record = get_meta_value - self._active_files = ActiveFiles() def _make_query_project(self, meta_keys=None, _id=1, filename=1): ret = {'_id': _id, 'filename': filename} @@ -81,98 +66,6 @@ def _make_result_meta(self, record, meta_keys): return tuple(self._get_meta_value_from_record(record, meta_dict, k) for k in meta_keys) - @property - def conn_str(self): - """Get the MongoDB connection string.""" - return self._conn_str - - @property - def db_name(self): - """Get the MongoDB database name.""" - return self._db_name - - @property - def coll_name(self): - """Get the collection name (prefix) of the GridFS.""" - return self._coll_name - - @property - def client(self): - """ - Get the MongoDB client. Reading this property will force - the internal states of :class:`MongoFS` to be initialized. - - Returns: - MongoClient: The MongoDB client. - """ - self.init() - return self._client - - @property - def db(self): - """ - Get the MongoDB database object. Reading this property will force - the internal states of :class:`MongoFS` to be initialized. - - Returns: - Database: The MongoDB database object. - """ - self.init() - return self._db - - @property - def gridfs(self): - """ - Get the MongoDB GridFS client. Reading this property will force - the internal states of :class:`MongoFS` to be initialized. - - Returns: - GridFS: The MongoDB GridFS client. - """ - self.init() - return self._gridfs - - @property - def gridfs_bucket(self): - """ - Get the MongoDB GridFS bucket. Reading this property will force - the internal states of :class:`MongoFS` to be initialized. - - Returns: - GridFSBucket: The MongoDB GridFS bucket. - """ - self.init() - return self._gridfs_bucket - - @property - def collection(self): - """ - Get the MongoDB collection object. Reading this property will force - the internal states of :class:`MongoFS` to be initialized. - - Returns: - Collection: The MongoDB collection object. - """ - self.init() - return self._collection - - def _init(self): - self._client = MongoClient(self._conn_str) - self._db = self._client.get_database(self._db_name) - self._collection = self._db[self._coll_name] - self._gridfs = GridFS(self._db, self._coll_name) - self._gridfs_bucket = GridFSBucket(self._db, self._coll_name) - - def _close(self): - self._active_files.close_all() - try: - if self._client is not None: - self._client.close() - finally: - self._gridfs = None - self._db = None - self._client = None - def clone(self): return MongoFS(self.conn_str, self.db_name, self.coll_name, strict=self.strict) diff --git a/mltoolkit/report/report.py b/mltoolkit/report/report.py index d98802f..f61e02f 100644 --- a/mltoolkit/report/report.py +++ b/mltoolkit/report/report.py @@ -1,8 +1,9 @@ -import codecs import os + import jinja2 +import six -from mltoolkit.report.container import Container +from .container import Container __all__ = ['Report'] @@ -78,12 +79,18 @@ def to_html(self): styles=styles, scripts=scripts ) - def save(self, path): + def save(self, path_or_file): """ Save the rendered HTML as file. Args: - path (str): The path of the HTML file. + path_or_file: The file path, or a file-like object to write. """ - with codecs.open(path, 'wb', 'utf-8') as f: - f.write(self.to_html()) + if hasattr(path_or_file, 'write'): + s = self.to_html() + if isinstance(s, six.text_type): + s = s.encode('utf-8') + path_or_file.write(s) + else: + with open(path_or_file, 'wb') as f: + self.save(f) diff --git a/mltoolkit/utils/__init__.py b/mltoolkit/utils/__init__.py index 7735ee6..435e8a3 100644 --- a/mltoolkit/utils/__init__.py +++ b/mltoolkit/utils/__init__.py @@ -1,10 +1,15 @@ -from . import (concepts, doc_inherit, imported) +from . import (concepts, doc_inherit, exec_proc, file_utils, imported, + mongo_binder) __all__ = sum( - [m.__all__ for m in [concepts, doc_inherit, imported]], + [m.__all__ for m in [concepts, doc_inherit, exec_proc, file_utils, imported, + mongo_binder]], [] ) from .concepts import * from .doc_inherit import * +from .exec_proc import * +from .file_utils import * from .imported import * +from .mongo_binder import * diff --git a/mltoolkit/utils/exec_proc.py b/mltoolkit/utils/exec_proc.py new file mode 100644 index 0000000..9d8b11e --- /dev/null +++ b/mltoolkit/utils/exec_proc.py @@ -0,0 +1,131 @@ +import os +import signal +import subprocess +import sys +import time +from contextlib import contextmanager +from threading import Thread + +__all__ = ['timed_wait_proc', 'exec_proc'] + + +if sys.version_info[:2] >= (3, 3): + def timed_wait_proc(proc, timeout): + try: + return proc.wait(timeout) + except subprocess.TimeoutExpired: + return None +else: + def timed_wait_proc(proc, timeout): + itv = min(timeout * .1, .5) + tot = 0. + exit_code = None + while tot + 1e-7 < timeout and exit_code is None: + exit_code = proc.poll() + if exit_code is None: + time.sleep(itv) + tot += itv + return exit_code + + +@contextmanager +def exec_proc(args, on_stdout=None, on_stderr=None, stderr_to_stdout=False, + buffer_size=16*1024, ctrl_c_timeout=3, kill_timeout=60, **kwargs): + """ + Execute an external program within a context. + + Args: + args: Arguments of the program. + on_stdout ((bytes) -> None): Callback for capturing stdout. + on_stderr ((bytes) -> None): Callback for capturing stderr. + stderr_to_stdout (bool): Whether or not to redirect stderr to + stdout? If specified, `on_stderr` will be ignored. + (default :obj:`False`) + buffer_size (int): Size of buffers for reading from stdout and stderr. + ctrl_c_timeout (int): Seconds to wait for the program to + respond to CTRL+C signal. (default 3) + kill_timeout (int): Seconds to wait for the program to terminate after + being killed. (default 60) + **kwargs: Other named arguments passed to :func:`subprocess.Popen`. + + Yields: + subprocess.Popen: The process object. + """ + # check the arguments + if stderr_to_stdout: + kwargs['stderr'] = subprocess.STDOUT + on_stderr = None + if on_stdout is not None: + kwargs['stdout'] = subprocess.PIPE + if on_stderr is not None: + kwargs['stderr'] = subprocess.PIPE + + # output reader + def reader_func(fd, action): + while not giveup_waiting[0]: + buf = os.read(fd, buffer_size) + if not buf: + break + action(buf) + + def make_reader_thread(fd, action): + th = Thread(target=reader_func, args=(fd, action)) + th.daemon = True + th.start() + return th + + # internal flags + giveup_waiting = [False] + + # launch the process + stdout_thread = None # type: Thread + stderr_thread = None # type: Thread + proc = subprocess.Popen(args, **kwargs) + + try: + if on_stdout is not None: + stdout_thread = make_reader_thread(proc.stdout.fileno(), on_stdout) + if on_stderr is not None: + stderr_thread = make_reader_thread(proc.stderr.fileno(), on_stderr) + + try: + yield proc + except KeyboardInterrupt: # pragma: no cover + if proc.poll() is None: + # Wait for a while to ensure the program has properly dealt + # with the interruption signal. This will help to capture + # the final output of the program. + # TODO: use signal.signal instead for better treatment + _ = timed_wait_proc(proc, 1) + + finally: + if proc.poll() is None: + # First, try to interrupt the process with Ctrl+C signal + ctrl_c_signal = (signal.SIGINT if sys.platform != 'win32' + else signal.CTRL_C_EVENT) + os.kill(proc.pid, ctrl_c_signal) + if timed_wait_proc(proc, ctrl_c_timeout) is None: + # If the Ctrl+C signal does not work, terminate it. + proc.kill() + # Finally, wait for at most 60 seconds + if timed_wait_proc(proc, kill_timeout) is None: # pragma: no cover + giveup_waiting[0] = True + + # Close the pipes such that the reader threads will ensure to exit, + # if we decide to give up waiting. + def close_pipes(): + for f in (proc.stdout, proc.stderr, proc.stdin): + if f is not None: + f.close() + + if giveup_waiting[0]: # pragma: no cover + close_pipes() + + # Wait for the reader threads to exit + for th in (stdout_thread, stderr_thread): + if th is not None: + th.join() + + # Ensure all the pipes are closed. + if not giveup_waiting[0]: + close_pipes() diff --git a/mltoolkit/datafs/utils.py b/mltoolkit/utils/file_utils.py similarity index 100% rename from mltoolkit/datafs/utils.py rename to mltoolkit/utils/file_utils.py diff --git a/mltoolkit/utils/mongo_binder.py b/mltoolkit/utils/mongo_binder.py new file mode 100644 index 0000000..1c81ae7 --- /dev/null +++ b/mltoolkit/utils/mongo_binder.py @@ -0,0 +1,131 @@ +from gridfs import GridFS, GridFSBucket +from pymongo import MongoClient +from pymongo.database import Database +from pymongo.collection import Collection + +from .file_utils import ActiveFiles +from .concepts import AutoInitAndCloseable + +__all__ = ['MongoBinder'] + + +class MongoBinder(AutoInitAndCloseable): + """ + Base class for MongoDB data binder. + + A MongoDB data binder may save and load data in a MongoDB. + This class provides the basic interface for accessing the MongoDB. + """ + + def __init__(self, conn_str, db_name, coll_name): + """ + Initialize the internal states of :class:`MongoBinder`. + + Args: + conn_str (str): The MongoDB connection string. + db_name (str): The MongoDB database name. + coll_name (str): The collection name (prefix) of the GridFS. + """ + self._conn_str = conn_str + self._db_name = db_name + self._coll_name = coll_name + self._fs_coll_name = '{}.files'.format(coll_name) + + self._client = None # type: MongoClient + self._db = None # type: Database + self._gridfs = None # type: GridFS + self._gridfs_bucket = None # type: GridFSBucket + self._collection = None # type: Collection + self._active_files = ActiveFiles() + + @property + def conn_str(self): + """Get the MongoDB connection string.""" + return self._conn_str + + @property + def db_name(self): + """Get the MongoDB database name.""" + return self._db_name + + @property + def coll_name(self): + """Get the collection name (prefix) of the GridFS.""" + return self._coll_name + + @property + def client(self): + """ + Get the MongoDB client. Reading this property will force + the internal states of :class:`MongoFS` to be initialized. + + Returns: + MongoClient: The MongoDB client. + """ + self.init() + return self._client + + @property + def db(self): + """ + Get the MongoDB database object. Reading this property will force + the internal states of :class:`MongoFS` to be initialized. + + Returns: + Database: The MongoDB database object. + """ + self.init() + return self._db + + @property + def gridfs(self): + """ + Get the MongoDB GridFS client. Reading this property will force + the internal states of :class:`MongoFS` to be initialized. + + Returns: + GridFS: The MongoDB GridFS client. + """ + self.init() + return self._gridfs + + @property + def gridfs_bucket(self): + """ + Get the MongoDB GridFS bucket. Reading this property will force + the internal states of :class:`MongoFS` to be initialized. + + Returns: + GridFSBucket: The MongoDB GridFS bucket. + """ + self.init() + return self._gridfs_bucket + + @property + def collection(self): + """ + Get the MongoDB collection object. Reading this property will force + the internal states of :class:`MongoFS` to be initialized. + + Returns: + Collection: The MongoDB collection object. + """ + self.init() + return self._collection + + def _init(self): + self._client = MongoClient(self._conn_str) + self._db = self._client.get_database(self._db_name) + self._collection = self._db[self._coll_name] + self._gridfs = GridFS(self._db, self._coll_name) + self._gridfs_bucket = GridFSBucket(self._db, self._coll_name) + + def _close(self): + self._active_files.close_all() + try: + if self._client is not None: + self._client.close() + finally: + self._gridfs = None + self._db = None + self._client = None diff --git a/tests/datafs/standard_checks.py b/tests/datafs/standard_checks.py index f2f1122..8caa789 100644 --- a/tests/datafs/standard_checks.py +++ b/tests/datafs/standard_checks.py @@ -6,6 +6,7 @@ from mock import Mock from mltoolkit.datafs import * +from mltoolkit.utils import maybe_close class StandardFSChecks(object): diff --git a/tests/datafs/test_base.py b/tests/datafs/test_base.py index 93ad9fb..81ba6cb 100644 --- a/tests/datafs/test_base.py +++ b/tests/datafs/test_base.py @@ -11,7 +11,7 @@ from mltoolkit.datafs import * from mltoolkit.datafs import UnsupportedOperation, DataFileNotExist -from mltoolkit.utils import TemporaryDirectory, makedirs +from mltoolkit.utils import TemporaryDirectory, makedirs, iter_files from .standard_checks import StandardFSChecks, LocalFS from .test_dataflow import _DummyDataFS diff --git a/tests/datafs/test_localfs.py b/tests/datafs/test_localfs.py index 68ab5f6..82b08f3 100644 --- a/tests/datafs/test_localfs.py +++ b/tests/datafs/test_localfs.py @@ -5,7 +5,7 @@ import pytest import six -from mltoolkit.utils import TemporaryDirectory, makedirs +from mltoolkit.utils import TemporaryDirectory, makedirs, iter_files from mltoolkit.datafs import * from .standard_checks import StandardFSChecks diff --git a/tests/datafs/test_mongofs.py b/tests/datafs/test_mongofs.py index aac13c8..b3f206a 100644 --- a/tests/datafs/test_mongofs.py +++ b/tests/datafs/test_mongofs.py @@ -1,7 +1,5 @@ import gc -import subprocess import unittest -import uuid from contextlib import contextmanager from io import BytesIO @@ -12,14 +10,15 @@ from pymongo.database import Database from mltoolkit.datafs import * +from mltoolkit.utils import maybe_close from .standard_checks import StandardFSChecks +from ..helper import temporary_mongodb class MongoFSTestCase(unittest.TestCase, StandardFSChecks): def get_snapshot(self, fs): - conn_str = 'mongodb://root:123456@127.0.0.1:27017/admin' - client = MongoClient(conn_str) + client = MongoClient(fs.conn_str) try: ret = {} database = client.get_database('admin') @@ -40,18 +39,7 @@ def get_snapshot(self, fs): @contextmanager def temporary_fs(self, snapshot=None, **kwargs): - daemon_name = uuid.uuid4().hex - subprocess.check_call([ - 'docker', 'run', '--rm', '-d', - '--name', daemon_name, - '-e', 'MONGO_INITDB_ROOT_USERNAME=root', - '-e', 'MONGO_INITDB_ROOT_PASSWORD=123456', - '-p', '27017:27017', - 'mongo' - ]) - print('Docker daemon started: {!r}'.format(daemon_name)) - try: - conn_str = 'mongodb://root:123456@127.0.0.1:27017/admin' + with temporary_mongodb() as conn_str: if snapshot: client = MongoClient(conn_str) try: @@ -74,21 +62,19 @@ def temporary_fs(self, snapshot=None, **kwargs): client.close() with MongoFS(conn_str, 'admin', 'test', **kwargs) as fs: yield fs - finally: - subprocess.check_call(['docker', 'kill', daemon_name]) def test_standard(self): self.run_standard_checks(DataFSCapacity.ALL) def test_mongofs_props_and_methods(self): with self.temporary_fs() as fs: - # test auto-init + # test auto-init on cloned objects self.assertIsInstance(fs.clone().client, MongoClient) self.assertIsInstance(fs.clone().db, Database) self.assertIsInstance(fs.clone().gridfs, GridFS) self.assertIsInstance(fs.clone().gridfs_bucket, GridFSBucket) self.assertIsInstance(fs.clone().collection, Collection) - gc.collect() + gc.collect() # cleanup cloned objects if __name__ == '__main__': diff --git a/tests/helper.py b/tests/helper.py new file mode 100644 index 0000000..963beb3 --- /dev/null +++ b/tests/helper.py @@ -0,0 +1,34 @@ +import subprocess +import uuid +from contextlib import contextmanager + +import six + +__all__ = ['temporary_mongodb'] + + +@contextmanager +def temporary_mongodb(): + """ + Open a temporary MongoDB server for testing, within context. + + Yields: + str: The connection string to the server. + """ + daemon_name = uuid.uuid4().hex + output = subprocess.check_output([ + 'docker', 'run', '--rm', '-d', + '--name', daemon_name, + '-e', 'MONGO_INITDB_ROOT_USERNAME=root', + '-e', 'MONGO_INITDB_ROOT_PASSWORD=123456', + '-p', '27017:27017', + 'mongo' + ]) + if not (isinstance(output, six.binary_type) and six.PY2): + output = output.decode('utf-8') + print('Docker daemon started: {}'.format(output.strip())) + try: + conn_str = 'mongodb://root:123456@127.0.0.1:27017/admin' + yield conn_str + finally: + _ = subprocess.check_output(['docker', 'kill', daemon_name]) diff --git a/tests/utils/test_exec_proc.py b/tests/utils/test_exec_proc.py new file mode 100644 index 0000000..95d39bb --- /dev/null +++ b/tests/utils/test_exec_proc.py @@ -0,0 +1,124 @@ +import io +import os +import subprocess +import sys +import time +import unittest +from tempfile import TemporaryDirectory + +from mltoolkit.utils import timed_wait_proc, exec_proc + + +def _strip(text): + return '\n'.join(l.lstrip()[1:] for l in text.split('\n')) + + +class TimedWaitTestCase(unittest.TestCase): + + def test_timed_wait_proc(self): + # test no wait + wait_time = -time.time() + proc = subprocess.Popen( + [sys.executable, '-c', 'import sys; sys.exit(123)']) + self.assertEquals(123, timed_wait_proc(proc, 1.5)) + wait_time += time.time() + self.assertLess(wait_time, 1.) + + # test wait + wait_time = -time.time() + proc = subprocess.Popen( + [sys.executable, '-c', 'import sys, time; time.sleep(3); ' + 'sys.exit(123)']) + self.assertIsNone(timed_wait_proc(proc, 1.5)) + wait_time += time.time() + self.assertGreater(wait_time, 1.) + self.assertLess(wait_time, 2.) + + +class ExcProcTestCase(unittest.TestCase): + + def test_exec_proc_io(self): + with TemporaryDirectory() as tempdir: + with open(os.path.join(tempdir, 'test_payload.txt'), 'wb') as f: + f.write(b'hello, world!') + + args = ['bash', '-c', 'ls && echo error_message >&2 && exit 123'] + + # test stdout only + stdout = io.BytesIO() + with exec_proc(args, on_stdout=stdout.write, cwd=tempdir) as proc: + proc.wait() + self.assertEquals(123, proc.poll()) + self.assertIn(b'test_payload.txt', stdout.getvalue()) + self.assertNotIn(b'error_message', stdout.getvalue()) + + # test separated stdout and stderr + stdout = io.BytesIO() + stderr = io.BytesIO() + with exec_proc(args, on_stdout=stdout.write, on_stderr=stderr.write, + cwd=tempdir) as proc: + proc.wait() + self.assertEquals(123, proc.poll()) + self.assertIn(b'test_payload.txt', stdout.getvalue()) + self.assertNotIn(b'error_message', stdout.getvalue()) + self.assertNotIn(b'test_payload.txt', stderr.getvalue()) + self.assertIn(b'error_message', stderr.getvalue()) + + # test redirect stderr to stdout + stdout = io.BytesIO() + with exec_proc(args, on_stdout=stdout.write, stderr_to_stdout=True, + cwd=tempdir) as proc: + proc.wait() + self.assertEquals(123, proc.poll()) + self.assertIn(b'test_payload.txt', stdout.getvalue()) + self.assertIn(b'error_message', stdout.getvalue()) + + def test_exec_proc_kill(self): + interruptable = _strip(''' + |import time + |try: + | while True: + | time.sleep(1) + |except KeyboardInterrupt: + | print("kbd interrupt") + |print("exited") + ''') + non_interruptable = _strip(''' + |import time + |while True: + | try: + | time.sleep(1) + | except KeyboardInterrupt: + | print("kbd interrupt") + |print("exited") + ''') + + # test interruptable + stdout = io.BytesIO() + wait_time = -time.time() + with exec_proc( + ['python', '-u', '-c', interruptable], + on_stdout=stdout.write) as proc: + timed_wait_proc(proc, 1.) + wait_time += time.time() + self.assertEquals(b'kbd interrupt\nexited\n', stdout.getvalue()) + self.assertEquals(0, proc.poll()) + self.assertLess(wait_time, 1.5) + + # test non-interruptable, give up waiting + stdout = io.BytesIO() + wait_time = -time.time() + with exec_proc( + ['python', '-u', '-c', non_interruptable], + on_stdout=stdout.write, + ctrl_c_timeout=1) as proc: + timed_wait_proc(proc, 1.) + wait_time += time.time() + self.assertEquals(b'kbd interrupt\n', stdout.getvalue()) + self.assertNotEquals(0, proc.poll()) + self.assertGreaterEqual(wait_time, 2.) # timed_wait + ctrl_c + self.assertLess(wait_time, 2.5) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/datafs/test_utils.py b/tests/utils/test_file_utils.py similarity index 96% rename from tests/datafs/test_utils.py rename to tests/utils/test_file_utils.py index 9222b16..eada0e0 100644 --- a/tests/datafs/test_utils.py +++ b/tests/utils/test_file_utils.py @@ -6,8 +6,7 @@ import pytest from mock import Mock -from mltoolkit.datafs import ActiveFiles, iter_files, maybe_close -from mltoolkit.utils import makedirs, TemporaryDirectory +from mltoolkit.utils import * class ActiveFilesTestCase(unittest.TestCase): diff --git a/tests/utils/test_mongo_binder.py b/tests/utils/test_mongo_binder.py new file mode 100644 index 0000000..353b646 --- /dev/null +++ b/tests/utils/test_mongo_binder.py @@ -0,0 +1,37 @@ +import unittest + +from gridfs import GridFS, GridFSBucket +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.database import Database + +from mltoolkit.utils import * +from ..helper import temporary_mongodb + + +class MongoBinderTestCase(unittest.TestCase): + + def test_props_and_methods(self): + with temporary_mongodb() as conn_str: + binder = MongoBinder(conn_str, 'test', 'fs') + + # test auto-init + client = binder.client + self.assertIsInstance(binder.client, MongoClient) + self.assertIsInstance(binder.db, Database) + self.assertIsInstance(binder.gridfs, GridFS) + self.assertIsInstance(binder.gridfs_bucket, GridFSBucket) + self.assertIsInstance(binder.collection, Collection) + + # test with context + with binder: + self.assertIs(client, binder.client) + self.assertIsNone(binder._client) + + # test re-init + self.assertIsInstance(binder.client, MongoClient) + self.assertIsNot(client, binder.client) + + +if __name__ == '__main__': + unittest.main() From 0bfc736f76f5baf64bff63a4e4c4d6a015f45e6b Mon Sep 17 00:00:00 2001 From: Haowen Xu Date: Wed, 18 Jul 2018 11:57:30 +0800 Subject: [PATCH 2/3] fix py2.7 TemporaryDirectory import in test_exec_proc.py --- tests/utils/test_exec_proc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utils/test_exec_proc.py b/tests/utils/test_exec_proc.py index 95d39bb..e530d8c 100644 --- a/tests/utils/test_exec_proc.py +++ b/tests/utils/test_exec_proc.py @@ -4,9 +4,8 @@ import sys import time import unittest -from tempfile import TemporaryDirectory -from mltoolkit.utils import timed_wait_proc, exec_proc +from mltoolkit.utils import timed_wait_proc, exec_proc, TemporaryDirectory def _strip(text): From 413f1b7cbeaa8deccf5b7f0ff8d4f646e6bca1c2 Mon Sep 17 00:00:00 2001 From: Haowen Xu Date: Wed, 18 Jul 2018 12:03:54 +0800 Subject: [PATCH 3/3] loose the time limits in test_exec_proc.py --- tests/utils/test_exec_proc.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_exec_proc.py b/tests/utils/test_exec_proc.py index e530d8c..807d71c 100644 --- a/tests/utils/test_exec_proc.py +++ b/tests/utils/test_exec_proc.py @@ -26,12 +26,12 @@ def test_timed_wait_proc(self): # test wait wait_time = -time.time() proc = subprocess.Popen( - [sys.executable, '-c', 'import sys, time; time.sleep(3); ' + [sys.executable, '-c', 'import sys, time; time.sleep(10); ' 'sys.exit(123)']) self.assertIsNone(timed_wait_proc(proc, 1.5)) wait_time += time.time() self.assertGreater(wait_time, 1.) - self.assertLess(wait_time, 2.) + self.assertLess(wait_time, 3.) class ExcProcTestCase(unittest.TestCase): @@ -94,29 +94,22 @@ def test_exec_proc_kill(self): # test interruptable stdout = io.BytesIO() - wait_time = -time.time() with exec_proc( ['python', '-u', '-c', interruptable], on_stdout=stdout.write) as proc: timed_wait_proc(proc, 1.) - wait_time += time.time() self.assertEquals(b'kbd interrupt\nexited\n', stdout.getvalue()) self.assertEquals(0, proc.poll()) - self.assertLess(wait_time, 1.5) # test non-interruptable, give up waiting stdout = io.BytesIO() - wait_time = -time.time() with exec_proc( ['python', '-u', '-c', non_interruptable], on_stdout=stdout.write, ctrl_c_timeout=1) as proc: timed_wait_proc(proc, 1.) - wait_time += time.time() self.assertEquals(b'kbd interrupt\n', stdout.getvalue()) self.assertNotEquals(0, proc.poll()) - self.assertGreaterEqual(wait_time, 2.) # timed_wait + ctrl_c - self.assertLess(wait_time, 2.5) if __name__ == '__main__':