Skip to content
Permalink
Browse files

[WIP] template precompilation

  • Loading branch information...
dw committed May 18, 2018
1 parent 7e1f88c commit 8a16a73558db7b640730ee3e88a356143381bb18
@@ -21,6 +21,7 @@
from ansible.parsing.utils.yaml import from_yaml
from ansible.parsing.vault import VaultLib, b_HEADER, is_encrypted, is_encrypted_file, parse_vaulttext_envelope
from ansible.utils.path import unfrackpath
from ansible.template import cache_templates

try:
from __main__ import display
@@ -93,6 +94,7 @@ def load_from_file(self, file_name, cache=True, unsafe=False):

# cache the file contents for next time
self._FILE_CACHE[file_name] = parsed_data
cache_templates(parsed_data)

if unsafe:
return parsed_data
@@ -29,6 +29,7 @@
from ansible.module_utils.six import text_type
from ansible.module_utils._text import to_native
from ansible.playbook.attribute import FieldAttribute
from ansible.template import cache_templates

try:
from __main__ import display
@@ -42,6 +43,35 @@
VALID_VAR_REGEX = re.compile("^[_A-Za-z][_a-zA-Z0-9]*$")


class CleansingNodeVisitor(ast.NodeVisitor):
def __init__(self, conditional):
self.conditional = conditional

def generic_visit(self, node, inside_call=False, inside_yield=False):
if isinstance(node, ast.Call):
inside_call = True
elif isinstance(node, ast.Yield):
inside_yield = True
elif isinstance(node, ast.Str):
if inside_call and node.s.startswith("__"):
# calling things with a dunder is generally bad at this point...
raise AnsibleError(
"Invalid access found in the conditional: '%s'" % self.conditional
)
elif inside_yield:
# we're inside a yield, so recursively parse and traverse the AST
# of the result to catch forbidden syntax from executing
parsed = ast.parse(node.s, mode='exec')
self.visit(parsed)
# iterate over all child nodes
for child_node in ast.iter_child_nodes(node):
self.generic_visit(
child_node,
inside_call=inside_call,
inside_yield=inside_yield
)


class Conditional:

'''
@@ -62,9 +92,13 @@ def __init__(self, loader=None):
self._loader = loader
super(Conditional, self).__init__()

def _decorate(self, s):
return "{%% if %s %%} True {%% else %%} False {%% endif %%}" % s

def _validate_when(self, attr, name, value):
if not isinstance(value, list):
setattr(self, name, [value])
cache_templates(value, decorator=self._decorate)

def extract_defined_undefined(self, conditional):
results = []
@@ -117,7 +151,7 @@ def _check_conditional(self, conditional, templar, all_vars):
if conditional is None or conditional == '':
return True

if templar.is_template(conditional):
if templar.is_template_uncached(conditional):
display.warning('when statements should not include jinja2 '
'templating delimiters such as {{ }} or {%% %%}. '
'Found: %s' % conditional)
@@ -148,32 +182,6 @@ def _check_conditional(self, conditional, templar, all_vars):

# First, we do some low-level jinja2 parsing involving the AST format of the
# statement to ensure we don't do anything unsafe (using the disable_lookup flag above)
class CleansingNodeVisitor(ast.NodeVisitor):
def generic_visit(self, node, inside_call=False, inside_yield=False):
if isinstance(node, ast.Call):
inside_call = True
elif isinstance(node, ast.Yield):
inside_yield = True
elif isinstance(node, ast.Str):
if disable_lookups:
if inside_call and node.s.startswith("__"):
# calling things with a dunder is generally bad at this point...
raise AnsibleError(
"Invalid access found in the conditional: '%s'" % conditional
)
elif inside_yield:
# we're inside a yield, so recursively parse and traverse the AST
# of the result to catch forbidden syntax from executing
parsed = ast.parse(node.s, mode='exec')
cnv = CleansingNodeVisitor()
cnv.visit(parsed)
# iterate over all child nodes
for child_node in ast.iter_child_nodes(node):
self.generic_visit(
child_node,
inside_call=inside_call,
inside_yield=inside_yield
)
try:
e = templar.environment.overlay()
e.filters.update(templar._get_filters(e.filters))
@@ -183,13 +191,13 @@ def generic_visit(self, node, inside_call=False, inside_yield=False):
res = generate(res, e, None, None)
parsed = ast.parse(res, mode='exec')

cnv = CleansingNodeVisitor()
cnv.visit(parsed)
if disable_lookups:
CleansingNodeVisitor(conditional).visit(parsed)
except Exception as e:
raise AnsibleError("Invalid conditional detected: %s" % to_native(e))

# and finally we generate and template the presented string and look at the resulting string
presented = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional
presented = self._decorate(conditional)
val = templar.template(presented, disable_lookups=disable_lookups).strip()
if val == "True":
return True
@@ -49,6 +49,8 @@

__all__ = ['PlayContext']

PROCESS_USERNAME = pwd.getpwuid(os.getuid()).pw_name

# TODO: needs to be configurable
b_SU_PROMPT_LOCALIZATIONS = [
to_bytes('Password'),
@@ -430,7 +432,7 @@ def set_task_and_variable_override(self, task, variables, templar):
if new_info.connection == 'local':
if not new_info.connection_user:
new_info.connection_user = new_info.remote_user
new_info.remote_user = pwd.getpwuid(os.getuid()).pw_name
new_info.remote_user = PROCESS_USERNAME

# set no_log to default if it was not previously set
if new_info.no_log is None:
@@ -36,6 +36,7 @@
from ansible.playbook.loop_control import LoopControl
from ansible.playbook.role import Role
from ansible.playbook.taggable import Taggable
from ansible.template import cache_templates


try:
@@ -158,6 +159,7 @@ def _preprocess_with_loop(self, ds, new_ds, k, v):
raise AnsibleError("you must specify a value when using %s" % k, obj=ds)
new_ds['loop_with'] = loop_name
new_ds['loop'] = v
cache_templates(v)
# FIXME: reenable afte 2.5
# display.deprecated("with_ type loops are being phased out, use the 'loop' keyword instead", version="2.10")

@@ -133,20 +133,23 @@ def _all_directories(self, dir):
results.append(os.path.join(root, x))
return results

_package_path = None
_package_subdir_path = None

def _get_package_paths(self, subdirs=True):
''' Gets the path of a Python package '''

if not self.package:
return []
if not hasattr(self, 'package_path'):
m = __import__(self.package)
parts = self.package.split('.')[1:]
for parent_mod in parts:
m = getattr(m, parent_mod)
self.package_path = os.path.dirname(m.__file__)
if not self._package_path:
module = __import__(self.package)
for part in self.package.split('.')[1:]:
module = getattr(module, part)
self._package_path = os.path.dirname(module.__file__)
if subdirs:
return self._all_directories(self.package_path)
return [self.package_path]
if not self._package_subdir_path:
self._package_subdir_path = self._all_directories(self._package_path)
return self._package_subdir_path
return [self._package_path]

def _get_paths(self, subdirs=True):
''' Return a list of paths to search for plugins in '''
@@ -46,6 +46,7 @@
from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role_include import IncludeRole
from ansible.plugins.loader import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader
import ansible.plugins.loader
from ansible.template import Templar
from ansible.utils.vars import combine_vars
from ansible.vars.clean import strip_internal_keys
@@ -211,6 +212,15 @@ def __init__(self, tqm):

self.debugger_active = C.ENABLE_TASK_DEBUGGER

# TODO MOVE ME
list(ansible.plugins.loader.vars_loader.all(class_only=True))
for k, v in vars(ansible.plugins.loader).items():
if isinstance(v, ansible.plugins.loader.PluginLoader):
try:
v._get_paths()
except Exception as e:
pass

def cleanup(self):
# close active persistent connections
for sock in itervalues(self._active_connections):
@@ -20,6 +20,7 @@
__metaclass__ = type

import ast
import collections
import contextlib
import datetime
import os
@@ -37,7 +38,7 @@
except ImportError:
from sha import sha as sha1

from jinja2 import Environment
from jinja2 import Environment, Template
from jinja2.exceptions import TemplateSyntaxError, UndefinedError
from jinja2.loaders import FileSystemLoader
from jinja2.runtime import Context, StrictUndefined
@@ -71,6 +72,116 @@
JINJA2_OVERRIDE = '#jinja2:'
B_JINJA2_OVERRIDE = b'#jinja2:'

_template_cache = {}
_snoop_templar = None


class StubOverlayDict(collections.Mapping):
def __init__(self, dct=None):
self._dct = dct or {}
self.names = set()

def __iter__(self):
return iter(())

def __len__(self):
return 0

def _stub(self):
assert False, "_stub() invoked."

def __getitem__(self, key):
if key in self._dct:
return self._dct[key]
self.names.add(key)
return self._stub


def _cache_one_template(env, s):
original_filters = env.filters
original_tests = env.tests
env.filters = StubOverlayDict(original_filters)
env.tests = StubOverlayDict(original_tests)

if s in _template_cache:
return

try:
try:
code = env.compile(s)
except TemplateSyntaxError as e:
raise AnsibleError("template error while templating string: %s. String: %s" % (to_native(e), to_native(data)))

_template_cache[s] = code, env.filters.names, env.tests.names
finally:
env.filters = original_filters
env.tests = original_tests


def _get_cached_template(env, s):
try:
code, filter_names, test_names = _template_cache[s]
return Template.from_code(env, code, env.globals)
except KeyError:
if not s.endswith('\n'):
# FML.
"""
- command: >
echo gcloud compute instances create {{instance_name}} --can-ip-forward --machine-type=n1-standard-8 --preemptible --scopes=compute-ro --image-project=debian-cloud --image-family=debian-9
"""
return _get_cached_template(env, s+'\n')
print(['KeyError for', s])
import traceback;traceback.print_stack(limit=10)
return env.from_string(s)
except Exception:
import traceback;traceback.print_exc()
raise

def walk_strings(root):
stack = [root]
while stack:
obj = stack.pop()
if isinstance(obj, (text_type, binary_type)):
yield obj
elif isinstance(obj, (list, tuple)):
stack.extend(obj)
elif isinstance(obj, dict):
stack.extend(obj)
stack.extend(viewvalues(obj))


def is_template(s):
return (
s.startswith(JINJA2_OVERRIDE) or
_snoop_templar._template_token_regex.search(s)
)


def cache_templates(obj, decorator=None):
'''
Return :data:`True` if any string or bytes in the graph `obj` contains
template directives. Unlike :meth:`is_template`, this does not cause
caching of the compiled template as a side-effect.
'''
global _snoop_templar
if _snoop_templar is None:
_snoop_templar = Templar(None)

env = AnsibleEnvironment(
trim_blocks=True,
extensions=_snoop_templar._get_extensions(),
finalize=_snoop_templar._finalize,
)
_snoop_templar._get_filters({})

if decorator:
for s in walk_strings(obj):
_cache_one_template(env, decorator(s))
else:
for s in walk_strings(obj):
if is_template(s):
_cache_one_template(env, s)


def generate_ansible_template_vars(path):
b_path = to_bytes(path)
@@ -725,7 +836,7 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=
data = _escape_backslashes(data, myenv)

try:
t = myenv.from_string(data)
t = _get_cached_template(myenv, data)
except TemplateSyntaxError as e:
raise AnsibleError("template error while templating string: %s. String: %s" % (to_native(e), to_native(data)))
except Exception as e:
@@ -449,7 +449,7 @@ def _get_magic_variables(self, play, host, task, include_hostvars, include_deleg
if self._inventory is not None:
variables['groups'] = self._inventory.get_groups_dict()
if play:
if self._templar.is_template(play.hosts):
if self._templar.is_template_uncached(play.hosts):
pattern = 'all'
else:
pattern = play.hosts or 'all'

0 comments on commit 8a16a73

Please sign in to comment.
You can’t perform that action at this time.