diff --git a/.env b/.env deleted file mode 100644 index d50efbb..0000000 --- a/.env +++ /dev/null @@ -1 +0,0 @@ -BASE_URL = "https://api.app.firebolt.io" \ No newline at end of file diff --git a/README.md b/README.md index d83ce14..8a73034 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,10 @@ firebolt://email@domain:password@sample_database firebolt://email@domain:password@sample_database/sample_engine ``` +To override the API url (e.g. for dev testing) +```bash +export FIREBOLT_BASE_URL= +``` ## DB API @@ -72,9 +76,7 @@ schemas = dialect.get_schema_names(connection) ``` ## Components in the Adapter: -1. Firebolt Connector: This file is used to establish a connection to the Firebolt database from 3rd party applications. It provides a ‘connect’ method which accepts parameters like database name, username, password etc. from the connecting application to identify the database and authenticate the user credentials. It returns a database connection which is used to execute queries on the database. -2. API Service: The API Service is responsible for calling Firebolt REST APIs to establish connection with the database and fire SQL queries on it. It provides methods to get access token as per user credentials, get the specific engine URL and execute/run SQL queries. Executing queries need access token and engine URL as per the Firebolt REST API specifications. -3. Firebolt Dialect: It provides methods for retrieving metadata about databases like schema data, table names, column names etc. It also maps the data types between Firebolt and SQLAlchemy along with providing a data type compiler for complex data types. +1. Firebolt Dialect: It provides methods for retrieving metadata about databases like schema data, table names, column names etc. It also maps the data types between Firebolt and SQLAlchemy along with providing a data type compiler for complex data types. ## Testing Strategy: diff --git a/setup.py b/setup.py index 88ff5ce..dc56de2 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,7 @@ }, install_requires=[ 'sqlalchemy>=1.0.0', - "requests", - "datetime" + "firebolt-sdk" ], entry_points={ "sqlalchemy.dialects": [ diff --git a/src/firebolt_db/__init__.py b/src/firebolt_db/__init__.py index 7f90924..b031390 100644 --- a/src/firebolt_db/__init__.py +++ b/src/firebolt_db/__init__.py @@ -1,5 +1,5 @@ -from .firebolt_connector import connect -from .exceptions import ( +from firebolt.db import connect +from firebolt.common.exception import ( DatabaseError, DataError, Error, diff --git a/src/firebolt_db/constants.py b/src/firebolt_db/constants.py deleted file mode 100644 index 877ff8f..0000000 --- a/src/firebolt_db/constants.py +++ /dev/null @@ -1,18 +0,0 @@ -import os -from dotenv import load_dotenv,find_dotenv - -found_dotenv = find_dotenv() - -if found_dotenv: - load_dotenv(found_dotenv) - base_url = os.environ["FIREBOLT_BASE_URL"] -else: - base_url = "https://api.app.firebolt.io" - -token_url = f"{base_url}/auth/v1/login" -refresh_url = f"{base_url}/auth/v1/refresh" -query_engine_url = f"{base_url}/core/v1/account/engines:getURLByDatabaseName" -query_engine_url_by_engine_name = f"{base_url}/core/v1/account/engines" -engine_id_url = f"{base_url}/core/v1/account/engines:getIdbyName" -engine_start_url = f"{base_url}/core/v1/account/engines" -token_header = {"Content-Type": "application/json;charset=UTF-8"} diff --git a/src/firebolt_db/exceptions.py b/src/firebolt_db/exceptions.py deleted file mode 100644 index 3e974f2..0000000 --- a/src/firebolt_db/exceptions.py +++ /dev/null @@ -1,53 +0,0 @@ -from sqlalchemy.exc import CompileError - - -class Error(Exception): - pass - - -class Warning(Exception): - pass - - -class InterfaceError(Error): - pass - - -class DatabaseError(Error): - pass - - -class InternalError(DatabaseError): - pass - - -class OperationalError(DatabaseError): - pass - - -class ProgrammingError(DatabaseError): - pass - - -class IntegrityError(DatabaseError): - pass - - -class DataError(DatabaseError): - pass - - -class NotSupportedError(CompileError): - pass - - -class InvalidCredentialsError(DatabaseError): - pass - - -class SchemaNotFoundError(DatabaseError): - pass - - -class EngineNotFoundError(DatabaseError): - pass diff --git a/src/firebolt_db/firebolt_api_service.py b/src/firebolt_db/firebolt_api_service.py deleted file mode 100644 index d19aaed..0000000 --- a/src/firebolt_db/firebolt_api_service.py +++ /dev/null @@ -1,278 +0,0 @@ -import json - -import requests -from requests.exceptions import HTTPError - -from . import exceptions -from . import constants -from .memoized import memoized - - -class FireboltApiService: - - @staticmethod - @memoized - def get_connection(user_email, password, engine_name, db_name, date): - """ - Retrieve Authorisation details for connection - This method internally calls methods to get access token, refresh token and engine URL. - :input user-email, password, engine name, database name and date for memoization - :returns access-token, refresh-token and engine url - """ - # get access token - token_json = FireboltApiService.get_access_token(user_email, password) - access_token = token_json["access_token"] - refresh_token = token_json["refresh_token"] - - # get engine url - if engine_name is None or engine_name == '': - engine_url = FireboltApiService.get_engine_url_by_db(db_name, access_token) - else: - engine_url = FireboltApiService.get_engine_url_by_engine(engine_name, access_token) - return access_token, refresh_token, engine_url - - @staticmethod - def get_access_token(user_email, password): - """ - Retrieve authentication token - This method uses the user email and the password to fire the API to generate access-token. - :input user-email and password - :returns access-token - """ - data = {'username': user_email, 'password': password} - json_data = {} # base case - payload = {} - try: - - """ - General format of request: - curl --request POST 'https://api.app.firebolt.io/auth/v1/login' --header 'Content-Type: application/json;charset=UTF-8' --data-binary '{"username":"username","password":"password"}' - """ - token_response = requests.post(url=constants.token_url, data=json.dumps(data), - headers=constants.token_header) - token_response.raise_for_status() - - """ - General format of response: - - { - "access_token": "YOUR_ACCESS_TOKEN_VALUE", - "expires_in": 86400, - "refresh_token": "YOUR_REFRESH_TOKEN_VALUE", - "scope": "offline_access", - "token_type": "Bearer" - } - """ - - json_data = json.loads(token_response.text) - - except HTTPError as http_err: - payload = { - "error": "Access Token API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Access Token API Exception", - "errorMessage": str(err), - } - - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InvalidCredentialsError(msg) - - return json_data - - @staticmethod - def get_access_token_via_refresh(refresh_token): - """ - Refresh access token - In case the token expires or the API throws a 401 HTTP error, then this method generates a fresh token - :input refresh token generated alongside the previous expired token - :returns new access-token - """ - refresh_access_token = "" - payload = {} - try: - """ - Request: - curl --request POST 'https://api.app.firebolt.io/auth/v1/refresh' - --header 'Content-Type: application/json;charset=UTF-8' - --data-binary '{"refresh_token":"YOUR_REFRESH_TOKEN_VALUE"}' - """ - data = {'refresh_token': refresh_token} - token_response = requests.post(url=constants.refresh_url, data=json.dumps(data), - headers=constants.token_header) - token_response.raise_for_status() - - """ - Response: - { - "access_token": "YOUR_REFRESHED_ACCESS_TOKEN_VALUE", - "scope": "offline_access", - "expires_in": 86400, - "token_type": "Bearer" - } - """ - - json_data = json.loads(token_response.text) - refresh_access_token = json_data["access_token"] - - except HTTPError as http_err: - payload = { - "error": "Refresh Access Token API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Refresh Access Token API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InternalError(msg) - - return refresh_access_token - - @staticmethod - def get_engine_url_by_db(db_name, access_token): - """ - Get engine url by db_name name - This method generates engine url using engine name and access-token - :input database name and access token - :returns engine url - """ - engine_url = "" # base case - payload = {} - try: - """ - Request: - curl --request GET 'https://api.app.firebolt.io/core/v1/account/engines:getURLByDatabaseName?database_name=YOUR_DATABASE_NAME' - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' - """ - header = {'Authorization': "Bearer " + access_token} - query_engine_response = requests.get(constants.query_engine_url, params={'database_name': db_name}, - headers=header) - query_engine_response.raise_for_status() - - """ - Response: - {"engine_url": "YOUR_DATABASES_DEFAULT_ENGINE_URL"} - """ - json_data = json.loads(query_engine_response.text) - engine_url = json_data["engine_url"] - - except HTTPError as http_err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.SchemaNotFoundError(msg) - - return engine_url - - @staticmethod - def get_engine_url_by_engine(engine_name, access_token): - """ - Get engine url by engine name - This method generates engine url using engine name and access-token - :input engine name and access-token - :returns engine url - """ - engine_url = "" # base case - payload = {} - try: - """ - Request: - curl --request GET 'https://api.app.firebolt.io/core/v1/account/engines?filter.name_contains=YOUR_ENGINE_NAME' - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' - """ - header = {'Authorization': "Bearer " + access_token} - query_engine_response = requests.get(constants.query_engine_url_by_engine_name, - params={'filter.name_contains': engine_name}, - headers=header) - query_engine_response.raise_for_status() - - """ - Response: - { - "page": { - ... - }, - "edges": [ - { - ... - "endpoint": "YOUR_ENGINE_URL", - ... - } - } - ] - } - """ - json_data = json.loads(query_engine_response.text) - engine_url = json_data["edges"][0]["node"]["endpoint"] - - except HTTPError as http_err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "Engine Url API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.EngineNotFoundError(msg) - - return engine_url - - @staticmethod - def run_query(access_token, engine_url, db_name, query): - """ - Run queries - This method is used to submit a query to run to a running engine. - You can specify multiple queries separated by a semicolon (;).. - :input access token, engine url, database name, query - :returns database metadata - """ - query_response = {} # base-case - payload = {} - - try: - - """ - Request: - --request POST 'https://YOUR_ENGINE_ENDPOINT/?database=YOUR_DATABASE_NAME' \ - --header 'Authorization: Bearer YOUR_ACCESS_TOKEN_VALUE' \ - --data-binary @- - """ - - header = {'Authorization': "Bearer " + access_token} - query_response = requests.post(url="https://" + engine_url, params={'database': db_name}, - headers=header, files={"query": (None, query)}) - query_response.raise_for_status() - - except HTTPError as http_err: - payload = { - "error": "DB-API Exception", - "errorMessage": http_err.response.text, - } - except Exception as err: - payload = { - "error": "DB-API Exception", - "errorMessage": str(err), - } - if payload != {}: - msg = "{error} : {errorMessage}".format(**payload) - raise exceptions.InternalError(msg) - - return query_response diff --git a/src/firebolt_db/firebolt_connector.py b/src/firebolt_db/firebolt_connector.py deleted file mode 100644 index fdb8354..0000000 --- a/src/firebolt_db/firebolt_connector.py +++ /dev/null @@ -1,464 +0,0 @@ -#!/usr/bin/env python -# -# See http://www.python.org/dev/peps/pep-0249/ -# -# Many docstrings in this file are based on the PEP, which is in the public domain. - -# Built as per Python DB API Specification - PEP 249 -# Responsible for connection to Database and providing database cursor for query execution - -import itertools -import json -from collections import namedtuple, OrderedDict -from datetime import date - -from .firebolt_api_service import FireboltApiService -from . import exceptions - - -class Error(Exception): - """Exception that is the base class of all other error exceptions. - You can use this to catch all errors with one single except statement. - """ - pass - - -class Type(object): - STRING = 1 - NUMBER = 2 - BOOLEAN = 3 - ARRAY = 4 - - -def connect(*args, **kwargs): - """ - Constructor for creating a connection to the database. - """ - return Connection(*args, **kwargs) - - -def check_closed(f): - """Decorator that checks if connection/cursor is closed.""" - - def g(self, *args, **kwargs): - if self.closed: - raise exceptions.Error( - "{klass} already closed".format(klass=self.__class__.__name__) - ) - return f(self, *args, **kwargs) - - return g - - -def check_result(f): - """Decorator that checks if the cursor has results from `execute`.""" - - def g(self, *args, **kwargs): - if self._results is None: - raise exceptions.Error("Called before `execute`") - return f(self, *args, **kwargs) - - return g - - -def get_description_from_row(row): - """ - Return description from a single row. - - We only return the name, type (inferred from the data) and if the values - can be NULL. String columns in Firebolt are NULLable. Numeric columns are NOT - NULL. - """ - return [ - ( - name, # name - get_type(value), # type_code - None, # [display_size] - None, # [internal_size] - None, # [precision] - None, # [scale] - get_type(value) == Type.STRING, # [null_ok] - ) - for name, value in row.items() - ] - - -def get_type(value): - """ - Infer type from value. - - Note that bool is a subclass of int so order of statements matter. - """ - - if isinstance(value, str) or value is None: - return Type.STRING - elif isinstance(value, bool): - return Type.BOOLEAN - elif isinstance(value, (int, float)): - return Type.NUMBER - elif isinstance(value, list): - return Type.ARRAY - - raise exceptions.Error("Value of unknown type: {value}".format(value=value)) - - -class Connection(object): - """Connection to a Firebolt database.""" - - def __init__(self, - host, - port, - username, - password, - db_name, - # scheme="http", - context=None, - header=False, - ssl_verify_cert=False, - ssl_client_cert=None, - proxies=None, - ): - self._host = host - self._post = port - self._username = username - self._password = password - self._db_name = db_name - connection_details = FireboltApiService.get_connection(username, password, host, db_name, date.today()) - - self.access_token = connection_details[0] - self.refresh_token = connection_details[1] - self.engine_url = connection_details[2] - self.cursors = [] - self.closed = False - - self.ssl_verify_cert = ssl_verify_cert - self.ssl_client_cert = ssl_client_cert - self.proxies = proxies - self.context = context or {} - self.header = header - - - @check_closed - def close(self): - """Close the connection now.""" - self.closed = True - for cursor in self.cursors: - try: - cursor.close() - except exceptions.Error: - pass # already closed - - @check_closed - def commit(self): - """ - Commit any pending transaction to the database. - - Not supported. - """ - pass - - @check_closed - def cursor(self): - """Return a new Cursor Object using the connection.""" - cursor = Cursor( - self._db_name, - self.access_token, - self.engine_url, - self.refresh_token, - # self.url, - self._username, - self._password, - self.context, - self.header, - self.ssl_verify_cert, - self.ssl_client_cert, - self.proxies, - ) - - self.cursors.append(cursor) - - return cursor - - def __enter__(self): - return self.cursor() - - def __exit__(self, *exc): - self.close() - - -class Cursor(object): - """Connection cursor.""" - - def __init__( - self, - db_name, - access_token, - engine_url, - refresh_token, - # url, - user=None, - password=None, - context=None, - header=False, - ssl_verify_cert=True, - proxies=None, - ssl_client_cert=None, - ): - # self.url = url - self.context = context or {} - self.header = header - self.user = user - self.password = password - self.ssl_verify_cert = ssl_verify_cert - self.ssl_client_cert = ssl_client_cert - self.proxies = proxies - self.db_name = db_name - self.access_token = access_token - self.engine_url = engine_url - self.refresh_token = refresh_token - - # This read/write attribute specifies the number of rows to fetch at a - # time with .fetchmany(). It defaults to 1 meaning to fetch a single - # row at a time. - self.arraysize = 1 - - self.closed = False - - # this is updated only after a query - self.description = None - - # this is set to an iterator after a successfull query - self._results = None - - @property - @check_result - @check_closed - def rowcount(self): - # consume the iterator - results = list(self._results) - n = len(results) - self._results = iter(results) - return n - - @check_closed - def close(self): - """Close the cursor.""" - self.closed = True - - @check_closed - def execute(self, operation, parameters=None): - query = apply_parameters(operation, parameters) - results = self._stream_query(query) - - """ - `_stream_query` returns a generator that produces the rows; we need to - consume the first row so that `description` is properly set, so let's - consume it and insert it back if it is not the header. - """ - try: - first_row = next(results) - self._results = ( - results if self.header else itertools.chain([first_row], results) - ) - except StopIteration: - self._results = iter([]) - return self - - @check_closed - def executemany(self, operation, seq_of_parameters=None): - raise exceptions.NotSupportedError( - "`executemany` is not supported, use `execute` instead" - ) - - @check_result - @check_closed - def fetchone(self): - """ - Fetch the next row of a query result set, returning a single sequence, - or `None` when no more data is available. - """ - try: - res = self.next() - return res - except StopIteration: - return None - - @check_result - @check_closed - def fetchmany(self, size=None): - """ - Fetch the next set of rows of a query result, returning a sequence of - sequences (e.g. a list of tuples). An empty sequence is returned when - no more rows are available. - """ - size = size or self.arraysize - return list(itertools.islice(self._results, size)) - - @check_result - @check_closed - def fetchall(self): - """ - Fetch all (remaining) rows of a query result, returning them as a - sequence of sequences (e.g. a list of tuples). Note that the cursor's - arraysize attribute can affect the performance of this operation. - """ - return list(self._results) - - @check_closed - def setinputsizes(self, sizes): - # not supported - pass - - @check_closed - def setoutputsizes(self, sizes): - # not supported - pass - - @check_closed - def __iter__(self): - return self - - @check_closed - def __next__(self): - return next(self._results) - - next = __next__ - - def _stream_query(self, query): - """ - Stream rows from a query. - - This method will yield rows as the data is returned in chunks from the - server. - """ - self.description = None - - r = FireboltApiService.run_query(self.access_token, - self.engine_url, - self.db_name, - query) - - # Setting `chunk_size` to `None` makes it use the server size - chunks = r.iter_content(chunk_size=4096, decode_unicode=True) - - Row = None - # for row in rows_from_lines(lines): - for row in rows_from_chunks(chunks): - # update description - if self.description is None: - self.description = ( - list(row.items()) if self.header else get_description_from_row(row) - ) - - # return row in namedtuple - if Row is None: - Row = namedtuple("Row", row.keys(), rename=True) - yield Row(*row.values()) - - -def rows_from_lines(lines): - """ - A generator that yields rows from JSON lines. - - Firebolt will return the data in lines, but they are not aligned with the - JSON objects. This function will parse all complete rows from the lines, - yielding them as soon as possible. - """ - - data_started = False - body = "" - for line in lines: - line = line.lstrip().rstrip() - if data_started: - if line == '],': - body = "".join((body,line)) - break - else: - body = "".join((body,line)) - - if not data_started and line == '"data":': - data_started = True - - rows = body.lstrip('[').rstrip('],') - - for row in json.loads( - "[{rows}]".format(rows=rows), object_pairs_hook=OrderedDict - ): - yield row - - -def rows_from_chunks(chunks): - """ - A generator that yields rows from JSON chunks. - - Firebolt will return the data in chunks, but they are not aligned with the - JSON objects. This function will parse all complete rows inside each chunk, - yielding them as soon as possible. - """ - data_started = False - old_body = "" - for chunk in chunks: - if chunk: - chunk = "".join((old_body, chunk)) - body = "" - lines = chunk.splitlines() - curly_started = False - new_data_row = "" - for line in lines: - line = line.lstrip().rstrip() - if data_started and line: - if line == '],': - data_started = False - break - else: - if curly_started: - if line == '}' or line == '},': - curly_started = False - body = "".join((body,new_data_row,line)) - new_data_row = "" - old_body = "" - else: - new_data_row = "".join((new_data_row,line)) - old_body = new_data_row - - elif not curly_started and line[0] == '{': - curly_started = True - new_data_row = "".join((new_data_row,line)) - old_body = new_data_row - - elif not data_started and line == '"data":': - data_started = True - - rows = body.lstrip().rstrip(',') - - for row in json.loads( - "[{rows}]".format(rows=rows), object_pairs_hook=OrderedDict - ): - yield row - - -def apply_parameters(operation, parameters): - if not parameters: - return operation - - escaped_parameters = {key: escape(value) for key, value in parameters.items()} - return operation % escaped_parameters - - -def escape(value): - """ - Escape the parameter value. - - Note that bool is a subclass of int so order of statements matter. - """ - - if value == "*": - return value - elif isinstance(value, str): - return "'{}'".format(value.replace("'", "''")) - elif isinstance(value, bool): - return "TRUE" if value else "FALSE" - elif isinstance(value, (int, float)): - return value - elif isinstance(value, (list, tuple)): - return ", ".join(escape(element) for element in value) diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index 904c033..8913377 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -6,6 +6,7 @@ TIMESTAMP, VARCHAR, BOOLEAN, FLOAT) import firebolt_db +import os class ARRAY(sqltypes.TypeEngine): @@ -83,19 +84,18 @@ def __init__(self, context=None, *args, **kwargs): def dbapi(cls): return firebolt_db - # Build DB-API compatible connection arguments. + # Build firebolt-sdk compatible connection arguments. # URL format : firebolt://username:password@host:port/db_name def create_connect_args(self, url): kwargs = { - "host": url.database or None, - "port": url.port or 5432, + "database": url.host or None, "username": url.username or None, "password": url.password or None, - "db_name": url.host, - # "scheme": self.scheme, - "context": self.context, - "header": False, # url.query.get("header") == "true", + "engine_name": url.database } + # If URL override is not provided leave it to the sdk to determine the endpoint + if "FIREBOLT_BASE_URL" in os.environ: + kwargs["api_endpoint"] = os.environ["FIREBOLT_BASE_URL"] return ([], kwargs) def get_schema_names(self, connection, **kwargs): diff --git a/src/firebolt_db/memoized.py b/src/firebolt_db/memoized.py deleted file mode 100644 index b99760b..0000000 --- a/src/firebolt_db/memoized.py +++ /dev/null @@ -1,34 +0,0 @@ -import collections -import functools - - -class memoized(object): - """ - Decorator. Caches a function's return value each time it is called. - If called later with the same arguments, the cached value is returned - (not reevaluated). - """ - - def __init__(self, func): - self.func = func - self.cache = {} - - def __call__(self, *args): - if not isinstance(args, collections.Hashable): - # uncacheable. a list, for instance. - # better to not cache than blow up. - return self.func(*args) - if args in self.cache: - return self.cache[args] - else: - value = self.func(*args) - self.cache[args] = value - return value - - def __repr__(self): - """Return the function's docstring.""" - return self.func.__doc__ - - def __get__(self, obj, objtype): - """Support instance methods.""" - return functools.partial(self.__call__, obj) diff --git a/tests/test_firebolt_api_service.py b/tests/test_firebolt_api_service.py deleted file mode 100644 index c3c66cc..0000000 --- a/tests/test_firebolt_api_service.py +++ /dev/null @@ -1,113 +0,0 @@ -import pytest -import os -from datetime import date -from requests.exceptions import HTTPError - -from firebolt_db.firebolt_api_service import FireboltApiService -from firebolt_db import exceptions - -test_username = os.environ["username"] -test_password = os.environ["password"] -test_engine_name = os.environ["engine_name"] -test_db_name = os.environ["db_name"] -query = 'select * from ci_fact_table limit 1' - -access_token = FireboltApiService.get_access_token(test_username, test_password) -if test_engine_name is None: - test_engine_url = FireboltApiService.get_engine_url_by_db(test_db_name, access_token["access_token"]) -else: - test_engine_url = FireboltApiService.get_engine_url_by_engine(test_engine_name, access_token["access_token"]) - - -class TestFireboltApiService: - - def test_get_connection_success(self): - response = FireboltApiService.get_connection(test_username, test_password, - test_engine_name, test_db_name, date.today()) - if type(response) == HTTPError: - assert response.response.status_code == 503 - else: - assert response != "" - - def test_get_connection_invalid_credentials(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_connection('username', 'password', test_engine_name, - test_db_name, date.today())[0] - - def test_get_connection_invalid_engine_name(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_connection(test_username, test_password, 'engine_name', - test_db_name, date.today())[2] - - def test_get_connection_invalid_db_name(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_connection(test_username, test_password, None, - 'test_db_name', date.today())[2] - - def test_get_access_token_success(self): - assert access_token["access_token"] != "" - - def test_get_access_token_invalid_credentials(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_access_token('username', 'password') - - def test_get_access_token_via_refresh_success(self): - assert FireboltApiService.get_access_token_via_refresh(access_token["refresh_token"]) != "" - - def test_get_access_token_via_refresh_invalid_token(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_access_token_via_refresh('refresh_token') - - def test_get_engine_url_by_db_success(self): - if test_engine_name is None: - assert test_engine_url != "" - else: - assert FireboltApiService.get_engine_url_by_db(test_db_name, access_token["access_token"]) != "" - - def test_get_engine_url_by_db_invalid_db_name(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_engine_url_by_db('db_name', access_token["access_token"]) - - def test_get_engine_url_by_db_invalid_header(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_engine_url_by_db(test_db_name, 'header') != "" - - def test_get_engine_url_by_engine_success(self): - if test_engine_name is not None: - assert FireboltApiService.get_engine_url_by_engine(test_engine_name, access_token["access_token"]) != "" - else: - assert test_engine_url != "" - - def test_get_engine_url_by_engine_invalid_engine_name(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_engine_url_by_engine('engine_name', access_token["access_token"]) - - def test_get_engine_url_by_engine_invalid_header(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.get_engine_url_by_engine(test_engine_name, 'header') != "" - - def test_run_query_success(self): - try: - response = FireboltApiService.run_query(access_token["access_token"], test_engine_url, test_db_name, query) - assert response != "" - except exceptions.InternalError as http_err: - assert http_err != "" - - def test_run_query_invalid_url(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.run_query(access_token["access_token"], "", test_db_name, query) != {} - - def test_run_query_invalid_schema(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.run_query(access_token["access_token"], test_engine_url, 'db_name', query) - - def test_run_query_invalid_header(self): - try: - response = FireboltApiService.run_query('header', test_engine_url, test_db_name, query) - assert response != "" - except exceptions.InternalError as e_info: - assert e_info != "" - - def test_run_query_invalid_query(self): - with pytest.raises(Exception) as e_info: - response = FireboltApiService.run_query(access_token["access_token"], test_engine_url, test_db_name, 'query') diff --git a/tests/test_fireboltconnector.py b/tests/test_fireboltconnector.py deleted file mode 100644 index 2429355..0000000 --- a/tests/test_fireboltconnector.py +++ /dev/null @@ -1,178 +0,0 @@ -import os - -import pytest - -from firebolt_db import firebolt_connector -from firebolt_db import exceptions - -test_username = os.environ["username"] -test_password = os.environ["password"] -test_engine_name = os.environ["engine_name"] -test_db_name = os.environ["db_name"] - - -@pytest.fixture -def get_connection(): - return firebolt_connector.connect(test_engine_name, 8123, test_username, test_password, test_db_name) - - -class TestConnect: - - def test_connect_success(self): - user_email = test_username - password = test_password - db_name = test_engine_name - host = test_db_name - port = "8123" - connection = firebolt_connector.connect(host, port, user_email, password, db_name) - assert connection.access_token - assert connection.engine_url - - def test_connect_invalid_credentials(self): - user_email = test_username - password = "wrongpassword" - db_name = test_engine_name - host = test_db_name - port = "8123" - with pytest.raises(exceptions.InvalidCredentialsError): - firebolt_connector.connect(host, port, user_email, password, db_name) - - -def test_get_description_from_row_valid_rows(): - row = {'id': 1, 'name': 'John', 'is_eligible': True, 'some_array': [2, 4]} - result = firebolt_connector.get_description_from_row(row) - assert result[0][0] == 'id' - assert result[0][1] == firebolt_connector.Type.NUMBER - assert not result[0][6] - assert result[1][0] == 'name' - assert result[1][1] == firebolt_connector.Type.STRING - assert result[1][6] - assert result[2][0] == 'is_eligible' - assert result[2][1] == firebolt_connector.Type.BOOLEAN - assert not result[2][6] - assert result[3][0] == 'some_array' - assert result[3][1] == firebolt_connector.Type.ARRAY - assert not result[3][6] - - -def test_get_description_from_row_invalid_rows(): - row = {'id': {}} - with pytest.raises(Exception): - firebolt_connector.get_description_from_row(row) - - -def test_get_type(): - value_1 = "String Value" - value_2_1 = 5 - value_2_2 = 5.1 - value_3_1 = True - value_3_2 = False - value_4 = [] - assert firebolt_connector.get_type(value_1) == 1 - assert firebolt_connector.get_type(value_2_1) == 2 - assert firebolt_connector.get_type(value_2_2) == 2 - assert firebolt_connector.get_type(value_3_1) == 3 - assert firebolt_connector.get_type(value_3_2) == 3 - assert firebolt_connector.get_type(value_4) == 4 - - -def test_get_type_invalid_type(): - value = {} - with pytest.raises(Exception): - firebolt_connector.get_type(value) - - -class TestConnection: - - def test_cursor(self, get_connection): - connection = get_connection - assert len(connection.cursors) == 0 - cursor = connection.cursor() - assert len(connection.cursors) > 0 - assert type(cursor) == firebolt_connector.Cursor - - def test_commit(self): - pass - - def test_close(self, get_connection): - connection = get_connection - connection.cursor() - connection.close() - for cursor in connection.cursors: - assert cursor.closed - - -class TestCursor: - - def test_rowcount(self, get_connection): - connection = get_connection - query = "select * from ci_fact_table limit 10" - try: - cursor = connection.cursor().execute(query) - assert cursor.rowcount == 10 - except exceptions.InternalError as http_err: - assert http_err != "" - - def test_close(self, get_connection): - connection = get_connection - cursor = connection.cursor() - if not cursor.closed: - cursor.close() - assert cursor.closed - - def test_execute(self, get_connection): - query = 'select * from ci_fact_table ' \ - 'where l_orderkey=3184321 and l_partkey=65945' - connection = get_connection - cursor = connection.cursor() - assert not cursor._results - try: - cursor.execute(query) - assert cursor.rowcount == 1 - except exceptions.InternalError as http_err: - assert http_err != "" - - def test_executemany(self, get_connection): - query = "select * from ci_fact_table limit 10" - connection = get_connection - cursor = connection.cursor() - with pytest.raises(exceptions.NotSupportedError): - cursor.executemany(query) - - def test_fetchone(self, get_connection): - query = "select * from ci_fact_table limit 10" - connection = get_connection - cursor = connection.cursor() - assert not cursor._results - try: - cursor.execute(query) - result = cursor.fetchone() - assert isinstance(result, tuple) - except exceptions.InternalError as http_err: - assert http_err != "" - - def test_fetchmany(self, get_connection): - query = "select * from ci_fact_table limit 10" - connection = get_connection - cursor = connection.cursor() - assert not cursor._results - try: - cursor.execute(query) - result = cursor.fetchmany(3) - assert isinstance(result, list) - assert len(result) == 3 - except exceptions.InternalError as http_err: - assert http_err != "" - - def test_fetchall(self, get_connection): - query = "select * from ci_fact_table limit 10" - connection = get_connection - cursor = connection.cursor() - assert not cursor._results - try: - cursor.execute(query) - result = cursor.fetchall() - assert isinstance(result, list) - assert len(result) == 10 - except exceptions.InternalError as http_err: - assert http_err != "" diff --git a/tests/test_fireboltdialect.py b/tests/test_fireboltdialect.py index a22d920..a613ccc 100644 --- a/tests/test_fireboltdialect.py +++ b/tests/test_fireboltdialect.py @@ -4,7 +4,6 @@ import sqlalchemy from firebolt_db import firebolt_dialect -from firebolt_db import exceptions from sqlalchemy.engine import url from sqlalchemy import create_engine @@ -29,16 +28,19 @@ def get_engine(): class TestFireboltDialect: def test_create_connect_args(self): + os.environ["FIREBOLT_BASE_URL"] = "test_url" connection_url = "test_engine://test_user@email:test_password@test_db_name/test_engine_name" u = url.make_url(connection_url) result_list, result_dict = dialect.create_connect_args(u) - assert result_dict["host"] == "test_engine_name" - assert result_dict["port"] == 5432 + assert result_dict["engine_name"] == "test_engine_name" assert result_dict["username"] == "test_user@email" assert result_dict["password"] == "test_password" - assert result_dict["db_name"] == "test_db_name" - assert result_dict["context"] == {} - assert not result_dict["header"] + assert result_dict["database"] == "test_db_name" + assert result_dict["api_endpoint"] == "test_url" + # No endpoint override + del os.environ["FIREBOLT_BASE_URL"] + result_list, result_dict = dialect.create_connect_args(u) + assert "api_endpoint" not in result_dict def test_get_schema_names(self, get_engine): engine = get_engine