/
header.py
348 lines (292 loc) · 13.4 KB
/
header.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------------------------------------------------
# INFO:
# -----------------------------------------------------------------------------------------------------------------------
"""
Author: Evan Hubinger
License: Apache 2.0
Description: Header utilities for the compiler.
"""
# -----------------------------------------------------------------------------------------------------------------------
# IMPORTS:
# -----------------------------------------------------------------------------------------------------------------------
from __future__ import print_function, absolute_import, unicode_literals, division
from coconut.root import * # NOQA
import os.path
from coconut.root import _indent
from coconut.constants import (
get_target_info,
hash_prefix,
tabideal,
default_encoding,
template_ext,
justify_len,
)
from coconut.exceptions import internal_assert
# -----------------------------------------------------------------------------------------------------------------------
# UTILITIES:
# -----------------------------------------------------------------------------------------------------------------------
def gethash(compiled):
"""Retrieve a hash from a header."""
lines = compiled.splitlines()
if len(lines) < 3 or not lines[2].startswith(hash_prefix):
return None
else:
return lines[2][len(hash_prefix):]
def minify(compiled):
"""Perform basic minifications.
Fails on non-tabideal indentation or a string with a #.
"""
compiled = compiled.strip()
if compiled:
out = []
for line in compiled.splitlines():
line = line.split("#", 1)[0].rstrip()
if line:
ind = 0
while line.startswith(" "):
line = line[1:]
ind += 1
internal_assert(ind % tabideal == 0, "invalid indentation in", line)
out.append(" " * (ind // tabideal) + line)
compiled = "\n".join(out) + "\n"
return compiled
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
def get_template(template):
"""Read the given template file."""
with open(os.path.join(template_dir, template) + template_ext, "r") as template_file:
return template_file.read()
def one_num_ver(target):
"""Return the first number of the target version, if it has one."""
return target[:1] # "2", "3", or ""
def section(name):
"""Generate a section break."""
line = "# " + name + ": "
return line + "-" * (justify_len - len(line)) + "\n\n"
# -----------------------------------------------------------------------------------------------------------------------
# FORMAT DICTIONARY:
# -----------------------------------------------------------------------------------------------------------------------
class comment(object):
"""When passed to str.format, allows {comment.<>} to serve as a comment."""
def __getattr__(self, attr):
"""Return an empty string for all comment attributes."""
return ""
def process_header_args(which, target, use_hash, no_tco, strict):
"""Create the dictionary passed to str.format in the header, target_startswith, and target_info."""
target_startswith = one_num_ver(target)
target_info = get_target_info(target)
try_backport_lru_cache = r'''try:
from backports.functools_lru_cache import lru_cache
functools.lru_cache = lru_cache
except ImportError: pass
'''
try_import_trollius = r'''try:
import trollius as asyncio
except ImportError:
class you_need_to_install_trollius: pass
asyncio = you_need_to_install_trollius()
'''
format_dict = dict(
comment=comment(),
empty_dict="{}",
target_startswith=target_startswith,
default_encoding=default_encoding,
hash_line=hash_prefix + use_hash + "\n" if use_hash is not None else "",
typing_line="# type: ignore\n" if which == "__coconut__" else "",
VERSION_STR=VERSION_STR,
module_docstring='"""Built-in Coconut utilities."""\n\n' if which == "__coconut__" else "",
object="(object)" if target_startswith != "3" else "",
import_asyncio=_indent(
"" if not target or target_info >= (3, 5)
else "import asyncio\n" if target_info >= (3, 4)
else r'''if _coconut_sys.version_info >= (3, 4):
import asyncio
else:
''' + _indent(try_import_trollius) if target_info >= (3,)
else try_import_trollius,
),
import_pickle=_indent(
r'''if _coconut_sys.version_info < (3,):
import cPickle as pickle
else:
import pickle''' if not target
else "import cPickle as pickle" if target_info < (3,)
else "import pickle"
),
import_OrderedDict=_indent(
r'''if _coconut_sys.version_info >= (2, 7):
OrderedDict = collections.OrderedDict
else:
OrderedDict = dict'''
if not target
else "OrderedDict = collections.OrderedDict" if target_info >= (2, 7)
else "OrderedDict = dict"
),
import_collections_abc=_indent(
r'''if _coconut_sys.version_info < (3, 3):
abc = collections
else:
import collections.abc as abc'''
if target_startswith != "2"
else "abc = collections"
),
bind_lru_cache=_indent(
r'''if _coconut_sys.version_info < (3, 2):
''' + _indent(try_backport_lru_cache)
if not target
else try_backport_lru_cache if target_startswith == "2"
else ""
),
comma_bytearray=", bytearray" if target_startswith != "3" else "",
static_repr="staticmethod(repr)" if target_startswith != "3" else "repr",
with_ThreadPoolExecutor=(
r'''from multiprocessing import cpu_count # cpu_count() * 5 is the default Python 3.5 thread count
with ThreadPoolExecutor(cpu_count() * 5)''' if target_info < (3, 5)
else '''with ThreadPoolExecutor()'''
),
def_tco_func="""def _coconut_tco_func(self, *args, **kwargs):
for func in self.patterns[:-1]:
try:
with _coconut_FunctionMatchErrorContext(self.FunctionMatchError):
return func(*args, **kwargs)
except self.FunctionMatchError:
pass
return _coconut_tail_call(self.patterns[-1], *args, **kwargs)
""",
def_prepattern=(
r'''def prepattern(base_func):
"""DEPRECATED: Use addpattern instead."""
def pattern_prepender(func):
return addpattern(func)(base_func)
return pattern_prepender
''' if not strict else ""
),
def_datamaker=(
r'''def datamaker(data_type):
"""DEPRECATED: Use makedata instead."""
return _coconut.functools.partial(makedata, data_type)
''' if not strict else ""
),
comma_tco=", _coconut_tail_call, _coconut_tco" if not no_tco else "",
)
format_dict["underscore_imports"] = "_coconut, _coconut_MatchError{comma_tco}, _coconut_igetitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_back_pipe, _coconut_star_pipe, _coconut_back_star_pipe, _coconut_dubstar_pipe, _coconut_back_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert".format(**format_dict)
format_dict["import_typing_NamedTuple"] = _indent(
"import typing" if target_info >= (3, 6)
else '''class typing{object}:
@staticmethod
def NamedTuple(name, fields):
return _coconut.collections.namedtuple(name, [x for x, t in fields])'''.format(**format_dict),
)
# ._coconut_tco_func is used in main.coco, so don't remove it
# here without replacing its usage there
format_dict["def_tco"] = "" if no_tco else '''class _coconut_tail_call{object}:
__slots__ = ("func", "args", "kwargs")
def __init__(self, func, *args, **kwargs):
self.func, self.args, self.kwargs = func, args, kwargs
_coconut_tco_func_dict = {empty_dict}
def _coconut_tco(func):
@_coconut.functools.wraps(func)
def tail_call_optimized_func(*args, **kwargs):
call_func = func
while True:{comment.weakrefs_necessary_for_ignoring_bound_methods}
wkref = _coconut_tco_func_dict.get(_coconut.id(call_func))
if (wkref is not None and wkref() is call_func) or _coconut.isinstance(call_func, _coconut_base_pattern_func):
call_func = call_func._coconut_tco_func
result = call_func(*args, **kwargs) # pass --no-tco to clean up your traceback
if not isinstance(result, _coconut_tail_call):
return result
call_func, args, kwargs = result.func, result.args, result.kwargs
tail_call_optimized_func._coconut_tco_func = func
tail_call_optimized_func.__module__ = _coconut.getattr(func, "__module__", None)
tail_call_optimized_func.__name__ = _coconut.getattr(func, "__name__", "<coconut tco function (pass --no-tco to remove)>")
tail_call_optimized_func.__qualname__ = _coconut.getattr(func, "__qualname__", tail_call_optimized_func.__name__)
_coconut_tco_func_dict[_coconut.id(tail_call_optimized_func)] = _coconut.weakref.ref(tail_call_optimized_func)
return tail_call_optimized_func
'''.format(**format_dict)
return format_dict, target_startswith, target_info
# -----------------------------------------------------------------------------------------------------------------------
# HEADER GENERATION:
# -----------------------------------------------------------------------------------------------------------------------
def getheader(which, target="", use_hash=None, no_tco=False, strict=False):
"""Generate the specified header."""
internal_assert(
which.startswith("package") or which in (
"none", "initial", "__coconut__", "sys", "code", "file",
),
"invalid header type",
which,
)
if which == "none":
return ""
# initial, __coconut__, package:n, sys, code, file
format_dict, target_startswith, target_info = process_header_args(which, target, use_hash, no_tco, strict)
if which == "initial" or which == "__coconut__":
header = '''#!/usr/bin/env python{target_startswith}
# -*- coding: {default_encoding} -*-
{hash_line}{typing_line}
# Compiled with Coconut version {VERSION_STR}
{module_docstring}'''.format(**format_dict)
elif use_hash is not None:
raise CoconutInternalException("can only add a hash to an initial or __coconut__ header, not", which)
else:
header = ""
if which == "initial":
return header
# __coconut__, package:n, sys, code, file
header += section("Coconut Header")
if target_startswith != "3":
header += "from __future__ import print_function, absolute_import, unicode_literals, division\n"
elif target_info >= (3, 7):
header += "from __future__ import generator_stop, annotations\n"
elif target_info >= (3, 5):
header += "from __future__ import generator_stop\n"
if which.startswith("package"):
levels_up = int(which[len("package:"):])
coconut_file_path = "_coconut_os_path.dirname(_coconut_os_path.abspath(__file__))"
for _ in range(levels_up):
coconut_file_path = "_coconut_os_path.dirname(" + coconut_file_path + ")"
return header + '''import sys as _coconut_sys, os.path as _coconut_os_path
_coconut_file_path = {coconut_file_path}
_coconut_cached_module = _coconut_sys.modules.get({__coconut__})
if _coconut_cached_module is not None and _coconut_os_path.dirname(_coconut_cached_module.__file__) != _coconut_file_path:
del _coconut_sys.modules[{__coconut__}]
_coconut_sys.path.insert(0, _coconut_file_path)
from __coconut__ import *
from __coconut__ import {underscore_imports}
{sys_path_pop}
'''.format(
coconut_file_path=coconut_file_path,
__coconut__=(
'"__coconut__"' if target_startswith == "3"
else 'b"__coconut__"' if target_startswith == "2"
else 'str("__coconut__")'
),
sys_path_pop=(
# we can't pop on Python 2 if we want __coconut__ objects to be pickleable
"_coconut_sys.path.pop(0)" if target_startswith == "3"
else "" if target_startswith == "2"
else '''if _coconut_sys.version_info >= (3,):
_coconut_sys.path.pop(0)'''
),
**format_dict
) + section("Compiled Coconut")
if which == "sys":
return header + '''import sys as _coconut_sys
from coconut.__coconut__ import *
from coconut.__coconut__ import {underscore_imports}
'''.format(**format_dict)
# __coconut__, code, file
header += "import sys as _coconut_sys\n"
if target_startswith == "3":
header += PY3_HEADER
elif target_info >= (2, 7):
header += PY27_HEADER
elif target_startswith == "2":
header += PY2_HEADER
else:
header += PYCHECK_HEADER
header += get_template("header").format(**format_dict)
if which == "file":
header += "\n" + section("Compiled Coconut")
return header