From 1f650f28edf11339e6b15c3d03892effb0015c59 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 19 Aug 2016 12:27:06 +0100 Subject: [PATCH] Test for class inheritance instead of class name (#34) --- opsdroid/core.py | 14 ++++++++++---- tests/mockmodules/connectors/connector.py | 4 +++- tests/mockmodules/databases/database.py | 4 +++- tests/test_core.py | 13 +++++-------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/opsdroid/core.py b/opsdroid/core.py index 73520e009..5b5575312 100644 --- a/opsdroid/core.py +++ b/opsdroid/core.py @@ -6,6 +6,8 @@ from multiprocessing import Process from opsdroid.helper import match from opsdroid.memory import Memory +from opsdroid.connector import Connector +from opsdroid.database import Database class OpsDroid(): @@ -54,7 +56,9 @@ def start_connectors(self, connectors): self.critical("All connectors failed to load", 1) elif len(connectors) == 1: for name, cls in connectors[0]["module"].__dict__.items(): - if isinstance(cls, type) and "Connector" in name: + if isinstance(cls, type) and \ + isinstance(cls({}), Connector): + logging.debug("Adding connector: " + name) connectors[0]["config"]["bot-name"] = self.bot_name connector = cls(connectors[0]["config"]) self.connectors.append(connector) @@ -62,7 +66,8 @@ def start_connectors(self, connectors): else: for connector_module in connectors: for name, cls in connector_module["module"].__dict__.items(): - if isinstance(cls, type) and "Connector" in name: + if isinstance(cls, type) and \ + isinstance(cls({}), Connector): connector_module["config"]["bot-name"] = self.bot_name connector = cls(connector_module["config"]) self.connectors.append(connector) @@ -78,11 +83,12 @@ def start_databases(self, databases): logging.warning("All databases failed to load") for database_module in databases: for name, cls in database_module["module"].__dict__.items(): - if isinstance(cls, type) and "Database" in name: + if isinstance(cls, type) and \ + isinstance(cls({}), Database): logging.debug("Adding database: " + name) database = cls(database_module["config"]) self.memory.databases.append(database) - database.connect() + database.connect(self) def load_regex_skill(self, regex, skill): """Load skills.""" diff --git a/tests/mockmodules/connectors/connector.py b/tests/mockmodules/connectors/connector.py index 5e4455961..0938160c6 100644 --- a/tests/mockmodules/connectors/connector.py +++ b/tests/mockmodules/connectors/connector.py @@ -2,8 +2,10 @@ import unittest.mock as mock +from opsdroid.connector import Connector -class ConnectorTest: + +class ConnectorTest(Connector): """The mocked connector class.""" def __init__(self, config): diff --git a/tests/mockmodules/databases/database.py b/tests/mockmodules/databases/database.py index b1119534b..3a5ae8640 100644 --- a/tests/mockmodules/databases/database.py +++ b/tests/mockmodules/databases/database.py @@ -2,8 +2,10 @@ import unittest.mock as mock +from opsdroid.database import Database -class DatabaseTest: + +class DatabaseTest(Database): """The mocked database class.""" def __init__(self, config): diff --git a/tests/test_core.py b/tests/test_core.py index b2bac73bb..e309f8b5e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -55,10 +55,8 @@ def test_start_databases(self): module["config"] = {} module["module"] = importlib.import_module( "tests.mockmodules.databases.database") - opsdroid.start_databases([module]) - self.assertEqual(len(opsdroid.memory.databases), 1) - self.assertEqual( - len(opsdroid.memory.databases[0].connect.mock_calls), 1) + with self.assertRaises(NotImplementedError): + opsdroid.start_databases([module]) def test_start_connectors(self): with OpsDroid() as opsdroid: @@ -67,12 +65,11 @@ def test_start_connectors(self): module["config"] = {} module["module"] = importlib.import_module( "tests.mockmodules.connectors.connector") - opsdroid.start_connectors([module]) - self.assertEqual(len(opsdroid.connectors), 1) + + with self.assertRaises(NotImplementedError): + opsdroid.start_connectors([module]) opsdroid.start_connectors([module, module]) - self.assertEqual(len(opsdroid.connectors), 3) - self.assertEqual(len(opsdroid.connector_jobs), 2) def test_multiple_opsdroids(self): with OpsDroid() as opsdroid: