Skip to content

Commit

Permalink
Fix scoping problems in literal Python blocks in PySQL.
Browse files Browse the repository at this point in the history
We now only pass one variables dictionary instead of two for global and local variables.
  • Loading branch information
doerwalter committed Jul 9, 2021
1 parent 27ae93e commit f65d02d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
3 changes: 3 additions & 0 deletions docs/NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ Changes in HEAD (released 07/??/2021)
* :meth:`ll.orasql.Table.uniques` returns all unique contraints for the table;
* :meth:`ll.orasql.Table.checks` returns all check contraints for the table.

* Fixed scoping problems in literal Python blocks in PySQL scripts: List
comprehension were not able to access local variables.


Changes in 5.67.2 (released 06/30/2021)
---------------------------------------
Expand Down
34 changes: 14 additions & 20 deletions src/ll/pysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,26 +1123,19 @@ def __init__(self, code, raiseexceptions=None, cond=True):
def __repr__(self):
return f"<{self.__class__.__module__}.{self.__class__.__qualname__} code={self.code!r} location={self.location} at {id(self):#x}>"

def globals(self, context, connection):
vars = {command.__name__: CommandExecutor(command, context) for command in Command.commands.values()}
vars["sqlexpr"] = sqlexpr
vars["datetime"] = datetime
vars["connection"] = connection
return vars

def execute(self, context):
connection = context.connections[-1] if context.connections else None

if not self.cond:
self.finish(f"Skipped Python block")
return None

vars = self.globals(context, connection)
context._locals["connection"] = connection

code = self.location.source(True) if self.location is not None else self.code
code += "\n"
code = compile(code, context.filename, "exec")
exec(code, vars, context._locals)
exec(code, context._locals)

self.finish(f"Executed Python block")
self.count(connectstring(connection))
Expand Down Expand Up @@ -2223,7 +2216,7 @@ def __init__(self, connectstring=None, scpdirectory="", filedirectory="", commit
self.filename = None
self._lastlocation = None
self._lastcommand = None
self._locals = dict(vars) if vars else {}

for fd in range(3):
try:
self._width = os.get_terminal_size(fd)[0]
Expand All @@ -2233,9 +2226,17 @@ def __init__(self, connectstring=None, scpdirectory="", filedirectory="", commit
break
else:
self._width = 80

if connectstring is not None:
self.connect(connectstring, None)

self._locals = dict(vars) if vars else {}
for command in Command.commands.values():
self._locals[command.__name__] = CommandExecutor(command, self)
self._locals["sqlexpr"] = sqlexpr
self._locals["datetime"] = datetime
self._locals["connection"] = self.connections[-1] if self.connections else None

def connect(self, connectstring, mode=None):
mode = cx_Oracle.SYSDBA if mode == "sysdba" else 0
if orasql is not None:
Expand Down Expand Up @@ -2295,20 +2296,13 @@ def changed_filename(self, filename):
self.filename = oldfilename
os.chdir(oldcwd)

def globals(self):
vars = {command.__name__: CommandExecutor(command, self) for command in Command.commands.values()}
vars["sqlexpr"] = sqlexpr
vars["datetime"] = datetime
vars["connection"] = self.connections[-1] if self.connections else None
return vars

def _load(self, stream):
"""
Load a PySQL file from ``stream`` and executes the commands in the file.
``stream`` must be an iterable over lines that contain the PySQL
commands.
"""
vars = self.globals()
self._locals["connection"] = self.connections[-1] if self.connections else None,

def blocks():
# ``state`` is the state of the "parser", values have the following meaning
Expand Down Expand Up @@ -2391,14 +2385,14 @@ def blocks():
CommandExecutor(literalpy, self)(source)
elif state == "dict":
code = compile(source, self._location.filename, "eval")
args = eval(code, vars, self._locals)
args = eval(code, self._locals)
type = args.pop("type", "procedure")
if type not in Command.commands:
raise ValueError(f"command type {type!r} unknown")
CommandExecutor(Command.commands[type], self)(**args)
else:
code = compile(source, self._location.filename, "exec")
exec(code, vars, self._locals)
exec(code, self._locals)

def executeall(self, *filenames):
"""
Expand Down

0 comments on commit f65d02d

Please sign in to comment.