Skip to content

Commit

Permalink
improved static optimizer
Browse files Browse the repository at this point in the history
--HG--
branch : trunk
  • Loading branch information
mitsuhiko committed Apr 8, 2008
1 parent 149aa4e commit 81b8817
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 29 deletions.
132 changes: 106 additions & 26 deletions jinja2/optimizer.py
Expand Up @@ -25,54 +25,134 @@
from jinja2.runtime import subscribe


class ContextStack(object):
"""Simple compile time context implementation."""

def __init__(self, initial=None):
self.stack = [{}]
if initial is not None:
self.stack.insert(0, initial)

def push(self):
self.stack.append({})

def pop(self):
self.stack.pop()

def __getitem__(self, key):
for level in reversed(self.stack):
if key in level:
return level[key]
raise KeyError(key)

def __setitem__(self, key, value):
self.stack[-1][key] = value

def blank(self):
"""Return a new context with nothing but the root scope."""
return ContextStack(self.stack[0])


class Optimizer(NodeTransformer):

def __init__(self, environment, context={}):
def __init__(self, environment):
self.environment = environment
self.context = context

def visit_Filter(self, node):
def visit_Filter(self, node, context):
"""Try to evaluate filters if possible."""
# XXX: nonconstant arguments? not-called visitors? generic visit!
try:
x = self.visit(node.node).as_const()
x = self.visit(node.node, context).as_const()
except nodes.Impossible:
return node
return self.generic_visit(node, context)
for filter in reversed(node.filters):
# XXX: call filters with arguments
x = self.environment.filters[filter.name](self.environment, x)
# XXX: don't optimize context dependent filters
return nodes.Const(x)

def visit_For(self, node):
"""Loop unrolling for constant values."""
def visit_For(self, node, context):
"""Loop unrolling for iterable constant values."""
try:
iter = self.visit(node.iter).as_const()
except nodes.Impossible:
return node
iterable = iter(self.visit(node.iter, context).as_const())
except (nodes.Impossible, TypeError):
return self.generic_visit(node, context)
context.push()
result = []
# XXX: tuple unpacking (for key, value in foo)
target = node.target.name
for item in iter:
# XXX: take care of variable scopes
self.context[target] = item
result.extend(self.visit(n) for n in deepcopy(node.body))
iterated = False
for item in iterable:
context[target] = item
result.extend(self.visit(n, context) for n in deepcopy(node.body))
iterated = True
if not iterated and node.else_:
result.extend(self.visit(n, context) for n in deepcopy(node.else_))
context.pop()
return result

def visit_Name(self, node):
# XXX: take care of variable scopes!
if node.name not in self.context:
def visit_If(self, node, context):
try:
val = self.visit(node.test, context).as_const()
except nodes.Impossible:
return self.generic_visit(node, context)
if val:
return node.body
return node.else_

def visit_Name(self, node, context):
if node.ctx == 'load':
try:
return nodes.Const(context[node.name], lineno=node.lineno)
except KeyError:
pass
return node

def visit_Assign(self, node, context):
try:
target = node.target = self.generic_visit(node.target, context)
value = self.generic_visit(node.node, context).as_const()
except nodes.Impossible:
return node
return nodes.Const(self.context[node.name])

def visit_Subscript(self, node):
result = []
lineno = node.lineno
def walk(target, value):
if isinstance(target, nodes.Name):
const_value = nodes.Const(value, lineno=lineno)
result.append(nodes.Assign(target, const_value, lineno=lineno))
context[target.name] = value
elif isinstance(target, nodes.Tuple):
try:
value = tuple(value)
except TypeError:
raise nodes.Impossible()
if len(target) != len(value):
raise nodes.Impossible()
for name, val in zip(target, value):
walk(name, val)
else:
raise AssertionError('unexpected assignable node')

try:
item = self.visit(node.node).as_const()
arg = self.visit(node.arg).as_const()
walk(target, value)
except nodes.Impossible:
return node
# XXX: what does the 3rd parameter mean?
return nodes.Const(subscribe(item, arg, None))
return result

def visit_Subscript(self, node, context):
if node.ctx == 'load':
try:
item = self.visit(node.node, context).as_const()
arg = self.visit(node.arg, context).as_const()
except nodes.Impossible:
return self.generic_visit(node, context)
return nodes.Const(subscribe(item, arg, 'load'))
return self.generic_visit(node, context)


def optimize(node, environment, context={}):
optimizer = Optimizer(environment, context=context)
return optimizer.visit(node)
def optimize(node, environment, context_hint=None):
"""The context hint can be used to perform an static optimization
based on the context given."""
optimizer = Optimizer(environment)
return optimizer.visit(node, ContextStack(context_hint))
4 changes: 2 additions & 2 deletions jinja2/runtime.py
Expand Up @@ -24,8 +24,8 @@ def extends(template, namespace):
def subscribe(obj, argument, undefined_factory):
"""Get an item or attribute of an object."""
try:
return getattr(obj, argument)
except AttributeError:
return getattr(obj, str(argument))
except (AttributeError, UnicodeError):
try:
return obj[argument]
except LookupError:
Expand Down
9 changes: 8 additions & 1 deletion test_optimizer.py
Expand Up @@ -16,12 +16,19 @@
{% for forum in forums %}
{{ readstatus(forum.id) }} {{ forum.id|e }} {{ forum.name|e }}
{% endfor %}
{% navigation = [('#foo', 'Foo'), ('#bar', 'Bar')] %}
<ul>
{% for item in navigation %}
<li><a href="{{ item[0] }}">{{ item[1] }}</a></li>
{% endfor %}
</ul>
""")
print ast
print
print generate(ast, env, "foo.html")
print
ast = optimize(ast, env, context={'forums': forums})
ast = optimize(ast, env, context_hint={'forums': forums})
print ast
print
print generate(ast, env, "foo.html")

0 comments on commit 81b8817

Please sign in to comment.