1
+ from logging import warning
2
+
1
3
import logging
2
4
import os
5
+ import pandas as pd
6
+ import sqlalchemy
3
7
import string
8
+ import threading
4
9
import warnings
5
10
from getpass import getuser
6
11
from hashlib import sha256
7
- from logging import warning
12
+ from sqlalchemy . exc import StatementError
8
13
from urllib .parse import quote_plus
9
14
10
- import pandas as pd
11
- import sqlalchemy
12
15
from omegaml .backends .basedata import BaseDataBackend
13
- from omegaml .util import ProcessLocal , KeepMissing
16
+ from omegaml .util import ProcessLocal , KeepMissing , tqdm_if_interactive
14
17
15
18
try :
16
19
import snowflake
@@ -172,8 +175,14 @@ def drop(self, name, secrets=None, **kwargs):
172
175
self .__CNX_CACHE .clear ()
173
176
return super ().drop (name , ** kwargs )
174
177
178
+ def sign (self , values ):
179
+ # sign a set of values
180
+ # -- this is used to ensure the values are not tampered with
181
+ # -- when passed to a SQL query
182
+ return sha256 ((str (threading .get_ident ()) + str (values )).encode ('utf-8' )).hexdigest ()
183
+
175
184
def get (self , name , sql = None , chunksize = None , raw = False , sqlvars = None ,
176
- secrets = None , index = True , keep = None , lazy = False , table = None , * args , ** kwargs ):
185
+ secrets = None , index = True , keep = None , lazy = False , table = None , trusted = False , * args , ** kwargs ):
177
186
""" retrieve a stored connection or query data from connection
178
187
179
188
Args:
@@ -183,6 +192,8 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
183
192
default by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE = False)
184
193
table (str): the name of the table, will be prefixed with the
185
194
store's bucket name unless the table is specified as ':name'
195
+ trusted (bool|str): if passed must be the value for store.sign(sqlvars or kwargs),
196
+ otherwise a warning is issued for any remaining variables in the sql statement
186
197
187
198
Returns:
188
199
connection
@@ -230,7 +241,7 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
230
241
sqlvars = sqlvars or {}
231
242
table = self ._default_table (table or meta .kind_meta .get ('table' ) or name )
232
243
if not raw and not valid_sql (sql ):
233
- sql = f'select * from { table } '
244
+ sql = f'select * from :sqltable '
234
245
chunksize = chunksize or meta .kind_meta .get ('chunksize' )
235
246
_default_keep = getattr (self .data_store .defaults ,
236
247
'SQLALCHEMY_ALWAYS_CACHE' ,
@@ -242,8 +253,9 @@ def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
242
253
else :
243
254
raise ValueError ('no connection string' )
244
255
if not raw and valid_sql (sql ):
256
+ sql = sql .replace (':sqltable' , table )
245
257
index_cols = _meta_to_indexcols (meta ) if index else kwargs .get ('index_col' )
246
- stmt = self ._sanitize_statement (sql , sqlvars )
258
+ stmt = self ._sanitize_statement (sql , sqlvars , trusted = trusted )
247
259
kwargs = meta .kind_meta .get ('kwargs' ) or {}
248
260
kwargs .update (kwargs )
249
261
if not lazy :
@@ -428,15 +440,9 @@ def copy_from_sql(self, sql, connstr, name, chunksize=10000,
428
440
connection = self ._get_connection (name , connstr , secrets = secrets )
429
441
chunksize = chunksize or 10000 # avoid None
430
442
pditer = pd .read_sql (sql , connection , chunksize = chunksize , ** kwargs )
431
- try :
432
- import tqdm
433
- except :
443
+ with tqdm_if_interactive ().tqdm (unit = 'rows' ) as pbar :
434
444
meta = self ._chunked_insert (pditer , name , append = append ,
435
- transform = transform )
436
- else :
437
- with tqdm .tqdm (unit = 'rows' ) as pbar :
438
- meta = self ._chunked_insert (pditer , name , append = append ,
439
- transform = transform , pbar = pbar )
445
+ transform = transform , pbar = pbar )
440
446
connection .close ()
441
447
return meta
442
448
@@ -448,7 +454,7 @@ def _chunked_to_sql(self, df, table, connection, if_exists='append', chunksize=N
448
454
def chunker (seq , size ):
449
455
return (seq .iloc [pos :pos + size ] for pos in range (0 , len (seq ), size ))
450
456
451
- def to_sql (df , table , connection , pbar = False ):
457
+ def to_sql (df , table , connection , pbar = None ):
452
458
for i , cdf in enumerate (chunker (df , chunksize )):
453
459
exists_action = if_exists if i == 0 else "append"
454
460
cdf .to_sql (table , con = connection , if_exists = exists_action , ** kwargs )
@@ -457,15 +463,8 @@ def to_sql(df, table, connection, pbar=False):
457
463
else :
458
464
print ("writing chunk {}" .format (i ))
459
465
460
- try :
461
- import tqdm
462
- if pbar is False :
463
- pbar = None
464
- except Exception as e :
466
+ with tqdm_if_interactive ().tqdm (total = len (df ), unit = 'rows' ) as pbar :
465
467
to_sql (df , table , connection , pbar = pbar )
466
- else :
467
- with tqdm .tqdm (total = len (df ), unit = 'rows' ) as pbar :
468
- to_sql (df , table , connection , pbar = pbar )
469
468
470
469
def _chunked_insert (self , pditer , name , append = True , transform = None , pbar = None ):
471
470
# insert into om dataset
@@ -525,7 +524,7 @@ def _default_table(self, name):
525
524
name = name [1 :]
526
525
return name
527
526
528
- def _sanitize_statement (self , sql , sqlvars ):
527
+ def _sanitize_statement (self , sql , sqlvars , trusted = False ):
529
528
# sanitize sql:string statement in two steps
530
529
# -- step 1: replace all {} variables by :notation
531
530
# -- step 2: replace all remaining {} variables from sqlvars
@@ -538,7 +537,7 @@ def _sanitize_statement(self, sql, sqlvars):
538
537
# replace all {...} variables with bound parameters
539
538
# sql = "select * from foo where user={username}"
540
539
# => "select * from foo where user=:username"
541
- placeholders = string .Formatter ().parse (sql )
540
+ placeholders = list ( string .Formatter ().parse (sql ) )
542
541
vars = [spec [1 ] for spec in placeholders if spec [1 ]]
543
542
safe_replacements = {var : f':{ var } ' for var in vars }
544
543
sql = sql .format (** safe_replacements )
@@ -562,15 +561,15 @@ def _sanitize_statement(self, sql, sqlvars):
562
561
# => "select a, b from foo where user=:username
563
562
placeholders = list (string .Formatter ().parse (sql ))
564
563
vars = [spec [1 ] for spec in placeholders if spec [1 ]]
565
- if vars :
564
+ if vars and trusted != self . sign ( sqlvars ) :
566
565
warnings .warn (f'Statement >{ sql } < contains unsafe variables { vars } . Use :notation or sanitize input.' )
567
566
sql = sql .format (** {** sqlvars , ** safe_replacements })
568
567
except KeyError as e :
569
568
raise KeyError ('{e}, specify sqlvars= to build query >{sql}<' .format (** locals ()))
570
569
# prepare sql statement with bound variables
571
570
try :
572
571
stmt = sqlalchemy .sql .text (sql )
573
- except sqlalchemy . exc . StatementError as exc :
572
+ except StatementError as exc :
574
573
raise
575
574
return stmt
576
575
@@ -657,7 +656,7 @@ def load_sql(om=None, kind=SQLAlchemyBackend.KIND):
657
656
from unittest .mock import MagicMock
658
657
from IPython import get_ipython
659
658
import omegaml as om
660
- from sql .connection import Connection
659
+ from sql .connection import Connection # noqa
661
660
662
661
class ConnectionShim :
663
662
# this is required to trick sql magic into accepting existing connection objects
0 commit comments