Skip to content

Commit 5fdd88c

Browse files
committed
update datapipeline for parallel map-style processing
1 parent 1ee0f38 commit 5fdd88c

File tree

5 files changed

+300
-45
lines changed

5 files changed

+300
-45
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,4 +248,4 @@ mlops-export/**
248248
mlruns/**
249249
.vscode/**
250250
.dccache
251-
/scripts/mongoinit.js
251+
omegaml/tests/**/*.db

omegaml/backends/sqlalchemy.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
from logging import warning
2+
13
import logging
24
import os
5+
import pandas as pd
6+
import sqlalchemy
37
import string
8+
import threading
49
import warnings
510
from getpass import getuser
611
from hashlib import sha256
7-
from logging import warning
12+
from sqlalchemy.exc import StatementError
813
from urllib.parse import quote_plus
914

10-
import pandas as pd
11-
import sqlalchemy
1215
from omegaml.backends.basedata import BaseDataBackend
13-
from omegaml.util import ProcessLocal, KeepMissing
16+
from omegaml.util import ProcessLocal, KeepMissing, tqdm_if_interactive
1417

1518
try:
1619
import snowflake
@@ -172,8 +175,14 @@ def drop(self, name, secrets=None, **kwargs):
172175
self.__CNX_CACHE.clear()
173176
return super().drop(name, **kwargs)
174177

