diff --git a/src/db.py b/src/db.py index 8610a84..eed4be3 100644 --- a/src/db.py +++ b/src/db.py @@ -1,7 +1,11 @@ from mysql.connector.abstracts import MySQLConnectionAbstract +from mysql.connector.cursor import MySQLCursorAbstract +import exception from clui import log +SHOW_TABLES = "SHOW TABLES" + def test_connection(cnx: MySQLConnectionAbstract) -> None: log("Testing database connection") @@ -9,3 +13,12 @@ def test_connection(cnx: MySQLConnectionAbstract) -> None: log(f"Successfully connected to MySQL {cnx.get_server_info()} on {cnx.server_host}.") else: log("Connection not working!") + + +def assert_table_exists(cursor: MySQLCursorAbstract, table: str) -> None: + cursor.execute(SHOW_TABLES) + tables = cursor.fetchall() + for row in tables: + if table in row: + return + raise exception.TableNotFoundError(table) diff --git a/src/exception.py b/src/exception.py new file mode 100644 index 0000000..d22dbd3 --- /dev/null +++ b/src/exception.py @@ -0,0 +1,9 @@ +class Error(Exception): + def __init__(self, message: str): + self.msg = message + + +class TableNotFoundError(Error): + def __init__(self, table: str): + message = f"Could not find table {table}" + super().__init__(message)