From 0ffa1565e1f9e6c74b8f7b161d6041b8992aecbb Mon Sep 17 00:00:00 2001 From: Justin Joyce Date: Sat, 1 Apr 2023 00:10:22 +0100 Subject: [PATCH] #905 --- opteryx/__main__.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/opteryx/__main__.py b/opteryx/__main__.py index 45a54c85..7778d7ee 100644 --- a/opteryx/__main__.py +++ b/opteryx/__main__.py @@ -23,15 +23,18 @@ from opteryx.components.sql_rewriter.sql_rewriter import clean_statement from opteryx.components.sql_rewriter.sql_rewriter import remove_comments +# Define ANSI color codes +ANSI_RED = "\u001b[31m" +ANSI_RESET = "\u001b[0m" # fmt:off def main( - o: str = typer.Option(default="console", help="Output location", ), + o: str = typer.Option(default="console", help="Output location (ignored by REPL)", ), color: bool = typer.Option(default=True, help="Colorize the table displayed to the console."), table_width: bool = typer.Option(default=True, help="Limit console display to the screen width."), max_col_width: int = typer.Option(default=30, help="Maximum column width"), stats: bool = typer.Option(default=False, help="Report statistics."), - sql: str = typer.Argument(None, show_default=False, help="SQL statement to execute."), + sql: str = typer.Argument(None, show_default=False, help="Execute SQL statement and quit."), ): # fmt:on """ @@ -42,6 +45,32 @@ def main( if hasattr(table_width, "default"): table_width = table_width.default + if sql is None: + + import readline + + if o != "console": + raise ValueError("Cannot specify output location and not provide a SQL statement.") + + # Start the REPL loop + while True: # pragma: no cover + # Prompt the user for a SQL statement + statement = input('>> ') + + # If the user entered "quit", exit the loop + if statement == 'quit': + break + + try: + # Execute the SQL statement and display the results + result = opteryx.query(statement) + print(result.display(limit=-1, display_width=table_width, colorize=color, max_column_width=max_col_width)) + except Exception as e: + # Display a friendly error message if an exception occurs + print(f"{ANSI_RED}Error{ANSI_RESET}: {e}") + + quit() + # tidy up the statement sql = clean_statement(remove_comments(sql)) @@ -78,8 +107,12 @@ def main( file.write(orjson.dumps(row.as_dict) + b"\n") return - print(f"Unknown output format '{ext}'") # pragma: no cover + raise ValueError(f"Unknown output format '{ext}'") # pragma: no cover if __name__ == "__main__": # pragma: no cover - typer.run(main) + try: + typer.run(main) + except Exception as e: + # Display a friendly error message if an exception occurs + print(f"{ANSI_RED}Error{ANSI_RESET}: {e}")