178+
def sign(self, values):
179+
# sign a set of values
180+
# -- this is used to ensure the values are not tampered with
181+
# -- when passed to a SQL query
182+
return sha256((str(threading.get_ident()) + str(values)).encode('utf-8')).hexdigest()
183+
175184
def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
176-
secrets=None, index=True, keep=None, lazy=False, table=None, *args, **kwargs):
185+
secrets=None, index=True, keep=None, lazy=False, table=None, trusted=False, *args, **kwargs):
177186
""" retrieve a stored connection or query data from connection
178187
179188
Args:
@@ -183,6 +192,8 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
183192
default by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE = False)
184193
table (str): the name of the table, will be prefixed with the
185194
store's bucket name unless the table is specified as ':name'
195+
trusted (bool|str): if passed must be the value for store.sign(sqlvars or kwargs),
196+
otherwise a warning is issued for any remaining variables in the sql statement
186197
187198
Returns:
188199
connection
@@ -230,7 +241,7 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
230241
sqlvars = sqlvars or {}
231242
table = self._default_table(table or meta.kind_meta.get('table') or name)
232243
if not raw and not valid_sql(sql):
233-
sql = f'select * from {table}'
244+
sql = f'select * from :sqltable'
234245
chunksize = chunksize or meta.kind_meta.get('chunksize')
235246
_default_keep = getattr(self.data_store.defaults,
236247
'SQLALCHEMY_ALWAYS_CACHE',
@@ -242,8 +253,9 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
242253
else:
243254
raise ValueError('no connection string')
244255
if not raw and valid_sql(sql):
256+
sql = sql.replace(':sqltable', table)
245257
index_cols = _meta_to_indexcols(meta) if index else kwargs.get('index_col')
246-
stmt = self._sanitize_statement(sql, sqlvars)
258+
stmt = self._sanitize_statement(sql, sqlvars, trusted=trusted)
247259
kwargs = meta.kind_meta.get('kwargs') or {}
248260
kwargs.update(kwargs)
249261
if not lazy:
@@ -428,15 +440,9 @@ def copy_from_sql(self, sql, connstr, name, chunksize=10000,
428440
connection = self._get_connection(name, connstr, secrets=secrets)
429441
chunksize = chunksize or 10000 # avoid None
430442
pditer = pd.read_sql(sql, connection, chunksize=chunksize, **kwargs)
431-
try:
432-
import tqdm
433-
except:
443+
with tqdm_if_interactive().tqdm(unit='rows') as pbar:
434444
meta = self._chunked_insert(pditer, name, append=append,
435-
transform=transform)
436-
else:
437-
with tqdm.tqdm(unit='rows') as pbar:
438-
meta = self._chunked_insert(pditer, name, append=append,
439-
transform=transform, pbar=pbar)
445+
transform=transform, pbar=pbar)
440446
connection.close()
441447
return meta
442448

@@ -448,7 +454,7 @@ def _chunked_to_sql(self, df, table, connection, if_exists='append', chunksize=N
448454
def chunker(seq, size):
449455
return (seq.iloc[pos:pos + size] for pos in range(0, len(seq), size))
450456

451-
def to_sql(df, table, connection, pbar=False):
457+
def to_sql(df, table, connection, pbar=None):
452458
for i, cdf in enumerate(chunker(df, chunksize)):
453459
exists_action = if_exists if i == 0 else "append"
454460
cdf.to_sql(table, con=connection, if_exists=exists_action, **kwargs)
@@ -457,15 +463,8 @@ def to_sql(df, table, connection, pbar=False):
457463
else:
458464
print("writing chunk {}".format(i))
459465

460-
try:
461-
import tqdm
462-
if pbar is False:
463-
pbar = None
464-
except Exception as e:
466+
with tqdm_if_interactive().tqdm(total=len(df), unit='rows') as pbar:
465467
to_sql(df, table, connection, pbar=pbar)
466-
else:
467-
with tqdm.tqdm(total=len(df), unit='rows') as pbar:
468-
to_sql(df, table, connection, pbar=pbar)
469468

470469
def _chunked_insert(self, pditer, name, append=True, transform=None, pbar=None):
471470
# insert into om dataset
@@ -525,7 +524,7 @@ def _default_table(self, name):
525524
name = name[1:]
526525
return name
527526

528-
def _sanitize_statement(self, sql, sqlvars):
527+
def _sanitize_statement(self, sql, sqlvars, trusted=False):
529528
# sanitize sql:string statement in two steps
530529
# -- step 1: replace all {} variables by :notation
531530
# -- step 2: replace all remaining {} variables from sqlvars
@@ -538,7 +537,7 @@ def _sanitize_statement(self, sql, sqlvars):
538537
# replace all {...} variables with bound parameters
539538
# sql = "select * from foo where user={username}"
540539
# => "select * from foo where user=:username"
541-
placeholders = string.Formatter().parse(sql)
540+
placeholders = list(string.Formatter().parse(sql))
542541
vars = [spec[1] for spec in placeholders if spec[1]]
543542
safe_replacements = {var: f':{var}' for var in vars}
544543
sql = sql.format(**safe_replacements)
@@ -562,15 +561,15 @@ def _sanitize_statement(self, sql, sqlvars):
562561
# => "select a, b from foo where user=:username
563562
placeholders = list(string.Formatter().parse(sql))
564563
vars = [spec[1] for spec in placeholders if spec[1]]
565-
if vars:
564+
if vars and trusted != self.sign(sqlvars):
566565
warnings.warn(f'Statement >{sql}< contains unsafe variables {vars}. Use :notation or sanitize input.')
567566
sql = sql.format(**{**sqlvars, **safe_replacements})
568567
except KeyError as e:
569568
raise KeyError('{e}, specify sqlvars= to build query >{sql}<'.format(**locals()))
570569
# prepare sql statement with bound variables
571570
try:
572571
stmt = sqlalchemy.sql.text(sql)
573-
except sqlalchemy.exc.StatementError as exc:
572+
except StatementError as exc:
574573
raise
575574
return stmt
576575

@@ -657,7 +656,7 @@ def load_sql(om=None, kind=SQLAlchemyBackend.KIND):
657656
from unittest.mock import MagicMock
658657
from IPython import get_ipython
659658
import omegaml as om
660-
from sql.connection import Connection
659+
from sql.connection import Connection # noqa
661660

662661
class ConnectionShim:
663662
# this is required to trick sql magic into accepting existing connection objects

omegaml/datapipeline.py

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,151 @@
1-
from sklearn.utils import Parallel, delayed
1+
from joblib import Parallel, delayed
2+
from sqlalchemy.exc import OperationalError
3+
4+
from omegaml.util import ProcessLocal
25

36

47
class ParallelStep(object):
58
def __init__(self, steps=None, agg=None, n_jobs=-1):
69
self.steps = steps
7-
self.agg = agg or np.mean
10+
self.agg = agg or self._default_agg
811
self.n_jobs = n_jobs
912

10-
def aggregate(self, values):
11-
return self.agg(values)
13+
def aggregate(self, values, **kwargs):
14+
return self.agg(values, **kwargs)
15+
16+
def _default_agg(self, values, **kwargs):
17+
return list(values)
1218

13-
def __call__(self, value):
14-
values = Parallel(n_jobs=self.n_jobs)(delayed(pfn)(value) for pfn in self.steps)
15-
return self.aggregate(values)
19+
def __call__(self, value, **kwargs):
20+
values = Parallel(n_jobs=self.n_jobs)(delayed(pfn)(value, **kwargs) for pfn in self.steps)
21+
return self.aggregate(values, **kwargs)
1622

1723

1824
class DataPipeline(object):
19-
def __init__(self, steps):
25+
""" A data pipeline that processes data through a series of steps
26+
27+
The pipeline is a sequence of steps, each of which is a callable that
28+
processes the data.
29+
30+
.. versionchanged:: 0.17
31+
The pipeline now automatically gets a context object that is passed
32+
along to each step as a keyword argument.
33+
"""
34+
def __init__(self, steps, *args, context=None, n_jobs=None, **kwargs):
2035
self.steps = steps
21-
self.args = None
22-
self.kwargs = None
36+
self.args = args
37+
self.kwargs = kwargs
38+
self.context = context or ProcessLocal()
39+
self.n_jobs = n_jobs or -1
2340

2441
def set_params(self, *args, **kwargs):
2542
self.args = args
2643
self.kwargs = kwargs
2744
return self
2845

29-
def process(self, value=None):
46+
def process(self, value=None, **kwargs):
3047
for stepfn in self.steps:
31-
value = stepfn(value)
48+
value = stepfn(value, context=self.context, **kwargs)
3249
return value
3350

34-
def __call__(self):
35-
return self.process()
51+
def __call__(self, value=None, **kwargs):
52+
return self.process(value, **kwargs)
53+
54+
def map(self, values, **kwargs):
55+
parallel_steps = self.steps[:-1] if len(self.steps) > 1 else self.steps
56+
finalizer = self.steps[-1] if len(self.steps) > 1 else lambda values, **kwargs: values
57+
pfn = DataPipeline(*self.args, steps=parallel_steps, context=self.context, **self.kwargs)
58+
parallel = Parallel(n_jobs=self.n_jobs)
59+
results = parallel(delayed(pfn)(**kwargs) for kwargs in values)
60+
return finalizer(results, context=self.context, **kwargs)
61+
62+
63+
class Model:
64+
dburl = 'sqlite:///:memory:'
65+
name = ''
66+
sql = 'select * from :sqltable'
67+
table = None
68+
chunksize = None
69+
delete_sql = 'delete from :sqltable'
70+
keys = None
71+
join_sql = '''
72+
with join_:sqltable as (
73+
select {{_join_vars_a}}, {{_join_vars_b}}
74+
from :sqltable as a
75+
join {{_join_sqltable}} as b
76+
on 1 = 1 {{_join_cond}}
77+
)
78+
79+
select *
80+
from join_:sqltable
81+
'''
82+
83+
def __init__(self, sql=None, om=None):
84+
self.sql = sql or self.sql
85+
self._om = om
86+
self.setup()
87+
88+
@property
89+
def om(self):
90+
import omegaml as _baseom
91+
self._om = self._om or _baseom
92+
return self._om
93+
94+
@property
95+
def store(self):
96+
return self.om.datasets
97+
98+
def setup(self):
99+
assert self.name, 'name must be set'
100+
return self.store.put(self.dburl, self.name, sql=self.sql, table=self.table)
101+
102+
def query(self, *args, sql=None, chunksize=None, trusted=False, **vars):
103+
chunksize = chunksize or self.chunksize
104+
return self.store.get(self.name, chunksize=chunksize, sqlvars=vars, sql=sql, trusted=trusted)
105+
106+
def insert(self, data):
107+
return self.store.put(data, self.name, index=False)
108+
109+
def transform(self, value, **kwargs):
110+
return value
111+
112+
def delete(self, *args, sql=None, **kwargs):
113+
sql = sql or self.delete_sql
114+
try:
115+
cursor = self.store.get(self.name, sql=sql, sqlvars=kwargs, lazy=True)
116+
except OperationalError as e:
117+
pass
118+
else:
119+
cursor.close()
120+
121+
def drop(self, force=False):
122+
return self.store.drop(self.name, force=force)
123+
124+
def join(self, other, on=None, on_left=None, on_right=None, columns_left=None, columns_right=None, **kwargs):
125+
join_sql = self.join_sql
126+
join_keys = {
127+
ka: kb for ka, kb in zip(on or on_left or [], on or on_right or [])
128+
}
129+
_join_cond = ' and '.join([f'a.{k} = b.{v}' for k, v in join_keys.items()])
130+
_join_cond = ' and ' + _join_cond if _join_cond else ''
131+
_left_cols = ','.join(columns_left or ['a.*'])
132+
_right_cols = ','.join(columns_right or ['b.*'])
133+
sqlvars = {
134+
'_join_vars_a': _left_cols,
135+
'_join_vars_b': _right_cols,
136+
'_join_sqltable': self.store.get_backend(self.name)._default_table(other.table or other.name),
137+
'_join_cond': _join_cond,
138+
}
139+
return self.query(sql=join_sql, trusted=self.store.get_backend(self.name).sign(sqlvars), **sqlvars)
36140

141+
def count(self, sql=None, raw=False, **vars):
142+
if raw and not vars:
143+
sql = sql or 'select count(*) from :sqltable'
144+
count = self.query(sql=sql).values[-1]
145+
else:
146+
count = len(self.query(sql=sql, **vars))
147+
return count
37148

149+
def __call__(self, value, **kwargs):
150+
value = self.query(**kwargs)
151+
return self.transform(value, **kwargs)

0 commit comments

Comments
 (0)