Skip to content

Commit c1fcd51

Browse files
authored
Merge pull request #308 from omegaml/improve-sql-cnx-caching
Improve sql cnx caching
2 parents c9bf188 + 3f5bd2f commit c1fcd51

File tree

3 files changed

+106
-22
lines changed

3 files changed

+106
-22
lines changed

omegaml/backends/sqlalchemy.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
import warnings
2-
1+
import logging
2+
import os
33
import string
4+
import warnings
45
from getpass import getuser
6+
from hashlib import sha256
57
from logging import warning
68

7-
import logging
8-
import os
99
import pandas as pd
1010
import sqlalchemy
11-
1211
from omegaml.backends.basedata import BaseDataBackend
12+
from omegaml.util import ProcessLocal, KeepMissing
1313

1414
try:
1515
import snowflake
@@ -19,8 +19,15 @@
1919
except:
2020
pass
2121

22-
CNX_CACHE = {}
23-
ALWAYS_CACHE = False
22+
#: override by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE
23+
ALWAYS_CACHE = True
24+
# -- enabled by default as this is the least-surprised option
25+
# -- consistent with sqlalchemy connection pooling defaults
26+
#: kwargs for create_engine()
27+
ENGINE_KWARGS = dict(echo=False, pool_pre_ping=True, pool_recycle=3600)
28+
# -- echo=False - do not log to stdout
29+
# -- pool_pre_ping=True - always check, re-establish connection if no longer working
30+
# -- pool_recylce=N - do not reuse connections older than N seconds
2431

2532
logger = logging.getLogger(__name__)
2633

@@ -135,26 +142,42 @@ class SQLAlchemyBackend(BaseDataBackend):
135142
136143
"""
137144
KIND = 'sqlalchemy.conx'
145+
#: sqlalchemy.Engine cache to enable pooled connections
146+
__CNX_CACHE = ProcessLocal()
147+
148+
# -- https://docs.sqlalchemy.org/en/14/core/pooling.html#module-sqlalchemy.pool
149+
# -- create_engine() must be called per-process, hence using ProcessLocal
150+
# -- meaning when using a multiprocessing.Pool or other fork()-ed processes,
151+
# the cache will be cleared in child processes, forcing the engine to be
152+
# recreated automatically in _get_connection
138153

139154
@classmethod
140155
def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs):
141156
valid = cls._is_valid_url(cls, obj)
142157
support_via = cls._supports_via(cls, data_store, name, obj)
143158
return valid or support_via
144159

145-
def drop(self, name, **kwargs):
146-
if name in CNX_CACHE:
147-
del CNX_CACHE[name]
160+
def drop(self, name, secrets=None, **kwargs):
161+
# ensure cache is cleared
162+
clear_cache = True if secrets is None else False
163+
try:
164+
self.get(name, secrets=secrets, raw=True, keep=False)
165+
except KeyError as e:
166+
warnings.warn(f'Connection cache was cleared, however secret {e} was missing.')
167+
clear_cache = True
168+
if clear_cache:
169+
self.__CNX_CACHE.clear()
148170
return super().drop(name, **kwargs)
149171

150172
def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
151-
secrets=None, index=True, keep=False, lazy=False, table=None, *args, **kwargs):
173+
secrets=None, index=True, keep=None, lazy=False, table=None, *args, **kwargs):
152174
""" retrieve a stored connection or query data from connection
153175
154176
Args:
155177
name (str): the name of the connection
156178
secrets (dict): dict to resolve variables in the connection string
157-
keep (bool): if True connection is kept open.
179+
keep (bool): if True connection is kept open, defaults to True (change
180+
default as om.defaults.SQLALCHEMY_ALWAYS_CACHE)
158181
table (str): the name of the table, will be prefixed with the
159182
store's bucket name unless the table is specified as ':name'
160183
@@ -195,8 +218,10 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
195218
if not raw and not valid_sql(sql):
196219
sql = f'select * from {table}'
197220
chunksize = chunksize or meta.kind_meta.get('chunksize')
198-
keep = getattr(self.data_store.defaults, 'SQLALCHEMY_ALWAYS_CACHE',
199-
ALWAYS_CACHE) or keep
221+
_default_keep = getattr(self.data_store.defaults,
222+
'SQLALCHEMY_ALWAYS_CACHE',
223+
ALWAYS_CACHE)
224+
keep = keep if keep is not None else _default_keep
200225
if connection_str:
201226
secrets = self._get_secrets(meta, secrets)
202227
connection = self._get_connection(name, connection_str, secrets=secrets, keep=keep)
@@ -352,24 +377,30 @@ def _get_connection(self, name, connection_str, secrets=None, keep=False):
352377
from sqlalchemy import create_engine
353378

