Skip to content

Commit

Permalink
Refactor postgresql provider to allow exposing cached engine and tabl…
Browse files Browse the repository at this point in the history
…e models to other providers (#1643)

* refactored in order to allow other postgresql-based providers to make use of cached engine and table models

* Replaced lru_cache with cache
  • Loading branch information
ricardogsilva committed May 15, 2024
1 parent dcabbb9 commit 51976ac
Showing 1 changed file with 108 additions and 110 deletions.
218 changes: 108 additions & 110 deletions pygeoapi/provider/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
# psql -U postgres -h 127.0.0.1 -p 5432 test

import logging
import functools

from copy import deepcopy
from geoalchemy2 import Geometry # noqa - this isn't used explicitly but is needed to process Geometry columns
Expand All @@ -69,8 +70,6 @@
from pygeoapi.util import get_transform_from_crs


_ENGINE_STORE = {}
_TABLE_MODEL_STORE = {}
LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -107,7 +106,21 @@ def __init__(self, provider_def):
if provider_def.get('options'):
options = provider_def['options']
self._store_db_parameters(provider_def['data'], options)
self._engine, self.table_model = self._get_engine_and_table_model()
self._engine = get_engine(
self.db_host,
self.db_port,
self.db_name,
self.db_user,
self._db_password,
**(self.db_options or {})
)
self.table_model = get_table_model(
self.table,
self.id_field,
self.db_search_path,
self._engine
)

LOGGER.debug(f'DB connection: {repr(self._engine.url)}')
self.fields = self.get_fields()

Expand Down Expand Up @@ -272,116 +285,13 @@ def _store_db_parameters(self, parameters, options):
self.db_host = parameters.get('host')
self.db_port = parameters.get('port', 5432)
self.db_name = parameters.get('dbname')
self.db_search_path = parameters.get('search_path', ['public'])
# db_search_path gets converted to a tuple here in order to ensure it
# is hashable - which allows us to use functools.cache() when
# reflecting the table definition from the DB
self.db_search_path = tuple(parameters.get('search_path', ['public']))
self._db_password = parameters.get('password')
self.db_options = options

def _get_engine_and_table_model(self):
"""
Create a SQL Alchemy engine for the database and reflect the table
model. Use existing versions from stores if available to allow reuse
of Engine connection pool and save expensive table reflection.
"""
# One long-lived engine is used per database URL:
# https://docs.sqlalchemy.org/en/14/core/connections.html#basic-usage
engine_store_key = (self.db_user, self.db_host, self.db_port,
self.db_name)
try:
engine = _ENGINE_STORE[engine_store_key]
except KeyError:
conn_str = URL.create(
'postgresql+psycopg2',
username=self.db_user,
password=self._db_password,
host=self.db_host,
port=self.db_port,
database=self.db_name
)
conn_args = {
'client_encoding': 'utf8',
'application_name': 'pygeoapi'
}
if self.db_options:
conn_args.update(self.db_options)
engine = create_engine(
conn_str,
connect_args=conn_args,
pool_pre_ping=True)
_ENGINE_STORE[engine_store_key] = engine

# Reuse table model if one exists
table_model_store_key = (self.db_host, self.db_port, self.db_name,
self.table)
try:
table_model = _TABLE_MODEL_STORE[table_model_store_key]
except KeyError:
table_model = self._reflect_table_model(engine)
_TABLE_MODEL_STORE[table_model_store_key] = table_model

return engine, table_model

def _reflect_table_model(self, engine):
"""
Reflect database metadata to create a SQL Alchemy model corresponding
to target table. This requires a database query and is expensive to
perform.
"""
metadata = MetaData()

# Look for table in the first schema in the search path
try:
schema = self.db_search_path[0]
metadata.reflect(
bind=engine, schema=schema, only=[self.table], views=True)
except OperationalError:
msg = (f"Could not connect to {repr(engine.url)} "
"(password hidden).")
raise ProviderConnectionError(msg)
except InvalidRequestError:
msg = (f"Table '{self.table}' not found in schema '{schema}' "
f"on {repr(engine.url)}.")
raise ProviderQueryError(msg)

# Create SQLAlchemy model from reflected table
# It is necessary to add the primary key constraint because SQLAlchemy
# requires it to reflect the table, but a view in a PostgreSQL database
# does not have a primary key defined.
sqlalchemy_table_def = metadata.tables[f'{schema}.{self.table}']
try:
sqlalchemy_table_def.append_constraint(
PrimaryKeyConstraint(self.id_field)
)
except KeyError:
msg = (f"No such id_field column ({self.id_field}) on "
f"{schema}.{self.table}.")
raise ProviderQueryError(msg)

Base = automap_base(metadata=metadata)
Base.prepare(
name_for_scalar_relationship=self._name_for_scalar_relationship,
)
TableModel = getattr(Base.classes, self.table)

return TableModel

@staticmethod
def _name_for_scalar_relationship(
base, local_cls, referred_cls, constraint,
):
"""Function used when automapping classes and relationships from
database schema and fixes potential naming conflicts.
"""
name = referred_cls.__name__.lower()
local_table = local_cls.__table__
if name in local_table.columns:
newname = name + '_'
LOGGER.debug(
f'Already detected column name {name!r} in table '
f'{local_table!r}. Using {newname!r} for relationship name.'
)
return newname
return name

def _sqlalchemy_to_feature(self, item, crs_transform_out=None):
feature = {
'type': 'Feature'
Expand Down Expand Up @@ -516,3 +426,91 @@ def _get_crs_transform(self, crs_transform_spec=None):
else:
crs_transform = None
return crs_transform


@functools.cache
def get_engine(
host: str,
port: str,
database: str,
user: str,
password: str,
**connection_options
):
"""Create SQL Alchemy engine."""
conn_str = URL.create(
'postgresql+psycopg2',
username=user,
password=password,
host=host,
port=int(port),
database=database
)
conn_args = {
'client_encoding': 'utf8',
'application_name': 'pygeoapi',
**connection_options,
}
engine = create_engine(
conn_str,
connect_args=conn_args,
pool_pre_ping=True)
return engine


@functools.cache
def get_table_model(
table_name: str,
id_field: str,
db_search_path: tuple[str],
engine,
):
"""Reflect table."""
metadata = MetaData()

# Look for table in the first schema in the search path
schema = db_search_path[0]
try:
metadata.reflect(
bind=engine, schema=schema, only=[table_name], views=True)
except OperationalError:
raise ProviderConnectionError(
f"Could not connect to {repr(engine.url)} (password hidden).")
except InvalidRequestError:
raise ProviderQueryError(
f"Table '{table_name}' not found in schema '{schema}' "
f"on {repr(engine.url)}."
)

# Create SQLAlchemy model from reflected table
# It is necessary to add the primary key constraint because SQLAlchemy
# requires it to reflect the table, but a view in a PostgreSQL database
# does not have a primary key defined.
sqlalchemy_table_def = metadata.tables[f'{schema}.{table_name}']
try:
sqlalchemy_table_def.append_constraint(PrimaryKeyConstraint(id_field))
except KeyError:
raise ProviderQueryError(
f"No such id_field column ({id_field}) on {schema}.{table_name}.")

_Base = automap_base(metadata=metadata)
_Base.prepare(
name_for_scalar_relationship=_name_for_scalar_relationship,
)
return getattr(_Base.classes, table_name)


def _name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
"""Function used when automapping classes and relationships from
database schema and fixes potential naming conflicts.
"""
name = referred_cls.__name__.lower()
local_table = local_cls.__table__
if name in local_table.columns:
newname = name + '_'
LOGGER.debug(
f'Already detected column name {name!r} in table '
f'{local_table!r}. Using {newname!r} for relationship name.'
)
return newname
return name

0 comments on commit 51976ac

Please sign in to comment.