Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Databricks in SQLDatabase #4702

Merged
merged 14 commits into from
May 19, 2023
2 changes: 1 addition & 1 deletion docs/modules/chains/examples/sqlite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
}
},
"source": [
"Under the hood, LangChain uses SQLAlchemy to connect to SQL databases. The `SQLDatabaseChain` can therefore be used with any SQL dialect supported by SQLAlchemy, such as MS SQL, MySQL, MariaDB, PostgreSQL, Oracle SQL, and SQLite. Please refer to the SQLAlchemy documentation for more information about requirements for connecting to your database. For example, a connection to MySQL requires an appropriate connector such as PyMySQL. A URI for a MySQL connection might look like: `mysql+pymysql://user:pass@some_mysql_db_address/db_name`\n",
"Under the hood, LangChain uses SQLAlchemy to connect to SQL databases. The `SQLDatabaseChain` can therefore be used with any SQL dialect supported by SQLAlchemy, such as MS SQL, MySQL, MariaDB, PostgreSQL, Oracle SQL, Databricks and SQLite. Please refer to the SQLAlchemy documentation for more information about requirements for connecting to your database. For example, a connection to MySQL requires an appropriate connector such as PyMySQL. A URI for a MySQL connection might look like: `mysql+pymysql://user:pass@some_mysql_db_address/db_name`. To connect to Databricks, it is recommended to use the handy method `SQLDatabase.from_databricks(catalog, schema, host, api_token, (warehouse_id|cluster_id))`.\n",
"\n",
"This demonstration uses SQLite and the example Chinook database.\n",
"To set it up, follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
Expand Down
97 changes: 97 additions & 0 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable

from langchain import utils


def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
Expand Down Expand Up @@ -110,6 +112,101 @@ def from_uri(
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)

@classmethod
def from_databricks(
cls,
catalog: str,
schema: str,
host: Optional[str] = None,
api_token: Optional[str] = None,
warehouse_id: Optional[str] = None,
cluster_id: Optional[str] = None,
**kwargs: Any,
) -> SQLDatabase:
"""
Class method to create an SQLDatabase instance from a Databricks connection.
This method requires the 'databricks-sql-connector' package. If not already
installed, it can be added using `pip install databricks-sql-connector`.

Args:
catalog (str): The catalog name in the Databricks database.
schema (str): The schema name in the catalog.
host (Optional[str]): The host URL of the Databricks instance. If not
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
provided, it will be fetched from the environment variable
'DATABRICKS_HOST' or from the current Databricks REPL context.
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
Defaults to None.
api_token (Optional[str]): The API token for the Databricks instance.
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
If not provided, it will be fetched from the environment variable
'DATABRICKS_API_TOKEN' or from the current Databricks REPL context.
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
Defaults to None.
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL.
If provided, the method will configure the connection to use this
warehouse. Cannot be used in conjunction with 'cluster_id'. Defaults
to None.
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime.
If provided, the method will configure the connection to use this
cluster. Cannot be used in conjunction with 'warehouse_id'. If both
'warehouse_id' and 'cluster_id' are None, it will use the cluster ID
from the current Databricks REPL context. Defaults to None.
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
**kwargs (Any): Additional keyword arguments for the `from_uri` method.

Returns:
SQLDatabase: An instance of SQLDatabase configured with the provided
Databricks connection details.

Raises:
ValueError: If the 'databricks-sql-connector' package is not found, or
if both 'warehouse_id' and 'cluster_id' are provided, or if neither
'warehouse_id' nor 'cluster_id' are provided and there is no current
Databricks REPL context to get the cluster ID from.
gengliangwang marked this conversation as resolved.
Show resolved Hide resolved
"""
try:
from databricks import sql # noqa: F401
except ImportError:
raise ValueError(
"databricks-sql-connector package not found, please install with"
" `pip install databricks-sql-connector`"
)
context = None
try:
from dbruntime.databricks_repl_context import get_context

context = get_context()
except ImportError:
pass

default_host = context.browserHostName if context else None
if host is None:
host = utils.get_from_env("host", "DATABRICKS_HOST", default_host)

default_api_token = context.apiToken if context else None
if api_token is None:
api_token = utils.get_from_env(
"api_token", "DATABRICKS_API_TOKEN", default_api_token
dev2049 marked this conversation as resolved.
Show resolved Hide resolved
)

if warehouse_id is None and cluster_id is None:
if context:
cluster_id = context.clusterId
else:
raise ValueError(
"Need to provide either 'warehouse_id' or 'cluster_id'."
)

if warehouse_id and cluster_id:
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")

if warehouse_id:
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
else:
http_path = f"/sql/protocolv1/o/0/{cluster_id}"

uri = (
f"databricks://token:{api_token}@{host}?"
f"http_path={http_path}&catalog={catalog}&schema={schema}"
)
return cls.from_uri(uri, engine_args=None, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity are engine_args explicitly not allowed in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. I just added the engine_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just verified it:
image


@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
Expand Down