From ae491709a20499b7aaa4fa4ba83310d0f3cef497 Mon Sep 17 00:00:00 2001 From: dead-beef Date: Sun, 25 Feb 2018 20:04:39 +0000 Subject: [PATCH] update tests --- setup.py | 6 +- test | 4 +- tests/cli/test_util.py | 562 +++++++++++++++------------------- tests/image/test_markov.py | 253 ++++++++------- tests/image/test_scanner.py | 291 +++++++++--------- tests/image/test_traversal.py | 341 +++++++++++---------- tests/image/test_util.py | 94 +++--- tests/storage/__init__.py | 0 tests/storage/test_base.py | 59 ++++ tests/storage/test_json.py | 82 +++++ tests/storage/test_sqlite.py | 79 +++++ tests/test_base.py | 136 ++++++-- tests/test_json.py | 91 ------ tests/test_parser.py | 287 ++++++++--------- tests/test_scanner.py | 212 +++++++------ tests/test_sqlite.py | 147 --------- tests/test_util.py | 344 ++++++++++----------- tests/text/test_markov.py | 44 +++ tests/text/test_util.py | 109 ++++--- 19 files changed, 1603 insertions(+), 1538 deletions(-) create mode 100644 tests/storage/__init__.py create mode 100644 tests/storage/test_base.py create mode 100644 tests/storage/test_json.py create mode 100644 tests/storage/test_sqlite.py delete mode 100644 tests/test_json.py delete mode 100644 tests/test_sqlite.py create mode 100644 tests/text/test_markov.py diff --git a/setup.py b/setup.py index bd9128d..86bac66 100755 --- a/setup.py +++ b/setup.py @@ -1,11 +1,8 @@ #!/usr/bin/env python3 import os -from unittest import TestLoader from setuptools import setup, find_packages -def tests(): - return TestLoader().discover('tests') BASE_DIR = os.path.abspath(os.path.dirname(__file__)) try: @@ -39,12 +36,13 @@ def tests(): entry_points={ 'console_scripts': ['markovchain=markovchain.cli:main'], }, - test_suite='setup.tests', install_requires=['tqdm', 'ijson', 'custom_inherit'], extras_require={ 'image': ['pillow'], 'dev': [ 'pillow', + 'pytest', + 'pytest-mock', 'coverage', 'sphinx', 'sphinx_rtd_theme', diff --git a/test b/test index ede5223..ab6a048 100755 --- a/test +++ b/test @@ -5,7 +5,7 @@ if [ -d env ]; then fi if which coverage >/dev/null 2>&1; then - coverage run --include 'markovchain/*' setup.py test && coverage report -m + coverage run --include 'markovchain/*' -m pytest && coverage report -m else - ./setup.py test + pytest fi diff --git a/tests/cli/test_util.py b/tests/cli/test_util.py index c4dee05..60b0c46 100644 --- a/tests/cli/test_util.py +++ b/tests/cli/test_util.py @@ -1,335 +1,251 @@ -from unittest import TestCase -from unittest.mock import Mock, MagicMock, mock_open, patch +from unittest.mock import Mock, MagicMock, mock_open from io import StringIO from argparse import Namespace import json import sys +import pytest + +from markovchain import JsonStorage, SqliteStorage from markovchain.cli.util import ( no_tqdm, NoProgressBar, - pprint, - load, save, IJSON_MIN_SIZE, IJSON_MIN_COMPRESSED_SIZE, + pprint, load, save, set_args, JSON, SQLITE, check_output_format, infiles, outfiles ) -from markovchain import MarkovBase, MarkovJsonMixin, MarkovSqliteMixin - -class TestNoProgressBar(TestCase): - @patch('sys.stderr', new_callable=StringIO) - def testWarning(self, stderr): - p = NoProgressBar() - self.assertEqual(stderr.getvalue(), '') - self.assertFalse(p.warning) - NoProgressBar.print_warning() - msg = stderr.getvalue() - self.assertNotEqual(msg, '') - self.assertTrue(p.warning) - NoProgressBar.print_warning() - self.assertEqual(stderr.getvalue(), msg) - self.assertTrue(p.warning) - - @patch('markovchain.cli.util.NoProgressBar') - def testNoTqdm(self, no_pbar): - p = no_tqdm(total=100, leave=False, desc='') - self.assertIsInstance(p, Mock) - no_pbar.assert_called_with() - no_pbar.print_warning.assert_called_with() - no_pbar.reset_mock() - it = range(3) - p = no_tqdm(it) - self.assertIs(p, it) - no_pbar.assert_not_called() - no_pbar.print_warning.assert_called_with() - -class TestPPrint(TestCase): - @patch('sys.stdout', new_callable=StringIO) - def testPPrint(self, stdout): - tests = [ - {'x': 0, 'y': [{'z': '0'}, {'z': '1'}]}, - [{'x': 0.1, 'y': {'y': False, 'z': None, 'u': 3}}, {'z': 0}] - ] - for test in tests: - stdout.seek(0) - stdout.truncate(0) - pprint(test) - self.assertEqual( - stdout.getvalue(), - json.dumps(test, indent=4, sort_keys=True) + '\n' - ) - - @patch('sys.stdout', new_callable=StringIO) - def testPPrintArray(self, stdout): - tests = [ - ([], '[]\n'), - ([0, 1, 2], '[0, 1, 2]\n'), - ([True, None], '[true, null]\n'), - ([0, [1, 2], 3], '[\n 0,\n [1, 2],\n 3\n]\n') - ] - for test, res in tests: - stdout.seek(0) - stdout.truncate(0) - pprint(test) - self.assertEqual(stdout.getvalue(), res) - -class TestLoad(TestCase): - @patch('builtins.open', new_callable=mock_open) - @patch('bz2.open', new_callable=mock_open) - @patch('os.path.getsize', return_value=IJSON_MIN_SIZE - 1) - @patch('sys.stdout', new_callable=StringIO) - def testOpen(self, stdout, getsize, bz2op, op): - class Test: - pass - Test.load = MagicMock(return_value=0) - markov = Test - fname = 'test' - args = Namespace(progress=False, settings={}) - - self.assertEqual(load(markov, fname, args), 0) - self.assertFalse(getsize.called) - self.assertFalse(op.called) - self.assertFalse(bz2op.called) - markov.load.assert_called_with(fname, args.settings) - - class Test2(MarkovJsonMixin): - pass - Test2.load = MagicMock(return_value=0) - markov = Test2 - - self.assertEqual(load(markov, fname, args), 0) - getsize.assert_called_with(fname) - op.assert_called_with(fname, 'rt') - self.assertFalse(bz2op.called) - - getsize.return_value = IJSON_MIN_SIZE + 1 - self.assertEqual(load(markov, fname, args), 0) - op.assert_called_with(fname, 'rb') - self.assertFalse(bz2op.called) - - fname = '.bz2' - getsize.return_value = IJSON_MIN_COMPRESSED_SIZE - 1 - op.called = False - self.assertEqual(load(markov, fname, args), 0) - bz2op.assert_called_with(fname, 'rt') - self.assertFalse(op.called) - - getsize.return_value = IJSON_MIN_COMPRESSED_SIZE + 1 - self.assertEqual(load(markov, fname, args), 0) - bz2op.assert_called_with(fname, 'rb') - self.assertFalse(op.called) - - self.assertEqual(stdout.getvalue(), '') - - @patch('builtins.open', new_callable=mock_open) - @patch('bz2.open', new_callable=mock_open) - @patch('os.path.getsize', return_value=IJSON_MIN_SIZE - 1) - @patch('sys.stdout', new_callable=StringIO) - def testProgress(self, stdout, getsize, bz2op, op): # pylint:disable=unused-argument - class Test(MarkovJsonMixin): - pass - Test.load = MagicMock(return_value=0) - markov = Test - fname = 'test' - args = Namespace(progress=True, settings={}) - - self.assertEqual(load(markov, fname, args), 0) - self.assertNotEqual(stdout.getvalue(), '') - - class Test2(): - pass - Test2.load = MagicMock(return_value=0) - markov = Test2 - stdout.seek(0) - stdout.truncate(0) - self.assertEqual(load(markov, fname, args), 0) - self.assertEqual(stdout.getvalue(), '') - - -class TestSave(TestCase): - @patch('builtins.open', new_callable=mock_open) - @patch('bz2.open', new_callable=mock_open) - @patch('sys.stdout', new_callable=StringIO) - def testOpen(self, stdout, bz2op, op): - markov = MagicMock() - fname = 'test' - args = Namespace(progress=False) - - save(markov, fname, args) - self.assertFalse(op.called) - self.assertFalse(bz2op.called) - markov.save.assert_called_with() - - markov = MagicMock(spec=MarkovJsonMixin) - save(markov, fname, args) - op.assert_called_with(fname, 'wt') - self.assertFalse(bz2op.called) - self.assertTrue(markov.save.called) - - markov.save.reset_mock() - op.called = False - save(markov, None, args) - self.assertFalse(op.called) - self.assertFalse(bz2op.called) - markov.save.assert_called_with(stdout) - - fname = '.bz2' - op.called = False - markov.save.reset_mock() - save(markov, fname, args) - bz2op.assert_called_with(fname, 'wt') - self.assertFalse(op.called) - self.assertTrue(markov.save.called) - - self.assertEqual(stdout.getvalue(), '') - - @patch('builtins.open', new_callable=mock_open) - @patch('bz2.open', new_callable=mock_open) - @patch('sys.stdout', new_callable=StringIO) - def testProgress(self, stdout, bz2op, op): # pylint:disable=unused-argument - markov = MagicMock(spec=MarkovJsonMixin) - fname = 'test' - args = Namespace(progress=True) - - save(markov, fname, args) - self.assertNotEqual(stdout.getvalue(), '') - - stdout.seek(0) - stdout.truncate(0) - save(markov, None, args) - self.assertEqual(stdout.getvalue(), '') - - markov = MagicMock() - save(markov, fname, args) - self.assertEqual(stdout.getvalue(), '') - - -class TestSetArgs(TestCase): - def testErrors(self): - args = Namespace(output=sys.stdout, progress=True) - with self.assertRaises(ValueError): - set_args(args, ()) - - def testType(self): - args = Namespace() - set_args(args, ()) - self.assertEqual(args.type, JSON) - self.assertTrue(issubclass(args.markov, MarkovBase)) - self.assertTrue(issubclass(args.markov, MarkovJsonMixin)) - - args = Namespace() - set_args(args, (Mock,)) - self.assertEqual(args.type, JSON) - self.assertTrue(issubclass(args.markov, Mock)) - self.assertTrue(issubclass(args.markov, MarkovBase)) - self.assertTrue(issubclass(args.markov, MarkovJsonMixin)) - - args = Namespace(output='.db') - set_args(args, ()) - self.assertEqual(args.type, SQLITE) - self.assertTrue(issubclass(args.markov, MarkovBase)) - self.assertTrue(issubclass(args.markov, MarkovSqliteMixin)) - - args = Namespace(output='.db', state='.json.bz2') - set_args(args, ()) - self.assertEqual(args.type, JSON) - - args = Namespace(output='.db', state='.json.bz2', type='sqlite') - set_args(args, ()) - self.assertEqual(args.type, SQLITE) - - args = Namespace(output='.db', state='.json.bz2', type='json') - set_args(args, ()) - self.assertEqual(args.type, JSON) - - @patch('json.load', new=lambda x: x()) - def testSettings(self): - args = Namespace() - set_args(args, ()) - self.assertIsNone(args.settings) - - args = Namespace(settings=None) - set_args(args, ()) - self.assertEqual(args.settings, {}) - - s = list(range(3)) - mock = MagicMock(return_value=s) - args = Namespace(settings=mock) - set_args(args, ()) - self.assertEqual(args.settings, s) - mock.close.assert_called_with() - -class TestCheckOutputFormat(TestCase): - def testError(self): - tests = [ - ('test', -1), - ('test', 0), - ('test', 2), - ('test%d%d', 2) - ] - for test in tests: - with self.assertRaises(ValueError): - check_output_format(*test) - - def testNoError(self): # pylint: disable=no-self-use - tests = [ - ('test', 1), - ('test%d', 2) - ] - for test in tests: - check_output_format(*test) - - -class TestInFiles(TestCase): - @patch('markovchain.cli.util.tqdm') - def testNoProgress(self, tqdm): - tests = [ - ([], True), - (list(range(3)), False) - ] - for test in tests: - with infiles(*test) as it: - self.assertFalse(tqdm.called) - self.assertIs(it, test[0]) - - @patch('markovchain.cli.util.tqdm') - def testProgress(self, tqdm): - test = list(range(3)) - with infiles(test, True) as pbar: - self.assertTrue(tqdm.called) - self.assertTrue((test,) in tqdm.call_args) - self.assertIsInstance(pbar, Mock) - pbar.close.assert_called_with() - - -class TestOutFiles(TestCase): - def testError(self): - tests = [ - ('', -1, False), - ('', 0, False) - ] - for test in tests: - with self.assertRaises(ValueError), outfiles(*test): - pass - - @patch('markovchain.cli.util.tqdm') - def testNoProgress(self, tqdm): - tests = [ - (('', 1, False), ['']), - (('%d', 2, False), ['0', '1']) - ] - for test, res in tests: - with outfiles(*test) as it: - it = list(it) - self.assertFalse(tqdm.called) - self.assertEqual(it, res) - - @patch('markovchain.cli.util.tqdm') - def testProgress(self, tqdm): - test = ('%d', 3, True) - res = ['0', '1', '2'] - with outfiles(*test) as pbar: - self.assertTrue(tqdm.called) - self.assertEqual(list(tqdm.call_args[0][0]), res) - self.assertIsInstance(pbar, Mock) - pbar.close.assert_called_with() + + +def test_no_progress_bar_warning(mocker): + stderr = mocker.patch('sys.stderr', new_callable=StringIO) + pbar = NoProgressBar() + assert stderr.getvalue() == '' + assert not NoProgressBar.warning + pbar.print_warning() + assert stderr.getvalue() != '' + assert NoProgressBar.warning + pbar.print_warning() + assert stderr.getvalue() != '' + assert NoProgressBar.warning + +def test_no_tqdm(mocker): + no_pbar = mocker.patch('markovchain.cli.util.NoProgressBar') + pbar = no_tqdm(total=100, leave=False, desc='') + assert isinstance(pbar, Mock) + no_pbar.assert_called_with() + no_pbar.print_warning.assert_called_with() + no_pbar.reset_mock() + iterable = range(3) + pbar = no_tqdm(iterable) + assert pbar is iterable + no_pbar.assert_not_called() + no_pbar.print_warning.assert_called_with() + + +@pytest.mark.parametrize('test', [ + {'x': 0, 'y': [{'z': '0'}, {'z': '1'}]}, + [{'x': 0.1, 'y': {'y': False, 'z': None, 'u': 3}}, {'z': 0}] +]) +def test_pprint(mocker, test): + stdout = mocker.patch('sys.stdout', new_callable=StringIO) + res = json.dumps(test, indent=4, sort_keys=True) + '\n' + pprint(test) + assert stdout.getvalue() == res + +@pytest.mark.parametrize('test,res', [ + ([], '[]\n'), + ([0, 1, 2], '[0, 1, 2]\n'), + ([True, None], '[true, null]\n'), + ([0, [1, 2], 3], '[\n 0,\n [1, 2],\n 3\n]\n') +]) +def test_pprint_array(mocker, test, res): + stdout = mocker.patch('sys.stdout', new_callable=StringIO) + pprint(test) + assert stdout.getvalue() == res + + +@pytest.mark.parametrize('fname,bz2,stdout', [ + ('test.json', False, False), + ('test.json.bz2', True, True) +]) +def test_load_json(mocker, fname, bz2, stdout): + open_ = mocker.patch('builtins.open', new_callable=mock_open) + bz2open = mocker.patch('bz2.open', new_callable=mock_open) + stdout_ = mocker.patch('sys.stdout', new_callable=StringIO) + json_storage = MagicMock() + json_storage_cls = mocker.patch( + 'markovchain.cli.util.JsonStorage', + load=Mock(return_value=json_storage) + ) + if bz2: + handle = bz2open() + else: + handle = open_() + + cls = Mock(load=Mock(return_value=0)) + args = Namespace(type=JSON, progress=stdout, settings={}) + + assert load(cls, fname, args) == 0 + assert (stdout_.getvalue() != '') == stdout + if bz2: + assert not open_.called + bz2open.assert_called_with(fname, 'rt') + else: + assert not bz2open.called + open_.assert_called_with(fname, 'rt') + json_storage_cls.load.assert_called_once_with(handle) + cls.load.assert_called_once_with(json_storage) + +def test_load_sqlite(mocker): + sqlite_storage = MagicMock() + sqlite_storage_cls = mocker.patch( + 'markovchain.cli.util.SqliteStorage', + load=Mock(return_value=sqlite_storage) + ) + fname = 'test' + cls = Mock(load=Mock(return_value=0)) + args = Namespace(type=SQLITE, progress=False, settings={}) + assert load(cls, fname, args) == 0 + sqlite_storage_cls.load.assert_called_once_with(fname) + cls.load.assert_called_once_with(sqlite_storage) + + +@pytest.mark.parametrize('fname,bz2,stdout', [ + ('test.json', False, False), + ('test.json.bz2', True, True) +]) +def test_save_json(mocker, fname, bz2, stdout): + open_ = mocker.patch('builtins.open', new_callable=mock_open) + bz2open = mocker.patch('bz2.open', new_callable=mock_open) + stdout_ = mocker.patch('sys.stdout', new_callable=StringIO) + markov = Mock(storage=JsonStorage(), save=Mock()) + + if bz2: + handle = bz2open() + else: + handle = open_() + + args = Namespace(progress=stdout, settings={}) + + save(markov, fname, args) + assert (stdout_.getvalue() != '') == stdout + if bz2: + assert not open_.called + bz2open.assert_called_with(fname, 'wt') + else: + assert not bz2open.called + open_.assert_called_with(fname, 'wt') + markov.save.assert_called_once_with(handle) + +def test_save_sqlite(): + markov = Mock(storage=SqliteStorage(), save=Mock()) + fname = 'test' + args = Namespace(progress=False, settings={}) + save(markov, fname, args) + markov.save.assert_called_once_with() + + +def test_set_args_error(): + args = Namespace(output=sys.stdout, progress=True) + with pytest.raises(ValueError): + set_args(args) + +@pytest.mark.parametrize('args,res', [ + (Namespace(), JSON), + (Namespace(output='.db'), SQLITE), + (Namespace(state='.json'), JSON), + (Namespace(state='.json.bz2'), JSON), + (Namespace(type='json'), JSON), + (Namespace(output='.db', state='.json.bz2'), JSON), + (Namespace(output='.db', state='.json.bz2', type='sqlite'), SQLITE) +]) +def test_set_args_type(args, res): + set_args(args) + assert args.type == res + +def test_set_args_settings(mocker): + mocker.patch('json.load', new=lambda x: x()) + + args = Namespace() + set_args(args) + assert args.settings == {} # pylint:disable=no-member + + args = Namespace(settings=None) + set_args(args) + assert args.settings == {} # pylint:disable=no-member + + s = list(range(3)) + mock = MagicMock(return_value=s) + args = Namespace(settings=mock) + set_args(args) + assert args.settings == s # pylint:disable=no-member + mock.close.assert_called_with() + + +@pytest.mark.parametrize('test', [ + ('test', -1), + ('test', 0), + ('test', 2), + ('test%d%d', 2) +]) +def test_check_output_format_error(test): + with pytest.raises(ValueError): + check_output_format(*test) + +@pytest.mark.parametrize('test', [ + ('test', 1), + ('test%d', 2) +]) +def test_check_output_format(test): + check_output_format(*test) + + +@pytest.mark.parametrize('test', [ + ([], True), + (list(range(3)), False) +]) +def test_infiles(mocker, test): + tqdm = mocker.patch('markovchain.cli.util.tqdm') + with infiles(*test) as iter_: + assert not tqdm.called + assert iter_ is test[0] + +def test_infiles_progress(mocker): + tqdm = mocker.patch('markovchain.cli.util.tqdm') + test = [0, 1] + with infiles(test, True) as pbar: + assert tqdm.called + assert (test,) in tqdm.call_args + assert isinstance(pbar, Mock) + pbar.close.assert_called_once_with() + + +@pytest.mark.parametrize('test', [ + ('', -1, False), + ('', 0, False) +]) +def test_outfiles_error(test): + with pytest.raises(ValueError), outfiles(*test): + pass + +@pytest.mark.parametrize('test,res', [ + (('', 1, False), ['']), + (('%d', 2, False), ['0', '1']) +]) +def test_outfiles(mocker, test, res): + tqdm = mocker.patch('markovchain.cli.util.tqdm') + with outfiles(*test) as files: + files = list(files) + assert not tqdm.called + assert files == res + +def test_outfiles_progress(mocker): + tqdm = mocker.patch('markovchain.cli.util.tqdm') + test = ('%d', 3, True) + res = ['0', '1', '2'] + with outfiles(*test) as pbar: + assert tqdm.called + assert list(tqdm.call_args[0][0]) == res + assert isinstance(pbar, Mock) + pbar.close.assert_called_once_with() diff --git a/tests/image/test_markov.py b/tests/image/test_markov.py index 6bc14d6..481c858 100644 --- a/tests/image/test_markov.py +++ b/tests/image/test_markov.py @@ -1,70 +1,111 @@ -from unittest import TestCase +import pytest from PIL import Image -from markovchain import MarkovBase, MarkovJsonMixin, Scanner, Parser -from markovchain.image import MarkovImageMixin, HLines, VLines - - -class TestMarkovImage(TestCase): - class Markov(MarkovImageMixin, MarkovJsonMixin, MarkovBase): - pass - - @classmethod - def setUpClass(cls): - palette = [ - 0x00, 0x00, 0x00, - 0x44, 0x44, 0x44, - 0xaa, 0xaa, 0xaa, - 0xdd, 0xdd, 0xdd +from markovchain import Scanner, Parser, LevelParser +from markovchain.image import MarkovImage, ImageScanner, HLines, VLines + + +@pytest.fixture +def palette_test(): + palette = [ + 0x00, 0x00, 0x00, + 0x44, 0x44, 0x44, + 0xaa, 0xaa, 0xaa, + 0xdd, 0xdd, 0xdd + ] + palette.extend(0 for _ in range((256 - len(palette)) * 3)) + return palette + + +def test_markov_image_properties(): + markov = MarkovImage() + assert isinstance(markov.scanner, ImageScanner) + assert isinstance(markov.parser, LevelParser) + assert markov.scanner.levels == markov.levels + assert markov.parser.levels == markov.levels + markov.levels = 3 + assert markov.scanner.levels == markov.levels + assert markov.parser.levels == markov.levels + with pytest.raises(ValueError): + markov.levels = -1 + assert markov.scanner.levels == markov.levels + assert markov.parser.levels == markov.levels + assert markov.levels == 3 + +def test_markov_image_generate_error(): + markov = MarkovImage() + with pytest.raises(RuntimeError): + markov(2, 2) + +@pytest.mark.parametrize('test,res', [ + ((2, 2), [0, 1, 2, 3]), + ((4, 2), [0, 1, 2, 3, 0, 1, 2, 3]) +]) +def test_markov_image_generate(palette_test, test, res): + scanner = Scanner(lambda x: x) + scanner.traversal = [HLines()] + scanner.levels = 1 + scanner.level_scale = [] + markov = MarkovImage( + palette=palette_test, + scanner=scanner, + parser=Parser() + ) + markov.data([['00', '01', '02', '03']]) + assert list(markov(*test).getdata()) == res + +@pytest.mark.parametrize('args,kwargs,data,res', [ + ((2, 2), {}, False, RuntimeError), + ((2, 2), {'levels': 0}, True, ValueError), + ((2, 2), {'levels': 3}, True, ValueError), + ((2, 2), {'levels': 1}, True, [0, 1, 2, 3]), + ( + (2, 2), + {'levels': 2}, + True, + [ + 0, 1, 1, 2, + 1, 1, 2, 2, + 2, 3, 3, 0, + 3, 3, 0, 0 ] - palette.extend(0 for _ in range((256 - len(palette)) * 3)) - cls.palette = palette - - def test_properties(self): - m = self.Markov() - self.assertEqual(m.scanner.levels, m.levels) - self.assertEqual(m.parser.levels, m.levels) - m.levels = 3 - self.assertEqual(m.scanner.levels, m.levels) - self.assertEqual(m.parser.levels, m.levels) - with self.assertRaises(ValueError): - m.levels = -1 - self.assertEqual(m.scanner.levels, m.levels) - self.assertEqual(m.parser.levels, m.levels) - self.assertEqual(m.levels, 3) - - def test_generate(self): - scanner = Scanner(lambda x: x) - scanner.traversal = [HLines()] - scanner.levels = 1 - scanner.level_scale = [] - - m = self.Markov(palette=self.palette, scanner=scanner, parser=Parser()) - - with self.assertRaises(RuntimeError): - m.image(2, 2) - - m.data([['00', '01', '02', '03']]) - - data = list(m.image(2, 2).getdata()) - self.assertEqual(data, [0, 1, 2, 3]) - - data = list(m.image(4, 2).getdata()) - self.assertEqual(data, [0, 1, 2, 3, 0, 1, 2, 3]) - - def test_generate_levels(self): - scanner = Scanner(lambda x: x) - scanner.traversal = [HLines(), VLines()] - scanner.levels = 2 - scanner.level_scale = [2] - scanner.set_palette = lambda img: img - - m = self.Markov(levels=2, palette=self.palette, scanner=scanner) - - with self.assertRaises(RuntimeError): - m.image(2, 2) - - m.data([ + ), + ( + (2, 1), + {'levels': 1, 'start_level': -5, 'start_image': True}, + True, + [0, 1] + ), + ( + (None, None), + {'levels': 1, 'start_level': 0, 'start_image': True}, + True, + [1] + ), + ( + (None, None), + {'levels': 1, 'start_level': 2, 'start_image': True}, + True, + [1] + ), + ( + (None, None), + {'levels': 2, 'start_level': 0, 'start_image': True}, + True, + [1, 2, 2, 2] + ), +]) +def test_markov_image_generate_levels(palette_test, args, kwargs, data, res): + scanner = Scanner(lambda x: x) + scanner.traversal = [HLines(), VLines()] + scanner.levels = 2 + scanner.level_scale = [2] + scanner.set_palette = lambda img: img + + markov = MarkovImage(levels=2, palette=palette_test, scanner=scanner) + + if data: + markov.data([ ['00', '01', '02', '03',], [(Scanner.START, '0000'), '0101', (Scanner.START, '0001'), '0102', @@ -72,55 +113,39 @@ def test_generate_levels(self): (Scanner.START, '0003'), '0100',] ]) - data = list(m.image(2, 2, levels=1).getdata()) # pylint: disable=no-member - self.assertEqual(data, [0, 1, 2, 3]) - - data = list(m.image(2, 2, levels=2).getdata()) # pylint: disable=no-member - self.assertEqual( - data, - [ - 0, 1, 1, 2, - 1, 1, 2, 2, - 2, 3, 3, 0, - 3, 3, 0, 0 - ] - ) - - with self.assertRaises(ValueError): - m.image(2, 2, levels=3) - with self.assertRaises(ValueError): - m.image(2, 2, levels=0) - + if 'start_image' in kwargs: img = Image.new('P', (1, 1)) - img.putpalette(self.palette) + img.putpalette(palette_test) img.putpixel((0, 0), 1) - - data = list(m.image( - 2, 1, levels=1, - start_level=-5, start_image=img - ).getdata()) # pylint: disable=no-member - self.assertEqual(data, [0, 1]) - - data = list(m.image( - None, None, levels=1, - start_level=0, start_image=img - ).getdata()) # pylint: disable=no-member - self.assertEqual(data, [1]) - - data = list(m.image( - None, None, levels=1, - start_level=2, start_image=img - ).getdata()) # pylint: disable=no-member - self.assertEqual(data, [1]) - - data = list(m.image( - None, None, levels=2, - start_level=0, start_image=img - ).getdata()) # pylint: disable=no-member - self.assertEqual(data, [1, 2, 2, 2]) - - def test_save_load(self): - m = self.Markov() - saved = m.get_save_data() - loaded = self.Markov(**saved) - self.assertEqual(m, loaded) + kwargs['start_image'] = img + + if isinstance(res, type): + with pytest.raises(res): + markov(*args, **kwargs) + else: + assert list(markov(*args, **kwargs).getdata()) == res + +def test_markov_image_get_settings_json(mocker, palette_test): + get_settings_json = mocker.patch( + 'markovchain.Markov.get_settings_json', + return_value={'x': 0} + ) + markov = MarkovImage( + levels=2, + palette=palette_test + ) + data = markov.get_settings_json() + assert data == { + 'x': 0, + 'levels': 2, + 'palette': palette_test + } + get_settings_json.assert_called_once_with() + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((1,), (2,), False), + ((1, [8, 4, 8]), (1, [8, 4, 4]), False) +]) +def test_markov_image_eq(test, test2, res): + assert (MarkovImage(*test) == MarkovImage(*test2)) == res diff --git a/tests/image/test_scanner.py b/tests/image/test_scanner.py index 6d89502..d986d43 100644 --- a/tests/image/test_scanner.py +++ b/tests/image/test_scanner.py @@ -1,4 +1,4 @@ -from unittest import TestCase +import pytest from PIL import Image from markovchain import Scanner @@ -6,145 +6,152 @@ from markovchain.image.util import palette as default_palette -class TestImageScanner(TestCase): - @classmethod - def setUpClass(cls): - data = b'\x00\x00\x00\xaa\xaa\xaa\xdd\xdd\xdd\x44\x44\x44' - palette = [ - 0x00, 0x00, 0x00, - 0x44, 0x44, 0x44, - 0xaa, 0xaa, 0xaa, - 0xdd, 0xdd, 0xdd +@pytest.fixture +def image_test(): + data = b'\x00\x00\x00\xaa\xaa\xaa\xdd\xdd\xdd\x44\x44\x44' + palette = [ + 0x00, 0x00, 0x00, + 0x44, 0x44, 0x44, + 0xaa, 0xaa, 0xaa, + 0xdd, 0xdd, 0xdd + ] + palette.extend(0 for _ in range((256 - len(palette)) * 3)) + image = Image.frombytes('RGB', (2, 2), data, 'raw', 'RGB') + return image, palette + + +def test_image_scanner_init(): + scan = ImageScanner() + assert scan.palette == default_palette(8, 4, 8) + assert isinstance(scan.palette_image, Image.Image) + assert scan.levels is not None + assert scan.min_size is not None + assert isinstance(scan.level_scale, list) + assert isinstance(scan.traversal, list) + assert len(scan.level_scale) == scan.levels - 1 + assert len(scan.traversal) == scan.levels + +def test_image_scanner_properties(image_test): + _, palette = image_test + scan = ImageScanner() + + scan.palette = palette + + assert scan.palette == palette + assert isinstance(scan.palette_image, Image.Image) + + with pytest.raises(ValueError): + scan.levels = 0 + scan.levels = 3 + assert len(scan.traversal) == scan.levels + + with pytest.raises(ValueError): + scan.level_scale = [] + with pytest.raises(ValueError): + scan.level_scale = [1, -1] + scan.level_scale = range(2, 10) + assert scan.level_scale == [2, 3] + assert scan.min_size == 6 + scan.level_scale = 2 + assert scan.level_scale == [2, 2] + assert scan.min_size == 4 + scan.levels = 2 + assert len(scan.traversal) == scan.levels + assert scan.level_scale == [2] + assert scan.min_size == 2 + scan.levels = 1 + assert len(scan.traversal) == scan.levels + assert scan.level_scale == [] + assert scan.min_size == 1 + + traversal = [HLines(), VLines()] + scan.traversal = traversal + scan.levels = 2 + assert scan.traversal is traversal + assert scan.level_scale == [2] + +@pytest.mark.parametrize('test', [ + (1, 1), (4, 1), (1, 4) +]) +def test_image_scanner_input(test, image_test): + image, palette = image_test + scan = ImageScanner(palette=palette, resize=test) + img = scan.input(image) + assert img.size <= test + assert img.mode == 'RGB' + img = [scan.level(img, level) for level in range(scan.levels)] + assert len(img) == 1 + img = img[0] + assert img.mode == 'P' + assert img.size == (1, 1) + assert list(img.getdata()) in [[0], [1]] + +def test_image_scanner_input_error(image_test): + image, palette = image_test + scan = ImageScanner(palette=palette, levels=3, level_scale=2) + with pytest.raises(ValueError): + scan.input(image) + +def test_image_scanner_input_levels(image_test): + image, palette = image_test + scan = ImageScanner(palette=palette, levels=2, level_scale=2) + + img = scan.input(image) + img = [scan.level(img, level) for level in range(scan.levels)] + + assert len(img) == 2 + + assert img[0].mode == 'P' + assert img[0].size == (1, 1) + assert list(img[0].getdata()) in [[0], [1]] + + assert img[1].mode == 'P' + assert img[1].size == (2, 2) + assert list(img[1].getdata()) == [0, 2, 3, 1] + +def test_image_scanner_input_level_scale(image_test): + img = Image.new(mode='RGB', size=(48, 48)) + scan = ImageScanner(palette=image_test[1], + levels=4, level_scale=[2, 3, 4]) + img = scan.input(img) + size = [scan.level(img, level).size for level in range(scan.levels)] + assert size == [(2, 2), (4, 4), (12, 12), (48, 48)] + +def test_image_scanner_scan(image_test): + image, palette = image_test + scan = ImageScanner(palette=palette, traversal=HLines()) + assert [list(level) for level in scan(image)] == [ + ['00', '02', '03', '01', scan.END] + ] + +def test_image_scanner_scan_levels(image_test): + image, palette = image_test + scan = ImageScanner(palette=palette, + levels=2, level_scale=2, + traversal=[HLines(), VLines()]) + assert [list(level) for level in scan(image)] in [ + [ + ['00', scan.END], + [(scan.START, '0000'), + '0100', '0103', '0102', '0101', + scan.END] + ], + [ + ['01', scan.END], + [(scan.START, '0001'), + '0100', '0103', '0102', '0101', + scan.END] ] - palette.extend(0 for _ in range((256 - len(palette)) * 3)) - cls.image = Image.frombytes('RGB', (2, 2), data, 'raw', 'RGB') - cls.palette = palette - - def test_properties(self): - scan = ImageScanner() - - self.assertEqual(scan.palette, default_palette(8, 4, 8)) - self.assertIsNotNone(scan.palette_image) - self.assertIsNotNone(scan.levels) - self.assertIsNotNone(scan.min_size) - self.assertIsInstance(scan.level_scale, list) - self.assertIsInstance(scan.traversal, list) - self.assertEqual(len(scan.level_scale), scan.levels - 1) - self.assertEqual(len(scan.traversal), scan.levels) - - scan.palette = self.palette - self.assertIsNotNone(scan.palette) - self.assertIsNotNone(scan.palette_image) - - with self.assertRaises(ValueError): - scan.levels = 0 - scan.levels = 3 - self.assertEqual(len(scan.traversal), scan.levels) - - with self.assertRaises(ValueError): - scan.level_scale = [] - with self.assertRaises(ValueError): - scan.level_scale = [1, -1] - scan.level_scale = range(2, 10) - self.assertEqual(scan.level_scale, [2, 3]) - self.assertEqual(scan.min_size, 6) - scan.level_scale = 2 - self.assertEqual(scan.level_scale, [2, 2]) - self.assertEqual(scan.min_size, 4) - scan.levels = 2 - self.assertEqual(len(scan.traversal), scan.levels) - self.assertEqual(scan.level_scale, [2]) - self.assertEqual(scan.min_size, 2) - scan.levels = 1 - self.assertEqual(len(scan.traversal), scan.levels) - self.assertEqual(scan.level_scale, []) - self.assertEqual(scan.min_size, 1) - - traversal = [HLines(), VLines()] - scan.traversal = traversal - scan.levels = 2 - self.assertIs(scan.traversal, traversal) - self.assertEqual(scan.level_scale, [2]) - - def test_input(self): - tests = [(1, 1), (4, 1), (1, 4)] - for test in tests: - scan = ImageScanner(palette=self.palette, resize=test) - img = scan.input(self.image) - self.assertLessEqual(img.size, test) - self.assertEqual(img.mode, 'RGB') - img = [scan.level(img, level) for level in range(scan.levels)] - self.assertEqual(len(img), 1) - img = img[0] - self.assertEqual(img.mode, 'P') - self.assertEqual(img.size, (1, 1)) - self.assertIn(list(img.getdata()), [[0], [1]]) - - def test_input_levels(self): - scan = ImageScanner(palette=self.palette, levels=2, level_scale=2) - - img = scan.input(self.image) - img = [scan.level(img, level) for level in range(scan.levels)] - - self.assertEqual(len(img), 2) - - self.assertEqual(img[0].mode, 'P') - self.assertEqual(img[0].size, (1, 1)) - self.assertIn(list(img[0].getdata()), [[0], [1]]) - - self.assertEqual(img[1].mode, 'P') - self.assertEqual(img[1].size, (2, 2)) - self.assertEqual(list(img[1].getdata()), [0, 2, 3, 1]) - - scan = ImageScanner(palette=self.palette, levels=3, level_scale=2) - with self.assertRaises(ValueError): - scan.input(self.image) - - def test_input_level_scale(self): - img = Image.new(mode='RGB', size=(48, 48)) - scan = ImageScanner(palette=self.palette, - levels=4, level_scale=[2, 3, 4]) - - img = scan.input(img) - size = [scan.level(img, level).size for level in range(scan.levels)] - - self.assertEqual(size, [(2, 2), (4, 4), (12, 12), (48, 48)]) - - def test_scan(self): - scan = ImageScanner(palette=self.palette, traversal=HLines()) - self.assertEqual([list(level) for level in scan(self.image)], - [['00', '02', '03', '01', scan.END]]) - - def test_scan_levels(self): - scan = ImageScanner(palette=self.palette, - levels=2, level_scale=2, - traversal=[HLines(), VLines()]) - self.assertIn( - [list(level) for level in scan(self.image)], - [ - [ - ['00', scan.END], - [(scan.START, '0000'), - '0100', '0103', '0102', '0101', - scan.END] - ], - [ - ['01', scan.END], - [(scan.START, '0001'), - '0100', '0103', '0102', '0101', - scan.END] - ], - ] - ) - - def test_save_load(self): - tests = [ - (), - ((4, 4), 0, True, self.palette, 2, 2, - Image.NEAREST, [HLines(), VLines()]) - ] - for test in tests: - scanner = ImageScanner(*test) - saved = scanner.save() - loaded = Scanner.load(saved) - self.assertEqual(scanner, loaded) + ] + +@pytest.mark.parametrize('test', [ + (), + ((4, 4), 0, True, image_test()[1], 2, 2, + Image.NEAREST, [HLines(), VLines()]) +]) +def test_image_scanner_save_load(test): + scanner = ImageScanner(*test) + saved = scanner.save() + loaded = Scanner.load(saved) + assert isinstance(loaded, ImageScanner) + assert scanner == loaded diff --git a/tests/image/test_traversal.py b/tests/image/test_traversal.py index 39c695e..db3fc9e 100644 --- a/tests/image/test_traversal.py +++ b/tests/image/test_traversal.py @@ -1,176 +1,177 @@ -from unittest import TestCase from itertools import product +import pytest from markovchain.image.traversal import ( - Traversal, Lines, HLines, VLines, Spiral, Hilbert, Blocks + Traversal, HLines, VLines, Spiral, Hilbert, Blocks ) -class TestLines(TestCase): - def setUp(self): - Traversal.add_class(Lines) - - def tearDown(self): - Traversal.remove_class(Lines) - - def test_save_load(self): - tests = [(0, False), (2, True)] - for test in tests: - test = Lines(*test) - saved = test.save() - loaded = Traversal.load(saved) - self.assertEqual(test, loaded) - - -class TestHLines(TestCase): - def test_traverse(self): - test = HLines() - - test.reverse = 0 - self.assertEqual(list(test(2, 3, False)), - [(0, 0), (1, 0), (0, 1), (1, 1), (0, 2), (1, 2)]) - self.assertEqual(list(test(2, 3, True)), - [(0, 0), (1, 0), (0, 1), (1, 1), (0, 2), (1, 2)]) - test.reverse = 1 - self.assertEqual(list(test(2, 3)), - [(1, 0), (0, 0), (1, 1), (0, 1), (1, 2), (0, 2)]) - test.reverse = 2 - self.assertEqual(list(test(2, 3)), - [(1, 0), (0, 0), (0, 1), (1, 1), (1, 2), (0, 2)]) - test.reverse = 0 - test.line_sentences = True - self.assertEqual(list(test(2, 3, True)), - [(0, 0), (1, 0), None, - (0, 1), (1, 1), None, - (0, 2), (1, 2), None]) - self.assertEqual(list(test(2, 3, False)), - [(0, 0), (1, 0), (0, 1), - (1, 1), (0, 2), (1, 2)]) - - def test_save_load(self): - tests = [(0, False), (2, True)] - for test in tests: - test = HLines(*test) - saved = test.save() - loaded = Traversal.load(saved) - self.assertEqual(test, loaded) - - -class TestVLines(TestCase): - def test_traverse(self): - test = VLines() - - test.reverse = 0 - self.assertEqual(list(test(2, 3)), - [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]) - test.reverse = 1 - self.assertEqual(list(test(2, 3)), - [(0, 2), (0, 1), (0, 0), (1, 2), (1, 1), (1, 0)]) - test.reverse = 2 - self.assertEqual(list(test(2, 3)), - [(0, 2), (0, 1), (0, 0), (1, 0), (1, 1), (1, 2)]) - test.reverse = 0 - test.line_sentences = True - self.assertEqual(list(test(2, 3)), - [(0, 0), (0, 1), (0, 2), None, - (1, 0), (1, 1), (1, 2), None]) - self.assertEqual(list(test(2, 3, False)), - [(0, 0), (0, 1), (0, 2), - (1, 0), (1, 1), (1, 2)]) - - def test_save_load(self): - tests = [(0, False), (2, True)] - for test in tests: - test = VLines(*test) - saved = test.save() - loaded = Traversal.load(saved) - self.assertEqual(test, loaded) - - -class TestSpiral(TestCase): - def test_traverse(self): - test = Spiral() - tests = product(range(1, 7), range(1, 7)) - for width, height in tests: - test.reverse = False - res = list(test(width, height, False)) - self.assertCountEqual(res, product(range(width), range(height))) - test.reverse = True - res2 = list(test(width, height, False)) - res.reverse() - self.assertEqual(res2, res) - - def test_save_load(self): - tests = [(False,), (True,)] - for test in tests: - test = Spiral(*test) - saved = test.save() - loaded = Traversal.load(saved) - self.assertEqual(test, loaded) - - -class TestHilbert(TestCase): - def test_traverse(self): - test = Hilbert() - tests = product(range(1, 7), range(1, 7)) - for width, height in tests: - res = list(test(width, height)) - self.assertCountEqual(res, product(range(width), range(height))) - - def test_save_load(self): - test = Hilbert() - saved = test.save() - loaded = Traversal.load(saved) - self.assertEqual(test, loaded) - - -class TestBlocks(TestCase): - def test_traverse(self): - def hline(width, height, ends): # pylint: disable=unused-argument - for x in range(width): - yield (x, 0) - if ends and x == 0: - yield None - - def vline(width, height, ends): # pylint: disable=unused-argument - for y in range(height): - yield (0, y) - if ends: - yield None - - traverse = Blocks(block_size=(2, 2), - block_sentences=False, - traverse_image=hline, - traverse_block=vline) - - tests = [ - ((4, 4, False), [(0, 0), (0, 1), (2, 0), (2, 1)]), - ((3, 3, False), [(0, 0), (0, 1)]), - ((4, 4, True), [(0, 0), (0, 1), None, (2, 0), (2, 1)]) - ] - - for test, res in tests: - self.assertEqual(list(traverse(*test)), res) - - traverse.block_sentences = True - - tests = [ - ((4, 4, False), [(0, 0), (0, 1), (2, 0), (2, 1)]), - ((3, 3, True), [(0, 0), (0, 1), None, None]), - ((4, 4, True), [(0, 0), (0, 1), None, None, (2, 0), (2, 1), None]) - ] - - for test, res in tests: - self.assertEqual(list(traverse(*test)), res) - - def test_save_load(self): - tests = [ - ((8, 8), True, - HLines(reverse=1, line_sentences=True), - VLines(reverse=2, line_sentences=False)) - ] - for test in tests: - test = Blocks(*test) - saved = test.save() - loaded = Traversal.load(saved) - self.assertEqual(test, loaded) +@pytest.mark.parametrize('args,test,res', [ + ((0,), (2, 3, False), [(0, 0), (1, 0), (0, 1), (1, 1), (0, 2), (1, 2)]), + ((0,), (2, 3, True), [(0, 0), (1, 0), (0, 1), (1, 1), (0, 2), (1, 2)]), + ((1,), (2, 3), [(1, 0), (0, 0), (1, 1), (0, 1), (1, 2), (0, 2)]), + ((2,), (2, 3), [(1, 0), (0, 0), (0, 1), (1, 1), (1, 2), (0, 2)]), + ((0, True), (2, 3, False), + [(0, 0), (1, 0), (0, 1), (1, 1), (0, 2), (1, 2)]), + ((0, True), (2, 3, True), + [(0, 0), (1, 0), None, (0, 1), (1, 1), None, (0, 2), (1, 2), None]) +]) +def test_hlines_traverse(args, test, res): + assert list(HLines(*args)(*test)) == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((1, True), (1, True), True), + ((0, False), (0, True), False), + ((0, False), (1, False), False) +]) +def test_hlines_eq(test, test2, res): + assert (HLines(*test) == HLines(*test2)) == res + +@pytest.mark.parametrize('test', [ + (0, False), (2, True) +]) +def test_hlines_save_load(test): + test = HLines(*test) + saved = test.save() + loaded = Traversal.load(saved) + assert isinstance(loaded, HLines) + assert test == loaded + + +@pytest.mark.parametrize('args,test,res', [ + ((0,), (2, 3, False), [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]), + ((0,), (2, 3, True), [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]), + ((1,), (2, 3), [(0, 2), (0, 1), (0, 0), (1, 2), (1, 1), (1, 0)]), + ((2,), (2, 3), [(0, 2), (0, 1), (0, 0), (1, 0), (1, 1), (1, 2)]), + ((0, True), (2, 3, False), + [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]), + ((0, True), (2, 3, True), + [(0, 0), (0, 1), (0, 2), None, (1, 0), (1, 1), (1, 2), None]) +]) +def test_vlines_traverse(args, test, res): + assert list(VLines(*args)(*test)) == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((1, True), (1, True), True), + ((0, False), (0, True), False), + ((0, False), (1, False), False) +]) +def test_vlines_eq(test, test2, res): + assert (VLines(*test) == VLines(*test2)) == res + +@pytest.mark.parametrize('test', [ + (0, False), (2, True) +]) +def test_vlines_save_load(test): + test = VLines(*test) + saved = test.save() + loaded = Traversal.load(saved) + assert isinstance(loaded, VLines) + assert test == loaded + + +@pytest.mark.parametrize('width,height', product(range(1, 7), range(1, 7))) +def test_spiral_traverse(width, height): + test = Spiral() + test.reverse = False + res = list(test(width, height, False)) + assert sorted(res) == list(product(range(width), range(height))) + test.reverse = True + res2 = list(test(width, height, False)) + res.reverse() + assert res2 == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((True,), (False,), False), + ((True,), (True,), True) +]) +def test_spiral_eq(test, test2, res): + assert (Spiral(*test) == Spiral(*test2)) == res + +@pytest.mark.parametrize('test', [ + (False,), (True,) +]) +def test_spiral_save_load(test): + test = Spiral(*test) + saved = test.save() + loaded = Traversal.load(saved) + assert isinstance(loaded, Spiral) + assert test == loaded + + +@pytest.mark.parametrize('width,height', product(range(1, 7), range(1, 7))) +def test_hilbert_traverse(width, height): + test = Hilbert() + res = list(test(width, height)) + assert sorted(res) == list(product(range(width), range(height))) + +def test_hilbert_save_load(): + test = Hilbert() + saved = test.save() + loaded = Traversal.load(saved) + assert isinstance(loaded, Hilbert) + assert test == loaded + + +@pytest.mark.parametrize('test,res,res2', [ + ( + (4, 4, False), + [(0, 0), (0, 1), (2, 0), (2, 1)], + [(0, 0), (0, 1), (2, 0), (2, 1)] + ), + ( + (3, 3, True), + [(0, 0), (0, 1), None], + [(0, 0), (0, 1), None, None] + ), + ( + (4, 4, True), + [(0, 0), (0, 1), None, (2, 0), (2, 1)], + [(0, 0), (0, 1), None, None, (2, 0), (2, 1), None] + ) +]) +def test_blocks_traverse(test, res, res2): + def hline(width, height, ends): # pylint: disable=unused-argument + for x in range(width): + yield (x, 0) + if ends and x == 0: + yield None + + def vline(width, height, ends): # pylint: disable=unused-argument + for y in range(height): + yield (0, y) + if ends: + yield None + + traverse = Blocks(block_size=(2, 2), + block_sentences=False, + traverse_image=hline, + traverse_block=vline) + assert list(traverse(*test)) == res + traverse.block_sentences = True + assert list(traverse(*test)) == res2 + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + (((8, 8),), ((4, 4),), False), + (((4, 4), True), ((4, 4), False), False), + (((4, 4), True, 0), ((4, 4), True, 1), False), + (((4, 4), True, 0, 0), ((4, 4), True, 0, 1), False) +]) +def test_blocks_eq(test, test2, res): + assert (Blocks(*test) == Blocks(*test2)) == res + +@pytest.mark.parametrize('test', [ + ((8, 8), True, + HLines(reverse=1, line_sentences=True), + VLines(reverse=2, line_sentences=False)) +]) +def test_blocks_save_load(test): + test = Blocks(*test) + saved = test.save() + loaded = Traversal.load(saved) + assert isinstance(loaded, Blocks) + assert test == loaded diff --git a/tests/image/test_util.py b/tests/image/test_util.py index ac549cf..09532aa 100644 --- a/tests/image/test_util.py +++ b/tests/image/test_util.py @@ -1,57 +1,53 @@ -from unittest import TestCase from collections import Counter from itertools import islice +import pytest from markovchain.image.util import palette -class TestPalette(TestCase): - def test_errors(self): - tests = [ - (8, 8, 5), - (0, 4, 4), - (4, 0, 4), - (4, 4, 0), - (8, -4, 8), - (-8, 4, -8) - ] - for test in tests: - with self.assertRaises(ValueError): - palette(*test) +@pytest.mark.parametrize('test', [ + (8, 8, 5), + (0, 4, 4), + (4, 0, 4), + (4, 4, 0), + (8, -4, 8), + (-8, 4, -8) +]) +def test_palette_errors(test): + with pytest.raises(ValueError): + palette(*test) - def test_generate(self): - tests = [ - (8, 4, 8), - (2, 2, 2), - (16, 1, 1) - ] - for test in tests: - res = palette(*test) - self.assertEqual(len(res), 768) - res = list(zip(islice(res, 0, None, 3), - islice(res, 1, None, 3), - islice(res, 2, None, 3))) - counter = Counter(res) - size = test[0] * test[1] * test[2] - self.assertEqual(len(counter.items()), min(256, size + 1)) - self.assertEqual(counter[(0, 0, 0)], 256 - size) +@pytest.mark.parametrize('test', [ + (8, 4, 8), + (2, 2, 2), + (16, 1, 1) +]) +def test_palette(test): + res = palette(*test) + assert len(res) == 768 + res = list(zip(islice(res, 0, None, 3), + islice(res, 1, None, 3), + islice(res, 2, None, 3))) + counter = Counter(res) + size = test[0] * test[1] * test[2] + assert len(counter.items()) == min(256, size + 1) + assert counter[(0, 0, 0)] == 256 - size - def test_grayscale(self): - tests = [ - (1, 1, 1), - (1, 1, 2), - (1, 1, 256) - ] - for test in tests: - res = palette(*test) - self.assertEqual(len(res), 768) - res = list(zip(islice(res, 0, None, 3), - islice(res, 1, None, 3), - islice(res, 2, None, 3))) - self.assertTrue(all(r == g and g == b for r, g, b in res)) - counter = Counter(res) - size = test[0] * test[1] * test[2] - self.assertEqual(len(counter.items()), size) - self.assertEqual(counter[(0, 0, 0)], 256 - size + 1) - if size > 1: - self.assertEqual(counter[(255, 255, 255)], 1) +@pytest.mark.parametrize('test', [ + (1, 1, 1), + (1, 1, 2), + (1, 1, 256) +]) +def test_palette_grayscale(test): + res = palette(*test) + assert len(res) == 768 + res = list(zip(islice(res, 0, None, 3), + islice(res, 1, None, 3), + islice(res, 2, None, 3))) + assert all(r == g and g == b for r, g, b in res) + counter = Counter(res) + size = test[0] * test[1] * test[2] + assert len(counter.items()) == size + assert counter[(0, 0, 0)] == 256 - size + 1 + if size > 1: + assert counter[(255, 255, 255)] == 1 diff --git a/tests/storage/__init__.py b/tests/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py new file mode 100644 index 0000000..11fd201 --- /dev/null +++ b/tests/storage/test_base.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock +import pytest + +from markovchain.storage.base import Storage + + +class StorageTest(Storage): + def replace_state_separator(self, old_separator, new_separator): + pass + def links(self, links): + pass + def random_link(self, state): + pass + def do_save(self, fp=None): + pass + @classmethod + def load(cls, fp): + pass + +def test_storage_base_abstract(): + with pytest.raises(TypeError): + Storage() + +def test_storage_base_properties(): + storage = StorageTest() + storage.replace_state_separator = Mock() + assert storage.state_separator == ' ' + storage.state_separator = ':' + storage.replace_state_separator.assert_called_once_with(' ', ':') + assert storage.state_separator == ':' + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + (({'state_separator': '-'},), ({'state_separator': '+'},), False) +]) +def test_storage_base_eq(test, test2, res): + assert (StorageTest(*test) == StorageTest(*test2)) == res + +@pytest.mark.parametrize('settings,test,res', [ + ({}, 'a b c', ['a', 'b', 'c']), + ({'storage': {'state_separator': '-'}}, 'a b-c d', ['a b', 'c d']), +]) +def test_storage_base_split_state(settings, test, res): + assert StorageTest(settings).split_state(test) == res + +@pytest.mark.parametrize('settings,test,res', [ + ({}, map(str, range(3)), '0 1 2'), + ({'storage': {'state_separator': '-'}}, ['a', 'b'], 'a-b'), +]) +def test_storage_base_join_state(settings, test, res): + assert StorageTest(settings).join_state(test) == res + +def test_storage_base_save(): + storage = StorageTest() + storage.do_save = Mock() + storage.state_separator = '+' + storage.save(0) + assert storage.settings['storage']['state_separator'] == '+' + storage.do_save.assert_called_once_with(0) diff --git a/tests/storage/test_json.py b/tests/storage/test_json.py new file mode 100644 index 0000000..3f97d5b --- /dev/null +++ b/tests/storage/test_json.py @@ -0,0 +1,82 @@ +from io import StringIO +import pytest + +from markovchain import JsonStorage + + +def test_json_storage_empty(): + storage = JsonStorage() + assert storage.nodes == {} + assert storage.settings == {} + +def test_json_storage_add_links(): + storage = JsonStorage() + storage.links([(('x',), 'y'), (('y',), 'z'), (('x',), 'y')]) + assert storage.nodes == { + 'x': ['y', 2], + 'y': ['z', 1] + } + storage.links([(('z',), 'x'), (('x',), 'z'), (('x',), 'y')]) + assert storage.nodes == { + 'x': [['y', 'z'], [3, 1]], + 'y': ['z', 1], + 'z': ['x', 1] + } + +@pytest.mark.parametrize('links,state,random,call,res', [ + ([], 'x', None, None, None), + ([(('x',), 'y')], 'y', None, None, None), + ([(('x',), 'y')], 'x', 0, None, 'y'), + ([(('x',), 'y'), (('x',), 'z')], 'x', 0, (0, 1), 'y'), + ([(('x',), 'y'), (('x',), 'z')], 'x', 1, (0, 1), 'z') +]) +def test_json_storage_random_link(mocker, links, state, random, call, res): + randint = mocker.patch( + 'markovchain.storage.json.randint', + return_value=random + ) + storage = JsonStorage() + storage.links(links) + link, next_state = storage.random_link([state]) + assert link == res + if res is None: + assert next_state is None + else: + assert next_state == [state, res] + if call is None: + assert randint.call_count == 0 + else: + randint.assert_called_once_with(*call) + +def test_json_storage_state_separator(): + storage = JsonStorage() + storage.links([(('x', 'y'), 'z'), (('y', 'z'), 'x')]) + assert storage.nodes == { + 'x y': ['z', 1], + 'y z': ['x', 1] + } + storage.state_separator = ':' + assert storage.nodes == { + 'x:y': ['z', 1], + 'y:z': ['x', 1] + } + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((), ({}), True), + ((), ({'x':['y', 1]}), False), + (({'x':['y', 1]}), ({'x':['y', 1]}), True), + ((), ({}, {'state_separator': ':'}), False) +]) +def test_json_storage_eq(test, test2, res): + assert (JsonStorage(*test) == JsonStorage(*test2)) == res + +def test_json_storage_save_load(): + storage = JsonStorage(settings={'state_separator': ':'}) + storage.links([(('x',), 'y'), (('y',), 'z'), (('x',), 'y'), (('x',), 'z')]) + + fp = StringIO() + storage.save(fp) + fp.seek(0) + loaded = JsonStorage.load(fp) + assert storage == loaded diff --git a/tests/storage/test_sqlite.py b/tests/storage/test_sqlite.py new file mode 100644 index 0000000..70f9220 --- /dev/null +++ b/tests/storage/test_sqlite.py @@ -0,0 +1,79 @@ +import os +import pytest + +from markovchain import SqliteStorage + + +def get_nodes(cursor): + cursor.execute('SELECT id, value FROM nodes') + return cursor.fetchall() + +def get_links(cursor, source): + cursor.execute( + 'SELECT value, count FROM links WHERE source=?', + (source,) + ) + return cursor.fetchall() + +def test_sqlite_storage_empty(): + storage = SqliteStorage() + assert storage.db + assert storage.cursor + tables = storage.get_tables() + assert 'main' in tables + assert 'nodes' in tables + assert 'links' in tables + +def test_sqlite_storage_add_links(): + storage = SqliteStorage() + storage.links([(('x',), 'y'), (('y',), 'z'), (('x',), 'y')]) + assert get_nodes(storage.cursor) == [(1, 'x'), (2, 'y'), (3, 'z')] + assert get_links(storage.cursor, 1) == [('y', 2)] + assert get_links(storage.cursor, 2) == [('z', 1)] + assert get_links(storage.cursor, 3) == [] + + storage.links([(('x',), 'z'), (('x',), 'y')]) + assert get_links(storage.cursor, 1) == [('y', 3), ('z', 1)] + +@pytest.mark.parametrize('links,state,random,call,res', [ + ([], 'x', None, None, None), + ([(('x',), 'y')], 'y', None, None, None), + ([(('x',), 'y')], 'x', 0, (0, 0), 'y'), + ([(('x',), 'y'), (('x',), 'z')], 'x', 0, (0, 1), 'y'), + ([(('x',), 'y'), (('x',), 'z')], 'x', 1, (0, 1), 'z') +]) +def test_sqlite_storage_random_link(mocker, links, state, random, call, res): + randint = mocker.patch( + 'markovchain.storage.sqlite.randint', + return_value=random + ) + storage = SqliteStorage() + storage.links(links) + link, next_state = storage.random_link([state]) + assert link == res + if res is None: + assert next_state is None + else: + assert next_state == storage.get_node(link) + if call is None: + assert randint.call_count == 0 + else: + randint.assert_called_once_with(*call) + +def test_sqlite_storage_state_separator(): + storage = SqliteStorage() + storage.links([(('x', 'y'), 'z'), (('y', 'z'), 'x')]) + assert get_nodes(storage.cursor) == [(1, 'x y'), (2, 'y z'), (3, 'z x')] + storage.state_separator = ':' + assert get_nodes(storage.cursor) == [(1, 'x:y'), (2, 'y:z'), (3, 'z:x')] + +def test_sqlite_storage_save_load(tmpdir): + db = os.path.join(str(tmpdir), 'test.db') + storage = SqliteStorage(db=db) + storage.state_separator = ':' + storage.links([(('x', 'y'), 'z'), (('y', 'z'), 'x')]) + nodes = get_nodes(storage.cursor) + storage.save() + loaded = SqliteStorage.load(db) + assert nodes == get_nodes(loaded.cursor) + assert storage.state_separator == loaded.state_separator diff --git a/tests/test_base.py b/tests/test_base.py index 646db04..310f924 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,32 +1,104 @@ -from unittest import TestCase - -from markovchain import MarkovBase, Parser - - -class TestMarkovBase(TestCase): - class Scanner: - def reset(self): - pass - def __call__(self, data, part): - pass - - class Markov(MarkovBase): # pylint:disable=abstract-method - def links(self, links): - pass - - def test_properties(self): - m = self.Markov() - self.assertIsInstance(m.parser, Parser) - self.assertIsNotNone(m.separator) - - def test_generate_error(self): - m = self.Markov() - m.parser = None - with self.assertRaises(ValueError): - list(m.generate(10)) - - def test_save_load(self): - m = self.Markov(separator=':', parser=Parser()) - saved = m.get_save_data() - loaded = self.Markov(**saved) - self.assertEqual(m, loaded) +from collections import deque +from unittest.mock import Mock, call +import pytest + +from markovchain import Markov +from markovchain.parser import ParserBase +from markovchain.scanner import Scanner + + +def test_markov_base_properties(): + storage = Mock(settings={}) + markov = Markov(storage=storage) + assert isinstance(markov.parser, ParserBase) + assert isinstance(markov.scanner, Scanner) + assert markov.storage is storage + +@pytest.mark.parametrize('data,part', [ + ([], False), + ([1, 2, 3], True) +]) +def test_markov_base_data(data, part): + storage = Mock(settings={}) + scanner = Mock(return_value=0) + parser = Mock(return_value=1) + markov = Markov(parser=parser, scanner=scanner, storage=storage) + markov.data(data, part) + scanner.assert_called_once_with(data, part) + parser.assert_called_once_with(0, part) + storage.links.assert_called_once_with(1) + +def test_markov_base_generate_empty(): + markov = Markov(parser=Mock(state_sizes=[])) + assert list(markov.generate(state_size=None)) == [] + +def test_markov_base_generate_error(): + markov = Markov() + markov.parser = None + with pytest.raises(ValueError): + list(markov.generate(state_size=None)) + +@pytest.mark.parametrize('state_size,start,res', [ + (1, None, ['']), + (2, None, ['', '']), + (3, None, ['', '', '']), + (3, 'a b', ['', 'a', 'b']), + (3, range(2), ['', 0, 1]) +]) +def test_markov_base_generate(state_size, start, res): + storage = Mock( + random_link=Mock(side_effect=[('link', 'state'), (None, None)]), + split_state=lambda s: s.split() + ) + markov = Markov( + parser=Mock(), + scanner=Mock(), + storage=storage + ) + assert list(markov.generate(start=start, state_size=state_size)) == ['link'] + assert storage.random_link.call_count == 2 + storage.random_link.assert_has_calls([ + call(deque(res, maxlen=state_size)), + call('state') + ]) + +def test_markov_base_get_settings_json(): + markov = Markov( + parser=Mock(save=lambda: 0), + scanner=Mock(save=lambda: 1), + storage=Mock() + ) + json = markov.get_settings_json() + assert json == { + 'scanner': 1, + 'parser': 0 + } + +@pytest.mark.parametrize('test,res', [ + (((0, 0, 0), (0, 0, 0)), True), + (((0, 0, 0), (0, 0, 1)), False), + (((0, 0, 0), (0, 1, 0)), False), + (((0, 0, 0), (1, 0, 0)), False) +]) +def test_markov_eq(test, res): + markov = [] + for scanner, parser, storage in test: + m = Markov() + m.scanner = scanner + m.parser = parser + m.storage = storage + markov.append(m) + assert (markov[0] == markov[1]) == res + +def test_markov_base_save(mocker): + mocker.patch('markovchain.Markov.get_settings_json', return_value=0) + storage = Mock(settings={}) + markov = Markov( + parser=Mock(), + scanner=Mock(), + storage=storage + ) + markov.save(1) + assert storage.settings == {'markov': 0} + storage.save.assert_called_once_with(1) + markov.get_settings_json.assert_called_once_with() # pylint:disable=no-member diff --git a/tests/test_json.py b/tests/test_json.py deleted file mode 100644 index 49eedb9..0000000 --- a/tests/test_json.py +++ /dev/null @@ -1,91 +0,0 @@ -from unittest import TestCase -from io import StringIO, BytesIO - -from markovchain import (MarkovBase, MarkovJsonMixin, - Scanner, CharScanner, Parser) - - -class TestMarkovJson(TestCase): - class Markov(MarkovJsonMixin, MarkovBase): - pass - - def test_empty(self): - m = self.Markov() - self.assertFalse(m.nodes) - - def test_properties(self): - m = self.Markov(scanner=Scanner(lambda x: x)) - m.links([(('x', 'y'), 'z')]) - m.separator = '::' - self.assertEqual(list(m.nodes.keys()), ['x::y']) - - def test_add_links(self): - m = self.Markov() - m.links([(('x',), 'y'), (('y',), 'z'), (('x',), 'y')]) - self.assertEqual( - m.nodes, - { - 'x': ['y', 2], - 'y': ['z', 1] - } - ) - m.links([(('z',), 'x'), (('x',), 'z')]) - self.assertEqual( - m.nodes, - { - 'x': [['y', 'z'], [2, 1]], - 'y': ['z', 1], - 'z': ['x', 1] - } - ) - - def test_generate_empty(self): - m = self.Markov() - self.assertEqual(''.join(m.generate(10)), '') - m = self.Markov() - m.links([(('x',), 'y')]) - self.assertEqual(''.join(m.generate(-1, start='x')), '') - self.assertEqual(''.join(m.generate(0, start='x')), '') - m.parser = None - self.assertEqual(''.join(m.generate(10, state_size=4)), '') - - def test_generate(self): - m = self.Markov(scanner=Scanner(lambda x: x)) - m.data(['x', 'y']) - self.assertEqual(''.join(m.generate(1, start='')), 'x') - self.assertEqual(''.join(m.generate(10, start='x')), 'y') - self.assertEqual(''.join(m.generate(10, start='y')), '') - self.assertIn(''.join(m.generate(10)), ['y', 'xy']) - - def test_generate_state_size(self): - m = self.Markov(separator=':', - parser=Parser(state_sizes=[2, 3]), - scanner=Scanner(lambda x: x)) - m.data(['x', 'y', 'z']) - self.assertEqual(''.join(m.generate(10, state_size=2)), 'xyz') - self.assertEqual(''.join(m.generate(10, state_size=3)), 'xyz') - - def test_save_load(self): - m = self.Markov(separator=':', - parser=Parser(state_sizes=[2, 3]), - scanner=Scanner(lambda x: x)) - m.data(['', 'x', 'y', 'z', None]) - m.scanner = CharScanner() - - fp = StringIO() - m.save(fp) - fp.seek(0) - loaded = self.Markov.load(fp) - self.assertEqual(m, loaded) - - fp.seek(0) - fp1 = BytesIO() - fp1.write(fp.read().encode('utf-8')) - fp1.seek(0) - loaded = self.Markov.load(fp1) - self.assertEqual(m, loaded) - - fp.seek(0) - loaded = self.Markov.load(fp, {'separator': ''}) - self.assertNotEqual(m, loaded) - self.assertEqual(loaded.separator, '') diff --git a/tests/test_parser.py b/tests/test_parser.py index f0361fa..6524e24 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,144 +1,153 @@ -from unittest import TestCase +from unittest.mock import Mock +import pytest from markovchain import Parser, LevelParser, Scanner from markovchain.parser import ParserBase -class TestParser(TestCase): - @staticmethod - def parse(parser, scanner, data, part=False, separator=' '): - return [(separator.join(src), dst) - for src, dst in parser(scanner(data, part), part)] - - def test_properties(self): - parser = Parser() - - with self.assertRaises(ValueError): - parser.state_sizes = [1, 0, 2] - with self.assertRaises(ValueError): - parser.state_sizes = [] - self.assertEqual(parser.state_sizes, [1]) - self.assertEqual(parser.state_size, 1) - - parser.state_sizes = [2, 1] - self.assertEqual(parser.state_sizes, [2, 1]) - self.assertEqual(parser.state_size, 2) - self.assertEqual(parser.state.maxlen, 2) - self.assertEqual(list(parser.state), ['', '']) - - parser.state.append('test') - parser.state_size = parser.state_size - self.assertEqual(list(parser.state), ['', 'test']) - - parser.state_sizes = [3] - self.assertEqual(list(parser.state), ['', '', '']) - - def test_properties_parse(self): - scanner = Scanner(lambda x: x) - parser = Parser(state_sizes=[2], - reset_on_sentence_end=False) - self.assertEqual(self.parse(parser, scanner, ['a', scanner.END, 'b'], - True, '::'), - [('::', 'a'), ('::a', 'b')]) - self.assertEqual(self.parse(parser, scanner, 'c', separator='::'), - [('a::b', 'c')]) - - def test_default_parse(self): - scanner = Scanner(lambda x: x) - parser = Parser() - - self.assertEqual(self.parse(parser, scanner, ''), []) - - self.assertEqual(self.parse(parser, scanner, 'abc', True), - [('', 'a'), ('a', 'b'), ('b', 'c')]) - self.assertEqual(self.parse(parser, scanner, - ['a', 'b', scanner.END, 'c']), - [('c', 'a'), ('a', 'b'), ('', 'c')]) - self.assertEqual(self.parse(parser, scanner, - ['a', Scanner.END, Scanner.END, 'c']), - [('', 'a'), ('', 'c')]) - self.assertEqual(self.parse(parser, scanner, [Scanner.END] * 4), []) - - def test_state_size(self): - scanner = Scanner(lambda x: x) - parser = Parser(state_sizes=[3]) - - self.assertEqual(self.parse(parser, scanner, 'abcde'), - [(' ', 'a'), (' a', 'b'), (' a b', 'c'), - ('a b c', 'd'), ('b c d', 'e')]) - self.assertEqual(self.parse(parser, scanner, - ['a', 'b', 'c', Scanner.END, 'd', 'e']), - [(' ', 'a'), (' a', 'b'), (' a b', 'c'), - (' ', 'd'), (' d', 'e')]) - self.assertEqual(self.parse(parser, scanner, - ['a', 'b', 'c', (Scanner.START, 'd'), 'e']), - [(' ', 'a'), (' a', 'b'), (' a b', 'c'), - (' d', 'e')]) - - def test_save_load(self): - parser = Parser(state_sizes=[1, 2, 3], - reset_on_sentence_end=False) - saved = parser.save() - loaded = Parser.load(saved) - self.assertEqual(parser, loaded) - - -class TestLevelParser(TestCase): - class ParserTest(ParserBase): - def __init__(self, parse=None): - super().__init__(parse) - self.is_reset = False - - def reset(self): - self.is_reset = True - - def test_properties(self): - parser = LevelParser() - self.assertEqual(parser.levels, 1) - self.assertEqual(parser.parsers, [Parser()]) - - with self.assertRaises(ValueError): - parser.levels = 0 - with self.assertRaises(ValueError): - parser.levels = -1 - self.assertEqual(parser.levels, 1) - - parser.levels = 2 - self.assertEqual(parser.parsers, [Parser(), Parser()]) - - level = Parser(state_sizes=[2, 3]) - parser.parsers = level - self.assertEqual(parser.parsers, [level, level]) - - parser.parsers = [Parser(), level, Parser()] - self.assertEqual(parser.parsers, [Parser(), level]) - - parser.levels = 1 - self.assertEqual(parser.parsers, [Parser()]) - - parser.levels = 2 - self.assertIs(parser.parsers[1], level) - - def test_reset(self): - parser = LevelParser(levels=2, - parsers=[self.ParserTest(), self.ParserTest()]) - self.assertEqual([x.is_reset for x in parser.parsers], [False, False]) # pylint:disable=no-member - parser.reset() - self.assertEqual([x.is_reset for x in parser.parsers], [True, True]) # pylint:disable=no-member - - def test_parse(self): - parser = LevelParser( - levels=2, - parsers=[self.ParserTest(lambda x: [0]), - self.ParserTest(lambda x: [1])] - ) - self.assertEqual(list(parser([[0], [1]])), [0, 1]) - self.assertEqual(list(parser([[0]] * 5)), [0, 1]) - self.assertEqual(list(parser([])), []) - - def test_save_load(self): - level = Parser(state_sizes=[2, 3]) - parser = LevelParser(levels=3, parsers=[level, Parser()]) - saved = parser.save() - loaded = Parser.load(saved) - self.assertEqual(parser, loaded) +def parse(parser, data, part=False, separator=' '): + return [(separator.join(src), dst) + for src, dst in parser(data, part)] + + +def test_parser_properties(): + parser = Parser() + + with pytest.raises(ValueError): + parser.state_sizes = [1, 0, 2] + with pytest.raises(ValueError): + parser.state_sizes = [] + + assert parser.state_sizes == [1] + assert parser.state_size == 1 + + parser.state_sizes = [2, 1] + assert parser.state_sizes == [2, 1] + assert parser.state_size == 2 + assert parser.state.maxlen == 2 + assert list(parser.state) == ['', ''] + + parser.state.append('test') + parser.state_size = parser.state_size + assert list(parser.state) == ['', 'test'] + + parser.state_sizes = [3] + assert list(parser.state) == ['', '', ''] + +def test_parser_parse(): + parser = Parser(state_sizes=[2], reset_on_sentence_end=False) + res = parse(parser, ['a', Scanner.END, 'b'], True, '::') + assert res == [('::', 'a'), ('::a', 'b')] + res = parse(parser, 'c', separator='::') + assert res == [('a::b', 'c')] + +def test_parser_parse_default(): + parser = Parser() + assert parse(parser, '') == [] + assert parse(parser, 'abc', True) == [ + ('', 'a'), ('a', 'b'), ('b', 'c') + ] + assert parse(parser, ['a', 'b', Scanner.END, 'c']) == [ + ('c', 'a'), ('a', 'b'), ('', 'c') + ] + assert parse(parser, ['a', Scanner.END, Scanner.END, 'c']) == [ + ('', 'a'), ('', 'c') + ] + assert parse(parser, [Scanner.END] * 4) == [] + +@pytest.mark.parametrize('test,res', [ + ('abcde', [(' ', 'a'), (' a', 'b'), + (' a b', 'c'), ('a b c', 'd'), ('b c d', 'e')]), + (['a', 'b', 'c', Scanner.END, 'd', 'e'], + [(' ', 'a'), (' a', 'b'), + (' a b', 'c'), (' ', 'd'), (' d', 'e')]), + (['a', 'b', 'c', (Scanner.START, 'd'), 'e'], + [(' ', 'a'), (' a', 'b'), (' a b', 'c'), (' d', 'e')]) +]) +def test_parser_state_size(test, res): + parser = Parser(state_sizes=[3]) + assert parse(parser, test) == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + (([2], False), ([2], False), True), + (([2], False), ([2], True), False), + (([2], False), ([2, 3], False), False) +]) +def test_parser_eq(test, test2, res): + parser = Parser(*test) + parser2 = Parser(*test2) + assert (parser == parser2) == res + +@pytest.mark.parametrize('test', [ + ([1, 2, 3], False) +]) +def test_parser_save_load(test): + parser = Parser(*test) + saved = parser.save() + loaded = Parser.load(saved) + assert parser == loaded + + +def test_level_parser_properties(): + parser = LevelParser() + assert parser.levels == 1 + assert parser.parsers == [Parser()] + + with pytest.raises(ValueError): + parser.levels = 0 + with pytest.raises(ValueError): + parser.levels = -1 + assert parser.levels == 1 + + parser.levels = 2 + assert parser.parsers == [Parser(), Parser()] + + level = Parser(state_sizes=[2, 3]) + parser.parsers = level + assert parser.parsers == [level, level] + + parser.parsers = [Parser(), level, Parser()] + assert parser.parsers == [Parser(), level] + + parser.levels = 1 + assert parser.parsers == [Parser()] + + parser.levels = 2 + assert parser.parsers[1] is level + +def test_level_parser_reset(): + parsers = [Mock(), Mock()] + parser = LevelParser(levels=2, parsers=parsers) + parser.reset() + for level in parsers: + level.reset.assert_called_once_with() + +@pytest.mark.parametrize('test,res', [ + ([[0], [1]], [0, 1]), + ([[0]] * 5, [0, 1]), + ([], []) +]) +def test_level_parser_parse(test, res): + parser = LevelParser( + levels=2, + parsers=[ParserBase(lambda x: [0]), + ParserBase(lambda x: [1])] + ) + assert list(parser(test)) == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((2, Parser()), (2, Parser()), True), + ((1,), (2,), False), + ((2, [Parser(), Parser(3)]), (2, Parser()), False) +]) +def test_level_parser_eq(test, test2, res): + assert (LevelParser(*test) == LevelParser(*test2)) == res + +def test_level_parser_save_load(): + level = Parser(state_sizes=[2, 3]) + parser = LevelParser(levels=3, parsers=[level, Parser()]) + saved = parser.save() + loaded = Parser.load(saved) + assert parser == loaded diff --git a/tests/test_scanner.py b/tests/test_scanner.py index 1bc49e8..6ce7ce2 100644 --- a/tests/test_scanner.py +++ b/tests/test_scanner.py @@ -1,93 +1,127 @@ -from unittest import TestCase +import pytest from markovchain.scanner import Scanner, CharScanner, RegExpScanner -class TestScanner(TestCase): - def test_id(self): - scan = Scanner(lambda x: x) - test = 'ab c.d' - self.assertEqual(''.join(scan(test)), test) - - -class ScannerTestCase(TestCase): - @staticmethod - def scan_str(scanner, data, part=False, sep=''): - return sep.join(word - for word in scanner(data, part) - if word != scanner.END) - - -class TestCharScanner(ScannerTestCase): - def test_id(self): - scan = CharScanner(None, None) - test = 'ab c.d' - self.assertEqual(list(scan('', True)), []) - self.assertEqual(list(scan('', False)), []) - self.assertEqual(list(scan(test, True)), list(test)) - self.assertEqual(list(scan(test, False)), list(test) + [None]) - - def test_default(self): - scan = CharScanner() - - self.assertEqual(list(scan('ab..c')), - ['a', 'b', '.', '.', scan.END, 'c', '.', scan.END]) - - self.assertEqual(list(scan('a b..c', True)), - ['a', ' ', 'b', '.', '.', scan.END, 'c']) - self.assertEqual(list(scan('.', True)), ['.']) - self.assertEqual(list(scan('', False)), [scan.END]) - - self.assertEqual(list(scan('abc', True)), ['a', 'b', 'c']) - self.assertEqual(list(scan('', False)), ['.', scan.END]) - - self.assertEqual(list(scan('...')), []) - - self.assertEqual(self.scan_str(scan, 'abc.de?!f'), 'abc.de?!f.') - self.assertEqual(self.scan_str(scan, '.?!.a'), 'a.') - self.assertEqual(self.scan_str(scan, '.?!.a', True), 'a') - self.assertEqual(self.scan_str(scan, 'a.'), 'a.') - - def test_save_load(self): - tests = [(), (None, None), ('ab', 'cd')] - for test in tests: - scanner = CharScanner(*test) - saved = scanner.save() - loaded = Scanner.load(saved) - self.assertEqual(scanner, loaded) - - -class TestRegExp(ScannerTestCase): - def test_id(self): - scan = RegExpScanner('.', None) - test = 'ab c.d' - self.assertEqual(list(scan('', True)), []) - self.assertEqual(list(scan('', False)), []) - self.assertEqual(list(scan(test, True)), list(test)) - self.assertEqual(list(scan(test, False)), list(test) + [scan.END]) - - def test_default(self): - scan = RegExpScanner() - - self.assertEqual(list(scan('ab..c')), - ['ab', '..', scan.END, 'c', '.', scan.END]) - - self.assertEqual(list(scan('a \n b?!. .. !! ??c', True)), - ['a', 'b', '?!.', scan.END, 'c']) - self.assertEqual(list(scan('.', True)), ['.', scan.END]) - self.assertEqual(list(scan('', False)), []) - - self.assertEqual(list(scan('... .. . ! \n ? ?! ')), []) - - self.assertEqual(self.scan_str(scan, 'a\t\nb\nc.d e ?!f'), - 'abc.de?!f.') - self.assertEqual(self.scan_str(scan, '.?!.a', True), 'a') - self.assertEqual(self.scan_str(scan, 'a.'), 'a.') - - def test_save_load(self): - tests = [(), ('.*', None)] - for test in tests: - scanner = RegExpScanner(*test) - saved = scanner.save() - loaded = Scanner.load(saved) - self.assertEqual(scanner, loaded) +def scan_str(scanner, data, part=False, sep=''): + return sep.join(word for word in scanner(data, part) if word != scanner.END) + + +def test_scanner_id(): + scan = Scanner(lambda x: x) + test = 'ab c.d' + assert ''.join(scan(test)) == test + + +@pytest.mark.parametrize('test,res', [ + (('', True), []), + (('', False), []), + (('ab c.d', True), list('ab c.d')), + (('ab c.d', False), list('ab c.d') + [Scanner.END]) +]) +def test_char_scanner_id(test, res): + scan = CharScanner(None, None) + assert list(scan(*test)) == res + +def test_char_scanner_default(): + scan = CharScanner() + + assert list(scan('ab..c')) == [ + 'a', 'b', '.', '.', scan.END, 'c', '.', scan.END + ] + assert list(scan('a b..c', True)) == [ + 'a', ' ', 'b', '.', '.', scan.END, 'c' + ] + assert list(scan('.', True)) == ['.'] + assert list(scan('', False)) == [scan.END] + + assert list(scan('abc', True)) == ['a', 'b', 'c'] + assert list(scan('', False)) == ['.', scan.END] + + assert list(scan('...')) == [] + +@pytest.mark.parametrize('test,res', [ + (('abc.de?!f',), 'abc.de?!f.'), + (('.?!.a',), 'a.'), + (('.?!.a', True), 'a'), + (('a.',), 'a.') +]) +def test_char_scanner_default_str(test, res): + scan = CharScanner() + assert scan_str(scan, *test) == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + ((None, None), (None, None), True), + (('x', 'y'), ('x', 'y'), True), + (('x', 'y'), ('x', 'x'), False), + (('x', 'y'), ('y', 'y'), False) +]) +def test_char_scanner_eq(test, test2, res): + scan = CharScanner(*test) + scan2 = CharScanner(*test2) + assert (scan == scan2) == res + +@pytest.mark.parametrize('test', [ + (), (None, None), ('ab', 'cd') +]) +def test_char_scanner_save_load(test): + scanner = CharScanner(*test) + saved = scanner.save() + loaded = Scanner.load(saved) + assert scanner == loaded + + +@pytest.mark.parametrize('test,res', [ + (('', True), []), + (('', False), []), + (('ab c.d', True), list('ab c.d')), + (('ab c.d', False), list('ab c.d') + [Scanner.END]) +]) +def test_regexp_scanner_id(test, res): + scan = RegExpScanner('.', None) + assert list(scan(*test)) == res + +def test_regexp_scanner_default(): + scan = RegExpScanner() + + assert list(scan('ab..c')) == [ + 'ab', '..', scan.END, 'c', '.', scan.END + ] + assert list(scan('a \n b?!. .. !! ??c', True)) == [ + 'a', 'b', '?!.', scan.END, 'c' + ] + assert list(scan('.', True)) == ['.', scan.END] + assert list(scan('', False)) == [] + assert list(scan('... .. . ! \n ? ?! ')) == [] + +@pytest.mark.parametrize('test,res', [ + (('a\t\nb\nc.d e ?!f',), 'abc.de?!f.'), + (('.?!.a', True), 'a'), + (('a.',), 'a.') +]) +def test_regexp_scanner_default_str(test, res): + scan = RegExpScanner() + assert scan_str(scan, *test) == res + +@pytest.mark.parametrize('test,test2,res', [ + ((), (), True), + (('.', None), ('.', None), True), + (('.', None), ('.', '.'), False), + (('.', 'x'), ('.', 'x'), True), + (('x', '.'), ('.', '.'), False), + (('x', '.'), ('x', 'x'), False) +]) +def test_regexp_scanner_eq(test, test2, res): + scan = RegExpScanner(*test) + scan2 = RegExpScanner(*test2) + assert (scan == scan2) == res + +@pytest.mark.parametrize('test', [ + (), ('.*', None) +]) +def test_regexp_scanner_save_load(test): + scanner = RegExpScanner(*test) + saved = scanner.save() + loaded = Scanner.load(saved) + assert scanner == loaded diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py deleted file mode 100644 index 738d28d..0000000 --- a/tests/test_sqlite.py +++ /dev/null @@ -1,147 +0,0 @@ -from unittest import TestCase -from tempfile import TemporaryDirectory -from collections import Counter -from itertools import chain, repeat -import os - -from markovchain import (MarkovBase, MarkovSqliteMixin, - Scanner, CharScanner, Parser) - - -class TestMarkovSqlite(TestCase): - class Markov(MarkovSqliteMixin, MarkovBase): - pass - - @classmethod - def setUpClass(cls): - cls.tmpdir = TemporaryDirectory() - - @classmethod - def tearDownClass(cls): - cls.tmpdir.cleanup() - - def test_empty(self): - m = self.Markov() - m.save() - self.assertTrue(m.db) - self.assertTrue(m.cursor) - tables = m.get_tables() - self.assertIn('main', tables) - self.assertIn('nodes', tables) - self.assertIn('links', tables) - - def test_properties(self): - m = self.Markov(scanner=Scanner(lambda x: x)) - m.links([(('x', 'y'), 'z')]) - m.separator = '::' - m.cursor.execute('SELECT value FROM nodes') - nodes = m.cursor.fetchall() - self.assertEqual(nodes, [('x::y',), ('y::z',)]) - - def test_add_links(self): - m = self.Markov() - m.links([(('x',), 'y'), (('y',), 'z'), (('x',), 'y')]) - - m.cursor.execute('SELECT id, value FROM nodes') - nodes = m.cursor.fetchall() - self.assertEqual(nodes, [(1, 'x'), (2, 'y'), (3, 'z')]) - - node = m.get_node('z') - m.cursor.execute('SELECT value FROM links WHERE source=?', (node,)) - nodes = m.cursor.fetchall() - self.assertEqual(nodes, []) - - node = m.get_node('y') - m.cursor.execute( - 'SELECT value, count FROM links WHERE source=?', - (node,) - ) - nodes = m.cursor.fetchall() - self.assertCountEqual(nodes, [('z', 1)]) - - node = m.get_node('x') - m.cursor.execute( - 'SELECT value, count FROM links WHERE source=?', - (node,) - ) - nodes = m.cursor.fetchall() - self.assertCountEqual(nodes, [('y', 2)]) - - m.links([(('x',), 'z'), (('x',), 'y')]) - - m.cursor.execute( - 'SELECT value, count FROM links WHERE source=?', - (node,) - ) - nodes = m.cursor.fetchall() - self.assertCountEqual(nodes, [('y', 3), ('z', 1)]) - - def test_random_link(self): - m = self.Markov() - values = list(str(x) for x in range(4)) - m.links(('', y) for y in values) - counter = Counter(m.random_link(('',))[0] - for _ in range(10 * len(values))) - self.assertEqual(len(counter.items()), len(values)) - self.assertTrue( - all(count and item in values for item, count in counter.items()) - ) - - def test_random_link_frequency(self): - m = self.Markov() - values = list(range(4)) - counts = (x * x for x in values) - m.links(chain(*( - repeat(('', str(v)), n) for v, n in zip(values, counts) - ))) - counter = Counter(m.random_link('')[0] - for _ in range(20 * len(values))) - common = [value for value, count in counter.most_common()] - values.sort(key=lambda x: -x) - values = [str(value) for value in values] - self.assertGreater(len(common), 1) - self.assertEqual(common, values[:len(common)]) - - def test_generate_empty(self): - m = self.Markov() - self.assertEqual(''.join(m.generate(10)), '') - m = self.Markov() - m.links([('x', 'y')]) - self.assertEqual(''.join(m.generate(-1, start='x')), '') - self.assertEqual(''.join(m.generate(0, start='x')), '') - m.parser = None - self.assertEqual(''.join(m.generate(10, state_size=4)), '') - - def test_generate(self): - m = self.Markov(scanner=Scanner(lambda x: x)) - m.data(['x', 'y']) - self.assertEqual(''.join(m.generate(1, start='')), 'x') - self.assertEqual(''.join(m.generate(10, start='x')), 'y') - self.assertEqual(''.join(m.generate(10, start='y')), '') - self.assertIn(''.join(m.generate(10)), ['y', 'xy']) - - def test_generate_state_size(self): - m = self.Markov(separator=':', - parser=Parser(state_sizes=[2, 3]), - scanner=Scanner(lambda x: x)) - m.data(['x', 'y', 'z']) - self.assertEqual(''.join(m.generate(10, state_size=2)), 'xyz') - self.assertEqual(''.join(m.generate(10, state_size=3)), 'xyz') - - def test_save_load(self): - db = os.path.join(self.tmpdir.name, 'test.db') - m = self.Markov(db=db, - separator=':', - parser=Parser(state_sizes=[2, 3]), - scanner=Scanner(lambda x: x)) - m.data(['x', 'y', 'z']) - m.scanner = CharScanner() - m.save() - - loaded = self.Markov.load(db) - self.assertEqual(m, loaded) - self.assertEqual(''.join(loaded.generate(10, state_size=2)), 'xyz') - - loaded = self.Markov.load(db, {'separator': ''}) - self.assertNotEqual(m, loaded) - self.assertEqual(loaded.separator, '') diff --git a/tests/test_util.py b/tests/test_util.py index 5de94e6..a8a71ca 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,4 @@ -from unittest import TestCase +import pytest from markovchain.util import ( SaveLoad, ObjectWrapper, const, @@ -6,7 +6,8 @@ ) -class TestSaveLoad(TestCase): +@pytest.fixture +def save_load_test(): class SaveLoadTest(SaveLoad): classes = {} @@ -20,186 +21,159 @@ def save(self): data = super().save() data['value'] = self.value return data - - def test_add_remove(self): - self.SaveLoadTest.add_class(self.SaveLoadTest) - self.assertIs(self.SaveLoadTest.classes['SaveLoadTest'], - self.SaveLoadTest) - self.SaveLoadTest.remove_class(self.SaveLoadTest) - self.SaveLoadTest.remove_class(self.SaveLoadTest) - with self.assertRaises(KeyError): - raise AssertionError(self.SaveLoadTest.classes['SaveLoadTest']) - - def test_save_load(self): - self.SaveLoadTest.add_class(self.SaveLoadTest) - test = self.SaveLoadTest(0) - saved = test.save() - loaded = self.SaveLoadTest.load(saved) - self.assertIsInstance(loaded, self.SaveLoadTest) - self.assertEqual(loaded, test) - - -class TestObjectWrapper(TestCase): - class ObjectTest: - def __init__(self, x, y): - self.x = x - self.y = y - def method(self): - return self.x + self.y - def method2(self): - return self.x - self.y - - class ObjectWrapperTest(ObjectWrapper): - def __init__(self, obj, z): - super().__init__(obj) - self.y *= 2 - self.z = z - def method(self): - return super().method() + self.z - - def test_wrap(self): - obj = self.ObjectTest(1, 2) - wrapped = ObjectWrapper(obj) - self.assertIsInstance(wrapped, self.ObjectTest) - self.assertEqual(wrapped.__dict__, obj.__dict__) - self.assertEqual(wrapped.method(), 3) - self.assertEqual(wrapped.method2(), -1) - - def test_override(self): - obj = self.ObjectTest(1, 2) - wrapped = self.ObjectWrapperTest(obj, 3) - self.assertIsInstance(wrapped, self.ObjectTest) - self.assertEqual(wrapped.x, 1) - self.assertEqual(wrapped.y, 4) - self.assertEqual(wrapped.z, 3) - self.assertEqual(wrapped.method(), 8) - self.assertEqual(wrapped.method2(), -3) - - -class TestConst(TestCase): - def test(self): - func = const('x') - self.assertEqual(func(), 'x') - self.assertEqual(func(1, [2], None), 'x') - - -class TestToList(TestCase): - def test(self): - tests = [ - ([], []), - (range(3), list(range(3))), - (0, [0]), - ({'x': 0}, [{'x': 0}]) - ] - for test, res in tests: - self.assertEqual(to_list(test), res) - - -class TestFill(TestCase): - def test_empty(self): - tests = [ - (None, -1), - ([], 0), - (0, 0) - ] - for test in tests: - self.assertEqual(fill(*test), []) - with self.assertRaises(ValueError): - fill([], 1) - - def test_single(self): - tests = [ - ((1, 1), [1]), - ((1, 5), [1] * 5) - ] - for test, res in tests: - self.assertEqual(fill(*test), res) - - def test_multiple(self): - tests = [ - ((range(10), 2), [0, 1]), - ((range(3), 5), [0, 1, 2, 2, 2]) - ] - for test, res in tests: - self.assertEqual(fill(*test), res) - - def test_no_copy(self): - tests = [ - ([], 0), - (list(range(3)), 3) - ] - for lst, sz in tests: - self.assertIs(fill(lst, sz), lst) - - def test_copy(self): - tests = [ - ([[0], [1]], 4) - ] - for test in tests: - res = fill(*test, copy=False) - self.assertIs(res[-1], test[0][-1]) - res = fill(*test, copy=True) - self.assertIsNot(res[-1], test[0][-1]) - self.assertEqual(res[-1], test[0][-1]) - - -class TestLoad(TestCase): - class LoadTest: - @staticmethod - def load(data): - return data['x'] - - def testDefault(self): - x = load(None, self.LoadTest, lambda: 0) - self.assertEqual(x, 0) - - def testLoadClass(self): - x = load({'x': 1}, self.LoadTest, lambda: 0) - self.assertEqual(x, 1) - - def testLoadObject(self): - obj = object() - x = load(obj, self.LoadTest, lambda: 0) - self.assertIs(x, obj) - - -class TestExtend(TestCase): - def test(self): - tests = [ - (({'x': 0}, {'y': 1}), {'x': 0, 'y': 1}), - (({'x': 0}, {'y': 1}, {'x': 1}), {'x': 1, 'y': 1}), - (({'x': {'y': 0}}, {'x': {'z': 1}}), {'x': {'y': 0, 'z': 1}}), - (({'x': {'y': 0}}, {'x': 1}), {'x': 1}), - (({'x': 1}, {'x': {'y': 0}}), {'x': {'y': 0}}), - (({}, {'x': {'y': 0}}), {'x': {'y': 0}}), - ] - for test, res in tests: - self.assertEqual(extend(*test), res) - - -class TestTruncate(TestCase): - def testError(self): - tests = [ - ('', 3, True), - ('0', -1, False) - ] - for test in tests: - with self.assertRaises(ValueError): - truncate(*test) - - def testNoTruncate(self): - tests = [ - ('0', 4, True), - ('1234', 4, False) - ] - for test in tests: - self.assertIs(truncate(*test), test[0]) - - def testTruncate(self): - tests = [ - (('1234567', 5), '12...'), - (('1234567', 5, True), '12...'), - (('1234567', 5, False), '...67') - ] - for test, res in tests: - self.assertEqual(truncate(*test), res) + SaveLoadTest.add_class(SaveLoadTest) + return SaveLoadTest + + +def test_saveload_add_remove_class(save_load_test): + assert save_load_test.classes['SaveLoadTest'] is save_load_test + save_load_test.remove_class(save_load_test) + save_load_test.remove_class(save_load_test) + with pytest.raises(KeyError): + save_load_test.classes['SaveLoadTest'] + +def test_saveload_save_load(save_load_test): + test = save_load_test(0) + saved = test.save() + loaded = save_load_test.load(saved) + assert isinstance(loaded, save_load_test) + assert loaded == test + + +class ObjectTest: + def __init__(self, x, y): + self.x = x + self.y = y + def method(self): + return self.x + self.y + def method2(self): + return self.x - self.y + +class ObjectWrapperTest(ObjectWrapper): + def __init__(self, obj, z): + super().__init__(obj) + self.y *= 2 + self.z = z + def method(self): + return super().method() + self.z + +def test_object_wrapper_wrap(): + obj = ObjectTest(1, 2) + wrapped = ObjectWrapper(obj) + assert isinstance(wrapped, ObjectTest) + assert wrapped.__dict__ == obj.__dict__ + assert wrapped.method() == 3 + assert wrapped.method2() == -1 + +def test_object_wrapper_override(): + obj = ObjectTest(1, 2) + wrapped = ObjectWrapperTest(obj, 3) + assert wrapped.x == 1 + assert wrapped.y == 4 + assert wrapped.z == 3 + assert wrapped.method() == 8 + assert wrapped.method2() == -3 + + +def test_const(): + assert const(0)() == 0 + assert const(1)(1, [2], key=3) == 1 + + +@pytest.mark.parametrize('test,res', [ + ([], []), + (range(3), list(range(3))), + (0, [0]), + ({'x': 0}, [{'x': 0}]) +]) +def test_to_list(test, res): + assert to_list(test) == res + + +def test_fill_error(): + with pytest.raises(ValueError): + fill([], 1) + +@pytest.mark.parametrize('test,res', [ + ((None, -1), []), + (([], 0), []), + ((0, 0), []), + ((1, 1), [1]), + ((1, 5), [1] * 5), + ((range(10), 2), [0, 1]), + ((range(3), 5), [0, 1, 2, 2, 2]) +]) +def test_fill(test, res): + assert fill(*test) == res + +@pytest.mark.parametrize('lst,size', [ + ([], 0), + (list(range(3)), 3) +]) +def test_fill_no_copy(lst, size): + assert fill(lst, size) is lst + +@pytest.mark.parametrize('lst,size', [ + ([[0], [1]], 4) +]) +def test_fill_copy(lst, size): + res = fill(lst, size, copy=False) + assert res[-1] is lst[-1] + res = fill(lst, size, copy=True) + assert res[-1] is not lst[-1] + assert res[-1] == lst[-1] + + +class LoadTest: + @staticmethod + def load(data): + return data['x'] + +def test_load_efault(): + x = load(None, LoadTest, lambda: 0) + assert x == 0 + +def test_load_class(): + x = load({'x': 1}, LoadTest, lambda: 0) + assert x == 1 + +def test_load_object(): + obj = object() + x = load(obj, LoadTest, lambda: 0) + assert x is obj + + +@pytest.mark.parametrize('test,res', [ + (({'x': 0}, {'y': 1}), {'x': 0, 'y': 1}), + (({'x': 0}, {'y': 1}, {'x': 1}), {'x': 1, 'y': 1}), + (({'x': {'y': 0}}, {'x': {'z': 1}}), {'x': {'y': 0, 'z': 1}}), + (({'x': {'y': 0}}, {'x': 1}), {'x': 1}), + (({'x': 1}, {'x': {'y': 0}}), {'x': {'y': 0}}), + (({}, {'x': {'y': 0}}), {'x': {'y': 0}}) +]) +def test_extend(test, res): + assert extend(*test) == res + + +@pytest.mark.parametrize('test', [ + ('', 3, True), + ('0', -1, False) +]) +def test_truncate_error(test): + with pytest.raises(ValueError): + truncate(*test) + +@pytest.mark.parametrize('test', [ + ('0', 4, True), + ('1234', 4, False) +]) +def test_truncate_noop(test): + assert truncate(*test) is test[0] + +@pytest.mark.parametrize('test,res', [ + (('1234567', 5), '12...'), + (('1234567', 5, True), '12...'), + (('1234567', 5, False), '...67') +]) +def test_truncate(test, res): + assert truncate(*test) == res diff --git a/tests/text/test_markov.py b/tests/text/test_markov.py new file mode 100644 index 0000000..3ad600a --- /dev/null +++ b/tests/text/test_markov.py @@ -0,0 +1,44 @@ +import pytest + +from markovchain.text import MarkovText +from markovchain.scanner import Scanner, CharScanner + + +def test_markov_text_data(mocker): + mock = mocker.patch('markovchain.Markov.data', return_value=1) + markov = MarkovText() + assert markov.data([1, 2], True) == 1 + mock.assert_called_once_with([1, 2], True) + +@pytest.mark.parametrize('test,scanner,join_with', [ + (['1', '2', '3'], CharScanner(), ''), + (['1', '2', '3'], None, ' ') +]) +def test_markov_text_format(mocker, test, scanner, join_with): + fmt = mocker.patch( + 'markovchain.text.markov.format_sentence', + return_value=2 + ) + markov = MarkovText(scanner=scanner) + assert markov.format(test) == 2 + fmt.assert_called_with(join_with.join(test)) + +@pytest.mark.parametrize('data,args,res', [ + ([], (), []), + ('xy', (), ['x', 'y']), + ('xy', (None, None, 'z'), ['z']), + ('xy', (None, None, 'xyx'), ['xyx', 'y']), + ('xy', (None, None, (x for x in 'xyx')), ['x', 'y', 'x', 'y']), + ('xxxxx', (2,), ['x', 'x']), + ('xxxxx', (-10,), []), + ('xxxxx', (0,), []) +]) +def test_markov_text_generate(mocker, data, args, res): + fmt = mocker.patch( + 'markovchain.MarkovText.format', + wraps=list + ) + markov = MarkovText(scanner=Scanner(lambda x: x)) + markov.data(data) + assert markov(*args) == res + assert fmt.call_count == 1 diff --git a/tests/text/test_util.py b/tests/text/test_util.py index 58ef900..5042f92 100644 --- a/tests/text/test_util.py +++ b/tests/text/test_util.py @@ -1,5 +1,4 @@ -from unittest import TestCase -from unittest.mock import patch +import pytest from markovchain.text.util import ( ispunct, lstrip_ws_and_chars, capitalize, @@ -7,52 +6,60 @@ ) -class TestTextUtils(TestCase): - def test_ispunct(self): - self.assertTrue(ispunct('\'"?,+-.[]{}()<>')) - self.assertFalse(ispunct('\'"?,+-x.[]{}()<>')) - self.assertFalse(ispunct('')) - - def test_capitalize(self): - self.assertEqual(capitalize('worD WORD WoRd'), 'Word word word') - self.assertEqual(capitalize('x'), 'X') - self.assertEqual(capitalize(''), '') - - def test_lstrip_ws_and_chars(self): - self.assertEqual(lstrip_ws_and_chars('', ''), '') - self.assertEqual(lstrip_ws_and_chars(' ', ''), '') - self.assertEqual(lstrip_ws_and_chars(' x ', 'xy'), '') - self.assertEqual(lstrip_ws_and_chars(' \t.\n , .x. ', '.,?!'), 'x. ') - - def test_format_sentence_string(self): - fmt = format_sentence_string - self.assertEqual(fmt(''), '') - self.assertEqual(fmt(' '), '') - self.assertEqual(fmt(' ...'), '') - self.assertEqual(fmt('.?!word'), 'Word.') - self.assertEqual(fmt('word', default_end='/'), 'Word/') - self.assertEqual(fmt('word', end_chars='d'), 'Word') - self.assertEqual(fmt('word , (word).. word'), 'Word, (word).. word.') - self.assertEqual(fmt('word,wo[rd..wo]rd'), 'Word, wo [rd.. wo] rd.') - self.assertEqual(fmt('wo--*--rd'), 'Wo --*-- rd.') - - @patch('markovchain.text.util.format_sentence_string', return_value=1) - def test_format_sentence(self, fmt): - self.assertEqual(format_sentence('word'), 1) - fmt.assert_called_with('word', '.?!', '.') - fmt.reset_mock() - - self.assertEqual( - format_sentence((str(x) for x in range(3)), - end_chars='/[', default_end='/'), - 1 - ) - fmt.assert_called_with('0 1 2', '/[', '/') - fmt.reset_mock() - - self.assertEqual( - format_sentence(['a', 'b', 'c'], join_with='.'), - 1 - ) - fmt.assert_called_with('a.b.c', '.?!', '.') - fmt.reset_mock() +@pytest.mark.parametrize('test,res', [ + ('\'"?,+-.[]{}()<>', True), + ('\'"?,+-x.[]{}()<>', False), + ('', False) +]) +def test_ispunct(test, res): + assert ispunct(test) == res + + +@pytest.mark.parametrize('test,res', [ + ('worD WORD WoRd', 'Word word word'), + ('x', 'X'), + ('', '') +]) +def test_capitalize(test, res): + assert capitalize(test) == res + + +@pytest.mark.parametrize('test,res', [ + (('', ''), ''), + ((' ', ''), ''), + ((' x ', 'xy'), ''), + ((' \t.\n , .x. ', '.,?!'), 'x. ') +]) +def test_lstrip_ws_and_chars(test, res): + assert lstrip_ws_and_chars(*test) == res + +@pytest.mark.parametrize('arg,kwargs,res', [ + ('', {}, ''), + (' ', {}, ''), + (' ...', {}, ''), + ('.?!word', {}, 'Word.'), + ('word', {'default_end': '/'}, 'Word/'), + ('word', {'end_chars': 'd'}, 'Word'), + ('word , (word).. word', {}, 'Word, (word).. word.'), + ('word,wo[rd..wo]rd', {}, 'Word, wo [rd.. wo] rd.'), + ('wo--*--rd', {}, 'Wo --*-- rd.') +]) +def test_format_sentence_string(arg, kwargs, res): + assert format_sentence_string(arg, **kwargs) == res + +@pytest.mark.parametrize('arg,kwargs,call', [ + ('word', {}, ('word', '.?!', '.')), + ( + (str(x) for x in range(3)), + {'end_chars': '/[', 'default_end': '/'}, + ('0 1 2', '/[', '/') + ), + (['a', 'b', 'c'], {'join_with': '.'}, ('a.b.c', '.?!', '.')) +]) +def test_format_sentence(mocker, arg, kwargs, call): + fmt = mocker.patch( + 'markovchain.text.util.format_sentence_string', + return_value=1 + ) + assert format_sentence(arg, **kwargs) == 1 + fmt.assert_called_once_with(*call)