Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions docs/examples/litestar_extension_migrations_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Example demonstrating how to use Litestar extension migrations with SQLSpec.

This example shows how to configure SQLSpec to include Litestar's session table
migrations, which will create dialect-specific tables when you run migrations.
"""

from pathlib import Path

from litestar import Litestar

from sqlspec.adapters.sqlite.config import SqliteConfig
from sqlspec.extensions.litestar.plugin import SQLSpec
from sqlspec.extensions.litestar.store import SQLSpecSessionStore
from sqlspec.migrations.commands import MigrationCommands

# Configure database with extension migrations enabled
db_config = SqliteConfig(
pool_config={"database": "app.db"},
migration_config={
"script_location": "migrations",
"version_table_name": "ddl_migrations",
# Enable Litestar extension migrations
"include_extensions": ["litestar"],
},
)

# Create SQLSpec plugin with session store
sqlspec_plugin = SQLSpec(db_config)

# Configure session store to use the database
session_store = SQLSpecSessionStore(
config=db_config,
table_name="litestar_sessions", # Matches migration table name
)

# Create Litestar app with SQLSpec and sessions
app = Litestar(plugins=[sqlspec_plugin], stores={"sessions": session_store})


def run_migrations() -> None:
"""Run database migrations including extension migrations.

This will:
1. Create your project's migrations (from migrations/ directory)
2. Create Litestar extension migrations (session table with dialect-specific types)
"""
commands = MigrationCommands(db_config)

# Initialize migrations directory if it doesn't exist
migrations_dir = Path("migrations")
if not migrations_dir.exists():
commands.init("migrations")

# Run all migrations including extension migrations
# The session table will be created with:
# - JSONB for PostgreSQL
# - JSON for MySQL/MariaDB
# - TEXT for SQLite
commands.upgrade()

# Check current version
current = commands.current(verbose=True)
print(f"Current migration version: {current}")


if __name__ == "__main__":
# Run migrations before starting the app
run_migrations()

# Start the application
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
166 changes: 166 additions & 0 deletions docs/examples/litestar_session_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Example showing how to use SQLSpec session backend with Litestar."""

from typing import Any

from litestar import Litestar, get, post
from litestar.config.session import SessionConfig
from litestar.connection import Request
from litestar.datastructures import State

from sqlspec.adapters.sqlite.config import SqliteConfig
from sqlspec.extensions.litestar import SQLSpec, SQLSpecSessionBackend, SQLSpecSessionConfig

# Configure SQLSpec with SQLite database
# Include Litestar extension migrations to automatically create session tables
sqlite_config = SqliteConfig(
pool_config={"database": "sessions.db"},
migration_config={
"script_location": "migrations",
"version_table_name": "sqlspec_migrations",
"include_extensions": ["litestar"], # Include Litestar session table migrations
},
)

# Create SQLSpec plugin
sqlspec_plugin = SQLSpec(sqlite_config)

# Create session backend using SQLSpec
# Note: The session table will be created automatically when you run migrations
# Example: sqlspec migrations upgrade --head
session_backend = SQLSpecSessionBackend(
config=SQLSpecSessionConfig(
table_name="litestar_sessions",
session_id_column="session_id",
data_column="data",
expires_at_column="expires_at",
created_at_column="created_at",
)
)

# Configure session middleware
session_config = SessionConfig(
backend=session_backend,
cookie_https_only=False, # Set to True in production
cookie_secure=False, # Set to True in production with HTTPS
cookie_domain="localhost",
cookie_path="/",
cookie_max_age=3600,
cookie_same_site="lax",
cookie_http_only=True,
session_cookie_name="sqlspec_session",
)


@get("/")
async def index() -> dict[str, str]:
"""Homepage route."""
return {"message": "SQLSpec Session Example"}


@get("/login")
async def login_form() -> str:
"""Simple login form."""
return """
<html>
<body>
<h2>Login</h2>
<form method="post" action="/login">
<input type="text" name="username" placeholder="Username" required>
<input type="password" name="password" placeholder="Password" required>
<button type="submit">Login</button>
</form>
</body>
</html>
"""


@post("/login")
async def login(data: dict[str, str], request: "Request[Any, Any, Any]") -> dict[str, str]:
"""Handle login and create session."""
username = data.get("username")
password = data.get("password")

# Simple authentication (use proper auth in production)
if username == "admin" and password == "secret":
# Store user data in session
request.set_session(
{"user_id": 1, "username": username, "login_time": "2024-01-01T12:00:00Z", "roles": ["admin", "user"]}
)
return {"message": f"Welcome, {username}!"}

return {"error": "Invalid credentials"}


@get("/profile")
async def profile(request: "Request[Any, Any, Any]") -> dict[str, str]:
"""User profile route - requires session."""
session_data = request.session

