|
1 |
| -import warnings |
2 |
| - |
| 1 | +import logging |
| 2 | +import os |
3 | 3 | import string
|
| 4 | +import warnings |
4 | 5 | from getpass import getuser
|
| 6 | +from hashlib import sha256 |
5 | 7 | from logging import warning
|
6 | 8 |
|
7 |
| -import logging |
8 |
| -import os |
9 | 9 | import pandas as pd
|
10 | 10 | import sqlalchemy
|
11 |
| - |
12 | 11 | from omegaml.backends.basedata import BaseDataBackend
|
| 12 | +from omegaml.util import ProcessLocal, KeepMissing |
13 | 13 |
|
14 | 14 | try:
|
15 | 15 | import snowflake
|
|
19 | 19 | except:
|
20 | 20 | pass
|
21 | 21 |
|
22 |
| -CNX_CACHE = {} |
23 |
| -ALWAYS_CACHE = False |
| 22 | +#: override by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE |
| 23 | +ALWAYS_CACHE = True |
| 24 | +# -- enabled by default as this is the least-surprised option |
| 25 | +# -- consistent with sqlalchemy connection pooling defaults |
| 26 | +#: kwargs for create_engine() |
| 27 | +ENGINE_KWARGS = dict(echo=False, pool_pre_ping=True, pool_recycle=3600) |
| 28 | +# -- echo=False - do not log to stdout |
| 29 | +# -- pool_pre_ping=True - always check, re-establish connection if no longer working |
| 30 | +# -- pool_recylce=N - do not reuse connections older than N seconds |
24 | 31 |
|
25 | 32 | logger = logging.getLogger(__name__)
|
26 | 33 |
|
@@ -135,26 +142,42 @@ class SQLAlchemyBackend(BaseDataBackend):
|
135 | 142 |
|
136 | 143 | """
|
137 | 144 | KIND = 'sqlalchemy.conx'
|
| 145 | + #: sqlalchemy.Engine cache to enable pooled connections |
| 146 | + __CNX_CACHE = ProcessLocal() |
| 147 | + |
| 148 | + # -- https://docs.sqlalchemy.org/en/14/core/pooling.html#module-sqlalchemy.pool |
| 149 | + # -- create_engine() must be called per-process, hence using ProcessLocal |
| 150 | + # -- meaning when using a multiprocessing.Pool or other fork()-ed processes, |
| 151 | + # the cache will be cleared in child processes, forcing the engine to be |
| 152 | + # recreated automatically in _get_connection |
138 | 153 |
|
139 | 154 | @classmethod
|
140 | 155 | def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs):
|
141 | 156 | valid = cls._is_valid_url(cls, obj)
|
142 | 157 | support_via = cls._supports_via(cls, data_store, name, obj)
|
143 | 158 | return valid or support_via
|
144 | 159 |
|
145 |
| - def drop(self, name, **kwargs): |
146 |
| - if name in CNX_CACHE: |
147 |
| - del CNX_CACHE[name] |
| 160 | + def drop(self, name, secrets=None, **kwargs): |
| 161 | + # ensure cache is cleared |
| 162 | + clear_cache = True if secrets is None else False |
| 163 | + try: |
| 164 | + self.get(name, secrets=secrets, raw=True, keep=False) |
| 165 | + except KeyError as e: |
| 166 | + warnings.warn(f'Connection cache was cleared, however secret {e} was missing.') |
| 167 | + clear_cache = True |
| 168 | + if clear_cache: |
| 169 | + self.__CNX_CACHE.clear() |
148 | 170 | return super().drop(name, **kwargs)
|
149 | 171 |
|
150 | 172 | def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
|
151 |
| - secrets=None, index=True, keep=False, lazy=False, table=None, *args, **kwargs): |
| 173 | + secrets=None, index=True, keep=None, lazy=False, table=None, *args, **kwargs): |
152 | 174 | """ retrieve a stored connection or query data from connection
|
153 | 175 |
|
154 | 176 | Args:
|
155 | 177 | name (str): the name of the connection
|
156 | 178 | secrets (dict): dict to resolve variables in the connection string
|
157 |
| - keep (bool): if True connection is kept open. |
| 179 | + keep (bool): if True connection is kept open, defaults to True (change |
| 180 | + default as om.defaults.SQLALCHEMY_ALWAYS_CACHE) |
158 | 181 | table (str): the name of the table, will be prefixed with the
|
159 | 182 | store's bucket name unless the table is specified as ':name'
|
160 | 183 |
|
@@ -195,8 +218,10 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
|
195 | 218 | if not raw and not valid_sql(sql):
|
196 | 219 | sql = f'select * from {table}'
|
197 | 220 | chunksize = chunksize or meta.kind_meta.get('chunksize')
|
198 |
| - keep = getattr(self.data_store.defaults, 'SQLALCHEMY_ALWAYS_CACHE', |
199 |
| - ALWAYS_CACHE) or keep |
| 221 | + _default_keep = getattr(self.data_store.defaults, |
| 222 | + 'SQLALCHEMY_ALWAYS_CACHE', |
| 223 | + ALWAYS_CACHE) |
| 224 | + keep = keep if keep is not None else _default_keep |
200 | 225 | if connection_str:
|
201 | 226 | secrets = self._get_secrets(meta, secrets)
|
202 | 227 | connection = self._get_connection(name, connection_str, secrets=secrets, keep=keep)
|
@@ -352,24 +377,30 @@ def _get_connection(self, name, connection_str, secrets=None, keep=False):
|
352 | 377 | from sqlalchemy import create_engine
|
353 | 378 |
|
354 | 379 | connection = None
|
| 380 | + cache_key = None |
355 | 381 | try:
|
| 382 | + # SECDEV: the cache key is a secret in order to avoid privilege escalation |
| 383 | + # -- if it is not secret, user A could create the connection (=> cache) |
| 384 | + # -- user B could reuse the connection by retrieving the dataset without secrets |
| 385 | + # -- this way the user needs to have the same secrets in order to reuse the connection |
356 | 386 | connection_str = connection_str.format(**(secrets or {}))
|
357 |
| - engine = create_engine(connection_str, echo=False) |
358 |
| - connection = CNX_CACHE.get(name) or engine.connect() |
| 387 | + cache_key = sha256(f'{name}:{connection_str}'.encode('utf8')).hexdigest() |
| 388 | + engine = self.__CNX_CACHE.get(cache_key) or create_engine(connection_str, **ENGINE_KWARGS) |
| 389 | + connection = engine.connect() |
359 | 390 | except KeyError as e:
|
360 | 391 | msg = ('{e}, ensure secrets are specified for connection '
|
361 | 392 | '>{connection_str}<'.format(**locals()))
|
362 | 393 | raise KeyError(msg)
|
363 | 394 | except Exception as e:
|
364 | 395 | if connection is not None:
|
365 | 396 | connection.close()
|
| 397 | + self.__CNX_CACHE.pop(cache_key, None) |
366 | 398 | raise
|
367 | 399 | else:
|
368 | 400 | if keep:
|
369 |
| - CNX_CACHE[name] = connection |
| 401 | + self.__CNX_CACHE[cache_key] = engine |
370 | 402 | else:
|
371 |
| - if name in CNX_CACHE: |
372 |
| - del CNX_CACHE[name] |
| 403 | + self.__CNX_CACHE.pop(cache_key, None) |
373 | 404 | return connection
|
374 | 405 |
|
375 | 406 | def copy_from_sql(self, sql, connstr, name, chunksize=10000,
|
@@ -445,7 +476,7 @@ def _supports_via(self, data_store, name, obj):
|
445 | 476 |
|
446 | 477 | def _get_secrets(self, meta, secrets):
|
447 | 478 | secrets_specs = meta.kind_meta.get('secrets')
|
448 |
| - values = dict(os.environ) |
| 479 | + values = dict(os.environ) if self.data_store.defaults.OMEGA_ALLOW_ENV_CONFIG else dict() |
449 | 480 | values.update(**self.data_store.defaults)
|
450 | 481 | if not secrets and secrets_specs:
|
451 | 482 | dsname = secrets_specs['dsname']
|
@@ -500,7 +531,7 @@ def _sanitize_statement(self, sql, sqlvars):
|
500 | 531 | # -- sqlvars is not used in constructing sql text
|
501 | 532 | v = sqlvars[k]
|
502 | 533 | if isinstance(v, (list, tuple)):
|
503 |
| - bind_vars = { f'{k}_{i}': lv for i, lv in enumerate(v)} |
| 534 | + bind_vars = {f'{k}_{i}': lv for i, lv in enumerate(v)} |
504 | 535 | placeholders = ','.join(f':{bk}' for bk in bind_vars)
|
505 | 536 | sql = sql.replace(f':{k}', f'({placeholders})')
|
506 | 537 | sqlvars.update(bind_vars)
|
@@ -565,7 +596,7 @@ def _format_dict(d, replace=None, **kwargs):
|
565 | 596 | if replace:
|
566 | 597 | del d[k]
|
567 | 598 | k = k.replace(*replace) if replace else k
|
568 |
| - d[k] = v.format(**kwargs) if isinstance(v, str) else v |
| 599 | + d[k] = v.format_map(KeepMissing(kwargs)) if isinstance(v, str) else v |
569 | 600 | return d
|
570 | 601 |
|
571 | 602 |
|
|
0 commit comments