Skip to content

Commit

Permalink
[frontend] New API typed_args::Reader (#1705)
Browse files Browse the repository at this point in the history
- Use it in several places
  • Loading branch information
melvinw committed Aug 24, 2023
1 parent 20f7f97 commit fdb6aaf
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 82 deletions.
142 changes: 91 additions & 51 deletions frontend/typed_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,17 @@
"""Typed_args.py."""
from __future__ import print_function

from _devbuild.gen.runtime_asdl import value_t, value_str
from _devbuild.gen.runtime_asdl import value_t
from _devbuild.gen.syntax_asdl import (loc, ArgList, BlockArg, command_t,
expr_e, expr_t, CommandSub)
from core import error
from core.error import e_usage
from mycpp.mylib import tagswitch
from mycpp.mylib import dict_erase, tagswitch
from ysh import val_ops

from typing import Optional, Dict, List, cast


class Spec(object):
"""Utility to express argument specifications (runtime typechecking)."""

def __init__(self, pos_args, named_args):
# type: (List[int], Dict[str, int]) -> None
"""Empty constructor for mycpp."""
self.pos_args = pos_args
self.named_args = named_args

def AssertArgs(self, func_name, pos_args, named_args):
# type: (str, List[value_t], Dict[str, value_t]) -> None
"""Assert any type differences between the spec and the given args."""
nargs = len(pos_args)
expected = len(self.pos_args)
if nargs != expected:
raise error.InvalidType(
"%s() expects %d arguments but %d were given" %
(func_name, expected, nargs), loc.Missing)

nargs = len(named_args)
expected = len(self.named_args)
if len(named_args) != 0:
raise error.InvalidType(
"%s() expects %d named arguments but %d were given" %
(func_name, expected, nargs), loc.Missing)

for i in xrange(len(pos_args)):
expected = self.pos_args[i]
got = pos_args[i]
if got.tag() != expected:
msg = "%s() expected %s" % (func_name, value_str(expected))
raise error.InvalidType2(got, msg, loc.Missing)

for name in named_args:
expected = self.named_args[name]
got = named_args[name]
if got.tag() != expected:
msg = "%s() expected %s" % (func_name, value_str(expected))
raise error.InvalidType2(got, msg, loc.Missing)


class Reader(object):
"""
func f(a Str) {
Expand Down Expand Up @@ -99,6 +59,7 @@ class Reader(object):
def __init__(self, pos_args, named_args):
# type: (List[value_t], Dict[str, value_t]) -> None
self.pos_args = pos_args
self.pos_consumed = 0
self.named_args = named_args

### Words: untyped args for procs
Expand All @@ -113,32 +74,104 @@ def RestWords(self):

### Typed positional args

# TODO: may need location info
def _GetNextPos(self):
# type: () -> value_t
if len(self.pos_args) == 0:
# TODO: may need location info
raise error.InvalidType(
'Expected at least %d arguments, but only got %d' %
(self.pos_consumed + 1, self.pos_consumed), loc.Missing)

self.pos_consumed += 1
return self.pos_args.pop(0)

def PosStr(self):
# type: () -> str
return None # TODO
arg = self._GetNextPos()
return val_ops.MustBeStr(arg).s

def PosInt(self):
# type: () -> int
return -1 # TODO
arg = self._GetNextPos()
return val_ops.MustBeInt(arg).i

def PosFloat(self):
# type: () -> float
arg = self._GetNextPos()
return val_ops.MustBeFloat(arg).f

def PosList(self):
# type: () -> List[value_t]
arg = self._GetNextPos()
return val_ops.MustBeList(arg).items

def PosDict(self):
# type: () -> Dict[str, value_t]
arg = self._GetNextPos()
return val_ops.MustBeDict(arg).d

def PosValue(self):
# type: () -> value_t
return self._GetNextPos()

def RestPos(self):
# type: () -> List[value_t]
return None # TODO
ret = self.pos_args
self.pos_args = []
return ret

### Typed named args

def NamedStr(self, param_name, default_):
# type: (str, str) -> str
return None # TODO
if param_name not in self.named_args:
return default_

ret = val_ops.MustBeStr(self.named_args[param_name]).s
dict_erase(self.named_args, param_name)
return ret

def NamedInt(self, param_name, default_):
# type: (str, int) -> int
return -1 # TODO
if param_name not in self.named_args:
return default_

ret = val_ops.MustBeInt(self.named_args[param_name]).i
dict_erase(self.named_args, param_name)
return ret

def NamedFloat(self, param_name, default_):
# type: (str, float) -> float
if param_name not in self.named_args:
return default_

ret = val_ops.MustBeFloat(self.named_args[param_name]).f
dict_erase(self.named_args, param_name)
return ret

def NamedList(self, param_name, default_):
# type: (str, List[value_t]) -> List[value_t]
if param_name not in self.named_args:
return default_

ret = val_ops.MustBeList(self.named_args[param_name]).items
dict_erase(self.named_args, param_name)
return ret

def NamedDict(self, param_name, default_):
# type: (str, Dict[str, value_t]) -> Dict[str, value_t]
if param_name not in self.named_args:
return default_

ret = val_ops.MustBeDict(self.named_args[param_name]).d
dict_erase(self.named_args, param_name)
return ret

def RestNamed(self):
# type: () -> Dict[str, value_t]
return None # TODO
ret = self.named_args
self.named_args = {}
return ret

def Block(self):
# type: () -> command_t
Expand All @@ -159,7 +192,14 @@ def Done(self):
problem
"""
# Note: Python throws TypeError on mismatch
pass
if len(self.pos_args):
raise error.InvalidType('Expected %d arguments, but got %d' %
(self.pos_consumed, self.pos_consumed +
len(self.pos_args)), loc.Missing)

if len(self.named_args):
bad_args = ','.join(self.named_args.keys())
raise error.InvalidType('Got unexpected named args: %s' % bad_args, loc.Missing)


def DoesNotAccept(arg_list):
Expand Down
116 changes: 116 additions & 0 deletions frontend/typed_args_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python2
"""
typed_args_test.py: Tests for typed_args.py
"""

import unittest

from _devbuild.gen.runtime_asdl import value
from core import error
from frontend import typed_args # module under test

from typing import cast


class TypedArgsTest(unittest.TestCase):
def testReaderPosArgs(self):
# Not enough args...
reader = typed_args.Reader([], {})
self.assertRaises(error.InvalidType, reader.PosStr)

pos_args = [
value.Int(0xc0ffee),
value.Str('foo'),
value.List([value.Int(1), value.Int(2), value.Int(3)]),
value.Dict({'a': value.Int(0xaa), 'b': value.Int(0xbb)}),
value.Float(3.14),
value.Int(0xdead),
value.Int(0xbeef),
value.Str('bar'),
]
reader = typed_args.Reader(list(pos_args), {})

# Haven't processed any args yet...
self.assertRaises(error.InvalidType, reader.Done)

# Arg is wrong type...
self.assertRaises(error.InvalidType, reader.PosStr)

# Normal operation from here on
reader = typed_args.Reader(pos_args, {})
arg = reader.PosInt()
self.assertEqual(0xc0ffee, arg)

arg = reader.PosStr()
self.assertEqual('foo', arg)

arg = reader.PosList()
self.assertEqual(1, cast(value.Int, arg[0]).i)
self.assertEqual(2, cast(value.Int, arg[1]).i)
self.assertEqual(3, cast(value.Int, arg[2]).i)

arg = reader.PosDict()
self.assertIn('a', arg)
self.assertEqual(0xaa, arg['a'].i)
self.assertIn('b', arg)
self.assertEqual(0xbb, arg['b'].i)

arg = reader.PosFloat()
self.assertEqual(3.14, arg)

rest = reader.RestPos()
self.assertEqual(3, len(rest))
self.assertEqual(0xdead, rest[0].i)
self.assertEqual(0xbeef, rest[1].i)
self.assertEqual('bar', rest[2].s)

reader.Done()

def testReaderKwargs(self):
kwargs = {
'hot': value.Int(0xc0ffee),
'name': value.Str('foo'),
'numbers': value.List([value.Int(1), value.Int(2), value.Int(3)]),
'blah': value.Dict({'a': value.Int(0xaa), 'b': value.Int(0xbb)}),
'pi': value.Float(3.14),
'a': value.Int(0xdead),
'b': value.Int(0xbeef),
'c': value.Str('bar'),
}
reader = typed_args.Reader([], kwargs)

# Haven't processed any args yet...
self.assertRaises(error.InvalidType, reader.Done)

arg = reader.NamedInt('hot', -1)
self.assertEqual(0xc0ffee, arg)

arg = reader.NamedStr('name', '')
self.assertEqual('foo', arg)

arg = reader.NamedList('numbers', [])
self.assertEqual(1, cast(value.Int, arg[0]).i)
self.assertEqual(2, cast(value.Int, arg[1]).i)
self.assertEqual(3, cast(value.Int, arg[2]).i)

arg = reader.NamedDict('blah', {})
self.assertIn('a', arg)
self.assertEqual(0xaa, arg['a'].i)
self.assertIn('b', arg)
self.assertEqual(0xbb, arg['b'].i)

arg = reader.NamedFloat('pi', -1.0)
self.assertEqual(3.14, arg)

rest = reader.RestNamed()
self.assertEqual(3, len(rest))
self.assertIn('a', rest)
self.assertEqual(0xdead, rest['a'].i)
self.assertIn('b', rest)
self.assertEqual(0xbeef, rest['b'].i)
self.assertIn('c', rest)
self.assertEqual('bar', rest['c'].s)


if __name__ == '__main__':
unittest.main()
7 changes: 3 additions & 4 deletions library/func_hay.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def _Call(self, path):
def Call(self, pos_args, named_args):
# type: (List[value_t], Dict[str, value_t]) -> value_t

spec = typed_args.Spec([value_e.Str], {})
spec.AssertArgs("parseHay", pos_args, named_args)

string = cast(value.Str, pos_args[0]).s
arg_reader = typed_args.Reader(pos_args, named_args)
string = arg_reader.PosStr()
arg_reader.Done()
return self._Call(string)


Expand Down

0 comments on commit fdb6aaf

Please sign in to comment.