354379
connection = None
380+
cache_key = None
355381
try:
382+
# SECDEV: the cache key is a secret in order to avoid privilege escalation
383+
# -- if it is not secret, user A could create the connection (=> cache)
384+
# -- user B could reuse the connection by retrieving the dataset without secrets
385+
# -- this way the user needs to have the same secrets in order to reuse the connection
356386
connection_str = connection_str.format(**(secrets or {}))
357-
engine = create_engine(connection_str, echo=False)
358-
connection = CNX_CACHE.get(name) or engine.connect()
387+
cache_key = sha256(f'{name}:{connection_str}'.encode('utf8')).hexdigest()
388+
engine = self.__CNX_CACHE.get(cache_key) or create_engine(connection_str, **ENGINE_KWARGS)
389+
connection = engine.connect()
359390
except KeyError as e:
360391
msg = ('{e}, ensure secrets are specified for connection '
361392
'>{connection_str}<'.format(**locals()))
362393
raise KeyError(msg)
363394
except Exception as e:
364395
if connection is not None:
365396
connection.close()
397+
self.__CNX_CACHE.pop(cache_key, None)
366398
raise
367399
else:
368400
if keep:
369-
CNX_CACHE[name] = connection
401+
self.__CNX_CACHE[cache_key] = engine
370402
else:
371-
if name in CNX_CACHE:
372-
del CNX_CACHE[name]
403+
self.__CNX_CACHE.pop(cache_key, None)
373404
return connection
374405

375406
def copy_from_sql(self, sql, connstr, name, chunksize=10000,
@@ -445,7 +476,7 @@ def _supports_via(self, data_store, name, obj):
445476

446477
def _get_secrets(self, meta, secrets):
447478
secrets_specs = meta.kind_meta.get('secrets')
448-
values = dict(os.environ)
479+
values = dict(os.environ) if self.data_store.defaults.OMEGA_ALLOW_ENV_CONFIG else dict()
449480
values.update(**self.data_store.defaults)
450481
if not secrets and secrets_specs:
451482
dsname = secrets_specs['dsname']
@@ -500,7 +531,7 @@ def _sanitize_statement(self, sql, sqlvars):
500531
# -- sqlvars is not used in constructing sql text
501532
v = sqlvars[k]
502533
if isinstance(v, (list, tuple)):
503-
bind_vars = { f'{k}_{i}': lv for i, lv in enumerate(v)}
534+
bind_vars = {f'{k}_{i}': lv for i, lv in enumerate(v)}
504535
placeholders = ','.join(f':{bk}' for bk in bind_vars)
505536
sql = sql.replace(f':{k}', f'({placeholders})')
506537
sqlvars.update(bind_vars)
@@ -565,7 +596,7 @@ def _format_dict(d, replace=None, **kwargs):
565596
if replace:
566597
del d[k]
567598
k = k.replace(*replace) if replace else k
568-
d[k] = v.format(**kwargs) if isinstance(v, str) else v
599+
d[k] = v.format_map(KeepMissing(kwargs)) if isinstance(v, str) else v
569600
return d
570601

571602

omegaml/tests/core/test_sqlalchemy.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,27 @@ def test_put_connection_with_secrets(self):
5151
conn = om.datasets.get('testsqlite', raw=True)
5252
self.assertIsInstance(conn, Connection)
5353

54+
def test_connection_cache(self):
55+
""" test connection caching
56+
"""
57+
from omegaml.backends import sqlalchemy
58+
om = self.om
59+
cnx = 'sqlite:///{user}.db'
60+
om.datasets.put(cnx, 'testsqlite', kind=SQLAlchemyBackend.KIND)
61+
conn = om.datasets.get('testsqlite', raw=True, secrets=dict(user='user'), keep=True)
62+
conn_ = om.datasets.get('testsqlite', raw=True, secrets=dict(user='user'), keep=True)
63+
self.assertEqual(conn.engine, conn_.engine)
64+
# drop should clear cache
65+
om.datasets.drop('testsqlite', secrets=dict(user='user'))
66+
conn_ = om.datasets.get('testsqlite', raw=True, secrets=dict(user='user'), keep=True)
67+
self.assertIsNone(conn_)
68+
self.assertTrue(len(sqlalchemy.SQLAlchemyBackend._SQLAlchemyBackend__CNX_CACHE) == 0)
69+
# even if we drop without secrets, cache is cleared
70+
om.datasets.put(cnx, 'testsqlite', kind=SQLAlchemyBackend.KIND)
71+
om.datasets.get('testsqlite', raw=True, secrets=dict(user='user'), keep=True)
72+
om.datasets.drop('testsqlite')
73+
self.assertTrue(len(sqlalchemy.SQLAlchemyBackend._SQLAlchemyBackend__CNX_CACHE) == 0)
74+
5475
def test_put_connection_with_sql(self):
5576
"""
5677
store generic sqlalchemy connection with sql, same principle as a view

omegaml/util.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,3 +1088,35 @@ def __str__(self):
10881088
# on posix systems, this is a noop
10891089
path = super().__str__()
10901090
return path if os.name != 'nt' else path.replace('\\', '/')
1091+
1092+
1093+
class ProcessLocal(dict):
1094+
def __init__(self, *args, **kwargs):
1095+
self._pid = os.getpid()
1096+
super().__init__(*args, **kwargs)
1097+
1098+
def _check_pid(self):
1099+
if self._pid != os.getpid():
1100+
self.clear()
1101+
self._pid = os.getpid()
1102+
1103+
def __getitem__(self, k):
1104+
self._check_pid()
1105+
return super().__getitem__(k)
1106+
1107+
def keys(self):
1108+
self._check_pid()
1109+
return super().keys()
1110+
1111+
def __contains__(self, item):
1112+
self._check_pid()
1113+
return super().__contains__(item)
1114+
1115+
1116+
class KeepMissing(dict):
1117+
# a missing '{key}' is replaced by '{key}'
1118+
# in order to avoid raising KeyError
1119+
# see str.format_map
1120+
def __missing__(self, key):
1121+
return '{' + key + '}'
1122+

0 commit comments

Comments
 (0)