From 377561efae3f7494b423fbda676d96b83c493312 Mon Sep 17 00:00:00 2001 From: Yosuke Mizutani Date: Wed, 4 Nov 2015 02:25:42 +0900 Subject: [PATCH] toggle to enable getch closes #23 --- src/mog_commons/__init__.py | 2 +- src/mog_commons/terminal.py | 19 ++++++++++++------- tests/mog_commons/test_terminal.py | 16 +++++++++++++++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/mog_commons/__init__.py b/src/mog_commons/__init__.py index b04cffb..60eb1af 100644 --- a/src/mog_commons/__init__.py +++ b/src/mog_commons/__init__.py @@ -1 +1 @@ -__version__ = '0.1.17' +__version__ = '0.1.18' diff --git a/src/mog_commons/terminal.py b/src/mog_commons/terminal.py index 3cdf611..2362675 100644 --- a/src/mog_commons/terminal.py +++ b/src/mog_commons/terminal.py @@ -46,7 +46,7 @@ class TerminalHandler(CaseClass): def __init__(self, term_type=None, encoding=None, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, getch_repeat_threshold=DEFAULT_GETCH_REPEAT_THRESHOLD, - keep_input_clean=True): + keep_input_clean=True, getch_enabled=True): CaseClass.__init__(self, ('term_type', term_type or self._detect_term_type()), ('encoding', encoding or self._detect_encoding(stdout)), @@ -55,11 +55,18 @@ def __init__(self, term_type=None, encoding=None, ('stderr', stderr), ('getch_repeat_threshold', getch_repeat_threshold), ('keep_input_clean', keep_input_clean), + ('getch_enabled', getch_enabled and self._can_getch_enable(stdin)) ) self.restore_terminal = self._get_restore_function() # binary function for restoring terminal attributes self.last_getch_time = 0.0 self.last_getch_char = '..' + @staticmethod + def _can_getch_enable(stdin): + if stdin.isatty(): + return os.name == 'nt' or hasattr(stdin, 'fileno') + return False + @staticmethod def _detect_term_type(): """ @@ -103,13 +110,11 @@ def _get_restore_function(self): Return the binary function for restoring terminal attributes. :return: function (signal, frame) => None: """ - if os.name == 'nt': + if os.name == 'nt' or not self.getch_enabled: return lambda signal, frame: None - assert hasattr(self.stdin, 'fileno'), 'Invalid input device.' - fd = self.stdin.fileno() - try: + fd = self.stdin.fileno() initial = termios.tcgetattr(fd) except termios.error: return lambda signal, frame: None @@ -147,7 +152,7 @@ def getch(self): """ Read one character from stdin. - If stdin is not a tty, read input as one line. + If stdin is not a tty or set `getch_enabled`=False, read input as one line. :return: unicode: """ ch = self._get_one_char() @@ -163,7 +168,7 @@ def getch(self): return uch if self._check_key_repeat(uch) else '' def _get_one_char(self): - if not self.stdin.isatty(): # pipeline or MinTTY + if not self.getch_enabled: return self.gets()[:1] elif os.name == 'nt': # Windows return msvcrt.getch() diff --git a/tests/mog_commons/test_terminal.py b/tests/mog_commons/test_terminal.py index 0db24fd..26d75d2 100644 --- a/tests/mog_commons/test_terminal.py +++ b/tests/mog_commons/test_terminal.py @@ -5,7 +5,7 @@ import time import six from mog_commons.terminal import TerminalHandler -from mog_commons.unittest import TestCase, base_unittest, FakeBytesInput +from mog_commons.unittest import TestCase, base_unittest, FakeBytesInput, FakeInput class TestTerminal(TestCase): @@ -23,6 +23,14 @@ def test_getch(self): self.assertEqual(TerminalHandler(stdin=FakeBytesInput('あ'.encode('utf-8'))).getch(), '') self.assertEqual(TerminalHandler(stdin=FakeBytesInput('あ'.encode('sjis'))).getch(), '') + def test_getch_disabled(self): + t = TerminalHandler(stdin=FakeInput('a\nb\ncd\ne\n'), keep_input_clean=False, getch_enabled=False) + self.assertEqual(t.getch(), 'a') + self.assertEqual(t.getch(), 'b') + self.assertEqual(t.getch(), 'c') + self.assertEqual(t.getch(), 'e') + self.assertRaises(EOFError, t.getch) + @base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible') def test_getch_key_repeat(self): fin = FakeBytesInput(b'abcde') @@ -95,3 +103,9 @@ def test_resolve_encoding(self): out = io.TextIOWrapper(six.StringIO(), 'sjis') self.assertEqual(TerminalHandler._detect_encoding(out), 'sjis') + + def test_init(self): + self.assertEqual(TerminalHandler(stdin=six.StringIO(), getch_enabled=False).getch_enabled, False) + self.assertEqual(TerminalHandler(stdin=six.StringIO(), getch_enabled=True).getch_enabled, False) + self.assertEqual(TerminalHandler(stdin=FakeInput(), getch_enabled=False).getch_enabled, False) + self.assertEqual(TerminalHandler(stdin=FakeInput(), getch_enabled=True).getch_enabled, True)