From 9a6efd4e3926e30d7540cf11f80c992a41a78750 Mon Sep 17 00:00:00 2001 From: Yosuke Mizutani Date: Tue, 3 Nov 2015 20:39:41 +0900 Subject: [PATCH 1/2] implement TerminalHandler closes #15 add FakeInput and FakeBytesInput to unittest module --- src/mog_commons/__init__.py | 2 +- src/mog_commons/terminal.py | 196 ++++++++++++++++++++++++ src/mog_commons/unittest.py | 32 ++++ tests/mog_commons/test_terminal.py | 46 ++++++ tests/resources/test_terminal_input.txt | 1 + 5 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 src/mog_commons/terminal.py create mode 100644 tests/mog_commons/test_terminal.py create mode 100644 tests/resources/test_terminal_input.txt diff --git a/src/mog_commons/__init__.py b/src/mog_commons/__init__.py index 112abf1..ac1125f 100644 --- a/src/mog_commons/__init__.py +++ b/src/mog_commons/__init__.py @@ -1 +1 @@ -__version__ = '0.1.14' +__version__ = '0.1.15' diff --git a/src/mog_commons/terminal.py b/src/mog_commons/terminal.py new file mode 100644 index 0000000..f029a8c --- /dev/null +++ b/src/mog_commons/terminal.py @@ -0,0 +1,196 @@ +from __future__ import division, print_function, absolute_import, unicode_literals + +import os +import sys +import codecs +import subprocess +import locale +import platform +import time + +if os.name == 'nt': + # for Windows + import msvcrt +else: + # for Unix/Linux/Mac/CygWin + import termios + import tty + +from mog_commons.case_class import CaseClass +from mog_commons.string import to_unicode + +__all__ = [ + 'TerminalHandler', +] + +DEFAULT_GETCH_REPEAT_THRESHOLD = 0.3 # in seconds + + +class TerminalHandler(CaseClass): + """ + IMPORTANT: When you use this class in POSIX environment, make sure to set signal function for restoring terminal + attributes. The function `restore_terminal` is for that purpose. See the example below. + + :example: + import signal + + t = TerminalHandler() + signal.signal(signal.SIGTERM, t.restore_terminal) + + try: + (do your work) + finally: + t.restore_terminal(None, None) + """ + + def __init__(self, term_type=None, encoding=None, + stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, + getch_repeat_threshold=DEFAULT_GETCH_REPEAT_THRESHOLD): + CaseClass.__init__(self, + ('term_type', term_type or self._detect_term_type()), + ('encoding', encoding or self._detect_encoding(stdout)), + ('stdin', stdin), + ('stdout', stdout), + ('stderr', stderr), + ('getch_repeat_threshold', getch_repeat_threshold) + ) + self.restore_terminal = self._get_restore_function() # binary function for restoring terminal attributes + self.last_getch_time = 0.0 + self.last_getch_char = '..' + + @staticmethod + def _detect_term_type(): + """ + Detect the type of the terminal. + """ + if os.name == 'nt': + if os.environ.get('TERM') == 'xterm': + # maybe MinTTY + return 'mintty' + else: + return 'nt' + if platform.system().upper().startswith('CYGWIN'): + return 'cygwin' + return 'posix' + + @staticmethod + def _detect_encoding(stdout): + """ + Detect the default encoding for the terminal's output. + :return: string: encoding string + """ + if stdout.encoding: + return stdout.encoding + + if os.environ.get('LANG'): + encoding = os.environ.get('LANG').split('.')[-1] + + # validate the encoding string + ret = None + try: + ret = codecs.lookup(encoding) + except LookupError: + pass + if ret: + return encoding + + return locale.getpreferredencoding() + + def _get_restore_function(self): + """ + Return the binary function for restoring terminal attributes. + :return: function (signal, frame) => None: + """ + if os.name == 'nt': + return lambda signal, frame: None + + assert hasattr(self.stdin, 'fileno'), 'Invalid input device.' + fd = self.stdin.fileno() + + try: + initial = termios.tcgetattr(fd) + except termios.error: + return lambda signal, frame: None + + return lambda signal, frame: termios.tcsetattr(fd, termios.TCSADRAIN, initial) + + def clear(self): + """ + Clear the terminal screen. + """ + if self.stdout.isatty() or self.term_type == 'mintty': + cmd, shell = { + 'posix': ('clear', False), + 'nt': ('cls', True), + 'cygwin': (['echo', '-en', r'\ec'], False), + 'mintty': (r'echo -en "\ec', False), + }[self.term_type] + subprocess.call(cmd, shell=shell, stdin=self.stdin, stdout=self.stdout, stderr=self.stderr) + + def clear_input_buffer(self): + """ + Clear the input buffer. + """ + if self.stdin.isatty(): + if os.name == 'nt': + while msvcrt.kbhit(): + msvcrt.getch() + else: + try: + self.stdin.seek(0, 2) # may fail in some unseekable file object + except IOError: + pass + + def getch(self): + """ + Read one character from stdin. + + If stdin is not a tty, read input as one line. + :return: unicode: + """ + ch = self._get_one_char() + self.clear_input_buffer() + + try: + # accept only unicode characters (for Python 2) + uch = to_unicode(ch, 'ascii') + except UnicodeError: + return '' + + return uch if self._check_key_repeat(uch) else '' + + def _get_one_char(self): + if not self.stdin.isatty(): # pipeline or MinTTY + return self.gets()[:1] + elif os.name == 'nt': # Windows + return msvcrt.getch() + else: # POSIX + try: + tty.setraw(self.stdin.fileno()) + return self.stdin.read(1) + finally: + self.restore_terminal(None, None) + + def _check_key_repeat(self, ch): + if self.getch_repeat_threshold <= 0.0: + return True + + t = time.time() + if ch == self.last_getch_char and t < self.last_getch_time + self.getch_repeat_threshold: + return False + + self.last_getch_time = t + self.last_getch_char = ch + return True + + def gets(self): + """ + Read line from stdin. + + The trailing newline will be omitted. + :return: string: + """ + ret = self.stdin.readline() + if ret == '': + raise EOFError # To break out of EOF loop + return ret.rstrip('\n') diff --git a/src/mog_commons/unittest.py b/src/mog_commons/unittest.py index 019228e..f68ec5b 100644 --- a/src/mog_commons/unittest.py +++ b/src/mog_commons/unittest.py @@ -15,6 +15,12 @@ from mog_commons.string import to_bytes, to_str +__all__ = [ + 'FakeInput', + 'FakeBytesInput', + 'TestCase', +] + class StringBuffer(object): """ @@ -39,6 +45,32 @@ def getvalue(self, encoding='utf-8', errors='strict'): return self._buffer.decode(encoding, errors) +class FakeInput(six.StringIO): + """Fake input object""" + + def __init__(self, buff=None): + six.StringIO.__init__(self, buff or '') + + def fileno(self): + return 0 + + def isatty(self): + return True + + +class FakeBytesInput(six.BytesIO): + """Fake bytes input object""" + + def __init__(self, buff=None): + six.BytesIO.__init__(self, buff or b'') + + def fileno(self): + return 0 + + def isatty(self): + return True + + class TestCase(base_unittest.TestCase): def assertRaisesRegexp(self, expected_exception, expected_regexp, callable_obj=None, *args, **kwargs): """ diff --git a/tests/mog_commons/test_terminal.py b/tests/mog_commons/test_terminal.py new file mode 100644 index 0000000..eb0175a --- /dev/null +++ b/tests/mog_commons/test_terminal.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +from __future__ import division, print_function, absolute_import, unicode_literals + +import os +import time +from mog_commons.terminal import TerminalHandler +from mog_commons.unittest import TestCase, base_unittest, FakeBytesInput + + +class TestTerminal(TestCase): + def test_getch_from_file(self): + with open(os.path.join('tests', 'resources', 'test_terminal_input.txt')) as f: + t = TerminalHandler(stdin=f) + self.assertEqual(t.getch(), 'a') + self.assertRaises(EOFError, t.getch) + + @base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible') + def test_getch(self): + self.assertEqual(TerminalHandler(stdin=FakeBytesInput(b'')).getch(), '') + self.assertEqual(TerminalHandler(stdin=FakeBytesInput(b'\x03')).getch(), '\x03') + self.assertEqual(TerminalHandler(stdin=FakeBytesInput(b'abc')).getch(), 'a') + self.assertEqual(TerminalHandler(stdin=FakeBytesInput('あ'.encode('utf-8'))).getch(), '') + self.assertEqual(TerminalHandler(stdin=FakeBytesInput('あ'.encode('sjis'))).getch(), '') + + @base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible') + def test_getch_key_repeat(self): + fin = FakeBytesInput(b'abcde') + t = TerminalHandler(stdin=fin) + + def append_char(ch): + fin.write(ch) + fin.seek(-len(ch), 1) + + self.assertEqual(t.getch(), 'a') + append_char(b'x') + self.assertEqual(t.getch(), 'x') + append_char(b'x') + self.assertEqual(t.getch(), '') + append_char(b'y') + self.assertEqual(t.getch(), 'y') + append_char(b'y') + self.assertEqual(t.getch(), '') + + time.sleep(1) + append_char(b'y') + self.assertEqual(t.getch(), 'y') diff --git a/tests/resources/test_terminal_input.txt b/tests/resources/test_terminal_input.txt new file mode 100644 index 0000000..00dedf6 --- /dev/null +++ b/tests/resources/test_terminal_input.txt @@ -0,0 +1 @@ +abcde From be6b1f152f90e3a014c7beca336d4856e361e976 Mon Sep 17 00:00:00 2001 From: Yosuke Mizutani Date: Wed, 4 Nov 2015 00:14:01 +0900 Subject: [PATCH 2/2] add a test case connects #15 --- tests/mog_commons/test_terminal.py | 41 +++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/tests/mog_commons/test_terminal.py b/tests/mog_commons/test_terminal.py index eb0175a..43848ce 100644 --- a/tests/mog_commons/test_terminal.py +++ b/tests/mog_commons/test_terminal.py @@ -25,22 +25,49 @@ def test_getch(self): @base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible') def test_getch_key_repeat(self): fin = FakeBytesInput(b'abcde') - t = TerminalHandler(stdin=fin) def append_char(ch): fin.write(ch) fin.seek(-len(ch), 1) - self.assertEqual(t.getch(), 'a') + t1 = TerminalHandler(stdin=fin) + self.assertEqual(t1.getch(), 'a') append_char(b'x') - self.assertEqual(t.getch(), 'x') + self.assertEqual(t1.getch(), 'x') append_char(b'x') - self.assertEqual(t.getch(), '') + self.assertEqual(t1.getch(), '') + append_char(b'x') + self.assertEqual(t1.getch(), '') + append_char(b'y') + self.assertEqual(t1.getch(), 'y') + append_char(b'y') + self.assertEqual(t1.getch(), '') + + time.sleep(1) + append_char(b'y') + self.assertEqual(t1.getch(), 'y') + + @base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible') + def test_getch_key_repeat_disabled(self): + fin = FakeBytesInput(b'abcde') + + def append_char(ch): + fin.write(ch) + fin.seek(-len(ch), 1) + + t1 = TerminalHandler(stdin=fin, getch_repeat_threshold=0) + self.assertEqual(t1.getch(), 'a') + append_char(b'x') + self.assertEqual(t1.getch(), 'x') + append_char(b'x') + self.assertEqual(t1.getch(), 'x') + append_char(b'x') + self.assertEqual(t1.getch(), 'x') append_char(b'y') - self.assertEqual(t.getch(), 'y') + self.assertEqual(t1.getch(), 'y') append_char(b'y') - self.assertEqual(t.getch(), '') + self.assertEqual(t1.getch(), 'y') time.sleep(1) append_char(b'y') - self.assertEqual(t.getch(), 'y') + self.assertEqual(t1.getch(), 'y')