if not session_data or "user_id" not in session_data:
return {"error": "Not logged in"}

return {
"user_id": session_data["user_id"],
"username": session_data["username"],
"login_time": session_data["login_time"],
"roles": session_data["roles"],
}


@post("/logout")
async def logout(request: "Request[Any, Any, Any]") -> dict[str, str]:
"""Logout and clear session."""
request.clear_session()
return {"message": "Logged out successfully"}


@get("/admin/sessions")
async def admin_sessions(request: "Request[Any, Any, Any]", state: State) -> dict[str, any]:
"""Admin route to view all active sessions."""
session_data = request.session

if not session_data or "admin" not in session_data.get("roles", []):
return {"error": "Admin access required"}

# Get session backend from state
backend = session_backend
session_ids = await backend.get_all_session_ids()

return {
"active_sessions": len(session_ids),
"session_ids": session_ids[:10], # Limit to first 10 for display
}


@post("/admin/cleanup")
async def cleanup_sessions(request: "Request[Any, Any, Any]", state: State) -> dict[str, str]:
"""Admin route to clean up expired sessions."""
session_data = request.session

if not session_data or "admin" not in session_data.get("roles", []):
return {"error": "Admin access required"}

# Clean up expired sessions
backend = session_backend
await backend.delete_expired_sessions()

return {"message": "Expired sessions cleaned up"}


# Create Litestar application
app = Litestar(
route_handlers=[index, login_form, login, profile, logout, admin_sessions, cleanup_sessions],
plugins=[sqlspec_plugin],
session_config=session_config,
debug=True,
)


if __name__ == "__main__":
import uvicorn

print("Starting SQLSpec Session Example...")
print("Visit http://localhost:8000 to view the application")
print("Login with username 'admin' and password 'secret'")

uvicorn.run(app, host="0.0.0.0", port=8000)
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ test = [
"anyio",
"coverage>=7.6.1",
"pytest>=8.0.0",
"pytest-asyncio>=0.23.8",
"pytest-cov>=5.0.0",
"pytest-databases[postgres,oracle,mysql,bigquery,spanner,minio]>=0.12.2",
"pytest-mock>=3.14.0",
Expand Down Expand Up @@ -259,8 +258,7 @@ exclude_lines = [

[tool.pytest.ini_options]
addopts = ["-q", "-ra"]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
anyio_mode = "auto"
filterwarnings = [
"ignore::DeprecationWarning:pkg_resources.*",
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
Expand Down
18 changes: 8 additions & 10 deletions sqlspec/adapters/adbc/data_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
from collections.abc import Callable

from sqlspec.adapters.adbc.driver import AdbcDriver

logger = get_logger("adapters.adbc.data_dictionary")
Expand Down Expand Up @@ -52,35 +50,35 @@ def get_version(self, driver: SyncDriverAdapterBase) -> "VersionInfo | None":

try:
if dialect == "postgres":
version_str = adbc_driver.select_value("SELECT version()")
version_str = cast("str", adbc_driver.select_value("SELECT version()"))
if version_str:
match = POSTGRES_VERSION_PATTERN.search(str(version_str))
match = POSTGRES_VERSION_PATTERN.search(version_str)
if match:
major = int(match.group(1))
minor = int(match.group(2))
patch = int(match.group(3)) if match.group(3) else 0
return VersionInfo(major, minor, patch)

elif dialect == "sqlite":
version_str = adbc_driver.select_value("SELECT sqlite_version()")
version_str = cast("str", adbc_driver.select_value("SELECT sqlite_version()"))
if version_str:
match = SQLITE_VERSION_PATTERN.match(str(version_str))
match = SQLITE_VERSION_PATTERN.match(version_str)
if match:
major, minor, patch = map(int, match.groups())
return VersionInfo(major, minor, patch)

elif dialect == "duckdb":
version_str = adbc_driver.select_value("SELECT version()")
version_str = cast("str", adbc_driver.select_value("SELECT version()"))
if version_str:
match = DUCKDB_VERSION_PATTERN.search(str(version_str))
match = DUCKDB_VERSION_PATTERN.search(version_str)
if match:
major, minor, patch = map(int, match.groups())
return VersionInfo(major, minor, patch)

elif dialect == "mysql":
version_str = adbc_driver.select_value("SELECT VERSION()")
version_str = cast("str", adbc_driver.select_value("SELECT VERSION()"))
if version_str:
match = MYSQL_VERSION_PATTERN.search(str(version_str))
match = MYSQL_VERSION_PATTERN.search(version_str)
if match:
major, minor, patch = map(int, match.groups())
return VersionInfo(major, minor, patch)
Expand Down
Loading
Loading