Skip to content

Commit

Permalink
[mycpp] Translate % to StrFormat(), supporting dynamic format strings (
Browse files Browse the repository at this point in the history
…#1430)

* [mycpp/builtins] use fputc() in print()

Else we truncate strings whenever we encounter NUL.
  • Loading branch information
melvinw committed Dec 20, 2022
1 parent f2d4969 commit bf828c9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 86 deletions.
108 changes: 24 additions & 84 deletions mycpp/cppgen_pass.py
Expand Up @@ -276,6 +276,15 @@ def get_c_return_type(t) -> Tuple[str, bool]:
return c_ret_type, False


def PythonStringLiteral(s: str) -> str:
"""
Returns a properly quoted string.
"""
# MyPy does bad escaping. Decode and push through json to get something
# workable in C++.
return json.dumps(format_strings.DecodeMyPyString(s))


class Generate(ExpressionVisitor[T], StatementVisitor[None]):

def __init__(self, types: Dict[Expression, Type], const_lookup, f,
Expand Down Expand Up @@ -618,14 +627,10 @@ def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
return

rest = args[1:]
if self.decl:
fmt = args[0].value
fmt_types = [self.types[arg] for arg in rest]
temp_name = self._WriteFmtFunc(fmt, fmt_types)
self.fmt_ids[o] = temp_name
quoted_fmt = PythonStringLiteral(args[0].value)

# DEFINITION PASS: Write the call
self.write('println_stderr(%s(' % self.fmt_ids[o])
self.write('println_stderr(StrFormat(%s, ' % quoted_fmt)
for i, arg in enumerate(rest):
if i != 0:
self.write(', ')
Expand Down Expand Up @@ -654,22 +659,16 @@ def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
if not rest:
pass

if self.decl:
fmt_arg = args[0]
if isinstance(fmt_arg, StrExpr):
fmt_types = [self.types[arg] for arg in rest]
temp_name = self._WriteFmtFunc(fmt_arg.value, fmt_types)
self.fmt_ids[o] = temp_name
else:
# oil_lang/expr_to_ast.py uses RANGE_POINT_TOO_LONG, etc.
self.fmt_ids[o] = "dynamic_fmt_dummy"
self.write('%s(StrFormat(' % o.callee.name)
if isinstance(args[0], StrExpr):
self.write(PythonStringLiteral(args[0].value))
else:
self.accept(args[0])

# Should p_die() be in mylib?
# DEFINITION PASS: Write the call
self.write('%s(%s(' % (o.callee.name, self.fmt_ids[o]))
for i, arg in enumerate(rest):
if i != 0:
self.write(', ')
self.write(', ')
self.accept(arg)

if has_keyword_arg:
Expand Down Expand Up @@ -707,59 +706,6 @@ def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
#self.log(' arg_kinds %s', o.arg_kinds)
#self.log(' arg_names %s', o.arg_names)

def _WriteFmtFunc(self, fmt, fmt_types):
"""Append a fmtX() function to a buffer.
Returns:
the temp fmtX() name we used.
"""
temp_name = 'fmt%d' % self.fmt_ids['_counter']
self.fmt_ids['_counter'] += 1

fmt_parts = format_strings.Parse(fmt)
self.fmt_funcs.write('inline Str* %s(' % temp_name)

# NOTE: We're not calling Alloc<> inside these functions, so
# they don't need StackRoots?
for i, typ in enumerate(fmt_types):
if i != 0:
self.fmt_funcs.write(', ');
self.fmt_funcs.write('%s a%d' % (get_c_type(typ), i))

self.fmt_funcs.write(') {\n')
self.fmt_funcs.write(' gBuf.reset();\n')

for part in fmt_parts:
if isinstance(part, format_strings.LiteralPart):
# MyPy does bad escaping.
# NOTE: We could do this in the CALLER to _WriteFmtFunc?

byte_string = bytes(part.s, 'utf-8')

# In Python 3
# >>> b'\\t'.decode('unicode_escape')
# '\t'

raw_string = format_strings.DecodeMyPyString(part.s)
n = len(raw_string) # NOT using part.strlen

escaped = json.dumps(raw_string)
self.fmt_funcs.write(
' gBuf.write_const(%s, %d);\n' % (escaped, n))
elif isinstance(part, format_strings.SubstPart):
# TODO: respect part.width as rjust()
self.fmt_funcs.write(
' gBuf.format_%s(a%d);\n' %
(part.char_code, part.arg_num))
else:
raise AssertionError(part)

self.fmt_funcs.write(' return gBuf.getvalue();\n')
self.fmt_funcs.write('}\n')
self.fmt_funcs.write('\n')

return temp_name

def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
c_op = o.op

Expand Down Expand Up @@ -811,8 +757,11 @@ def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:

# RHS can be primitive or tuple
if left_ctype == 'Str*' and c_op == '%':
if not isinstance(o.left, StrExpr):
raise AssertionError('Expected constant format string, got %s' % o.left)
self.write('StrFormat(')
if isinstance(o.left, StrExpr):
self.write(PythonStringLiteral(o.left.value))
else:
self.accept(o.left)
#log('right_type %s', right_type)
if isinstance(right_type, Instance):
fmt_types = [right_type]
Expand All @@ -826,22 +775,13 @@ def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T:
else:
raise AssertionError(right_type)

# Write a buffer with fmtX() functions.
if self.decl:
fmt = o.left.value

# TODO: I want to do this later
temp_name = self._WriteFmtFunc(fmt, fmt_types)
self.fmt_ids[o] = temp_name

# In the definition pass, write the call site.
self.write('%s(' % self.fmt_ids[o])
if isinstance(right_type, TupleType):
for i, item in enumerate(o.right.items):
if i != 0:
self.write(', ')
self.write(', ')
self.accept(item)
else: # '[%s]' % x
self.write(', ')
self.accept(o.right)

self.write(')')
Expand Down
20 changes: 20 additions & 0 deletions mycpp/examples/strings.py
Expand Up @@ -47,6 +47,26 @@ def run_tests():
x = 'x'
print("%s\tb\n%s\td\n" % (x, x))

fmt = "%dfoo"
print(fmt % 10)

fmts = ["foo%d"]
print(fmts[0] % 10)

print(("foo " + "%s") % "bar")

print("foo\0%s" % "bar")

print("foo%s" % "\0bar")

print("%o" % 12345)
print("%17o" % 12345)
print("%017o" % 12345)

print("%%%d%%%%" % 12345)

print("%r" % "tab\tline\nline\r\n")


def run_benchmarks():
# type: () -> None
Expand Down
8 changes: 6 additions & 2 deletions mycpp/gc_builtins.cc
Expand Up @@ -4,13 +4,17 @@

// Translation of Python's print().
void print(Str* s) {
fputs(s->data(), stdout);
for (int i = 0; i < len(s); ++i) {
fputc(s->data()[i], stdout);
}
fputs("\n", stdout);
}

// Like print(..., file=sys.stderr), but Python code explicitly calls it.
void println_stderr(Str* s) {
fputs(s->data(), stderr);
for (int i = 0; i < len(s); ++i) {
fputc(s->data()[i], stderr);
}
fputs("\n", stderr);
}

Expand Down

0 comments on commit bf828c9

Please sign in to comment.