diff --git a/quantdsl/semantics.py b/quantdsl/semantics.py index 87c9d66..4ddcd7c 100644 --- a/quantdsl/semantics.py +++ b/quantdsl/semantics.py @@ -583,7 +583,7 @@ class Name(DslExpression): relative_cost = 0 def pprint(self, indent=''): - return self.name + return indent + self.name def validate(self, args): assert isinstance(args[0], (six.string_types, String)), type(args[0]) @@ -644,8 +644,8 @@ class FunctionDef(DslObject): def pprint(self, indent=''): msg = "" for decorator_name in self.decorator_names: - msg += "@" + decorator_name + "\n" - msg += "def %s(%s):\n" % (self.name, ", ".join(self.call_arg_names)) + msg += indent + "@" + decorator_name + "\n" + msg += indent + "def %s(%s):\n" % (self.name, ", ".join(self.call_arg_names)) if isinstance(self.body, DslObject): try: msg += self.body.pprint(indent=indent + ' ') diff --git a/quantdsl/tests/test_semantics.py b/quantdsl/tests/test_semantics.py index 2415a1c..7f48d8d 100644 --- a/quantdsl/tests/test_semantics.py +++ b/quantdsl/tests/test_semantics.py @@ -373,6 +373,15 @@ def test_substitute(self): self.assertEqual(obj.substitute_names(ns), function_def) +class TestFunctionDef(TestCase): + def test_pprint(self): + fd = FunctionDef('f', [], Name('a'), []) + code = fd.pprint(indent='') + self.assertEqual(code, "def f():\n a") + code = fd.pprint(indent=' ') + self.assertEqual(code, " def f():\n a") + + class TestFunctionCall(TestCase): def test_substitute_names(self): fc = FunctionCall(Name('f'), [Name('x')])