Skip to content

Commit

Permalink
Change to n+1 / positional default parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacqueswww committed Aug 13, 2018
1 parent 65afb40 commit f2901ee
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 24 deletions.
14 changes: 3 additions & 11 deletions tests/parser/functions/test_default_parameters.py
Expand Up @@ -9,12 +9,11 @@ def safeTransferFrom(_data: bytes[100] = "test", _b: int128 = 1):
"""
abi = get_contract(code)._classic_contract.abi

assert len(abi) == 4
assert len(abi) == 3
assert set([fdef['name'] for fdef in abi]) == {'safeTransferFrom'}
assert abi[0]['inputs'] == []
assert abi[1]['inputs'] == [{'type': 'int128', 'name': '_b'}]
assert abi[2]['inputs'] == [{'type': 'bytes', 'name': '_data'}]
assert abi[3]['inputs'] == [{'type': 'bytes', 'name': '_data'}, {'type': 'int128', 'name': '_b'}]
assert abi[1]['inputs'] == [{'type': 'bytes', 'name': '_data'}]
assert abi[2]['inputs'] == [{'type': 'bytes', 'name': '_data'}, {'type': 'int128', 'name': '_b'}]


def test_basic_default_param_passthrough(get_contract):
Expand All @@ -27,7 +26,6 @@ def fooBar(_data: bytes[100] = "test", _b: int128 = 1) -> int128:
c = get_contract(code)

assert c.fooBar() == 12321
assert c.fooBar(2) == 12321
assert c.fooBar(b"drum drum") == 12321
assert c.fooBar(b"drum drum", 2) == 12321

Expand Down Expand Up @@ -60,8 +58,6 @@ def fooBar(a:int128, b: uint256 = 999, c: address = 0x00000000000000000000000000
assert c.fooBar(123) == [123, b_default_value, c_default_value]
# c default_value, b set from param
assert c.fooBar(456, 444) == [456, 444, c_default_value]
# b default set, c set from param
assert c.fooBar(555, addr2) == [555, b_default_value, addr2]
# no default values
assert c.fooBar(6789, 4567, addr2) == [6789, 4567, addr2]

Expand All @@ -78,8 +74,6 @@ def fooBar(a: bytes[100], b: int128, c: bytes[100] = "testing", d: uint256 = 999

# c set, 7d default value
assert c.fooBar(b"booo", 12321, b'woo') == [b"booo", 12321, b'woo', d_default]
# d set, c default value
assert c.fooBar(b"booo", 12321, 888) == [b"booo", 12321, c_default, 888]
# d set, c set
assert c.fooBar(b"booo", 12321, b"lucky", 777) == [b"booo", 12321, b"lucky", 777]
# no default values
Expand All @@ -98,8 +92,6 @@ def fooBar(a: bytes[100], b: uint256[2], c: bytes[6] = "hello", d: int128[3] = [

# c set, d default value
assert c.fooBar(b"booo", [99, 88], b'woo') == [b"booo", 88, b'woo', d_default]
# d set, c default value
assert c.fooBar(b"booo", [99, 88], [34, 35, 36]) == [b"booo", 88, c_default, 36]
# d set, c set
assert c.fooBar(b"booo", [22, 11], b"lucky", [24, 25, 26]) == [b"booo", 11, b"lucky", 26]
# no default values
Expand Down
26 changes: 13 additions & 13 deletions vyper/parser/parser.py
Expand Up @@ -2,12 +2,8 @@
import copy
import tokenize
import io
import itertools
import re

from collections import (
Counter
)
from tokenize import (
OP,
NAME,
Expand Down Expand Up @@ -381,11 +377,17 @@ def generate_default_arg_sigs(code, _contracts, _custom_units):
return [FunctionSignature.from_definition(code, sigs=_contracts, custom_units=_custom_units)]
base_args = code.args.args[:-total_default_args]
default_args = code.args.args[-total_default_args:]
truth_table = list(itertools.product([False, True], repeat=total_default_args))

# Generate a list of default function combinations.
row = [False] * (total_default_args)
table = [row.copy()]
for i in range(total_default_args):
row[i] = True
table.append(row.copy())

default_sig_strs = []
sig_fun_defs = []
for truth_row in truth_table:
for truth_row in table:
new_code = copy.deepcopy(code)
new_code.args.args = copy.deepcopy(base_args)
new_code.args.default = []
Expand All @@ -397,10 +399,6 @@ def generate_default_arg_sigs(code, _contracts, _custom_units):
default_sig_strs.append(sig.sig)
sig_fun_defs.append(sig)

if len(default_sig_strs) != len(set(default_sig_strs)):
violations = [item for item, count in Counter(default_sig_strs).items() if count > 1]
raise FunctionDeclarationException('Default variables are causing a conflict: {}'.format(','.join(violations)), code)

return sig_fun_defs


Expand Down Expand Up @@ -631,7 +629,6 @@ def parse_func(code, _globals, sigs, origcode, _custom_units, _vars=None):
# Variables to be populated from calldata
copier_arg_count = len(default_sig.args) - len(base_args)
default_copiers = []

if copier_arg_count > 0:
current_sig_arg_names = {x.name for x in default_sig.args}
base_arg_names = {arg.name for arg in base_args}
Expand All @@ -647,12 +644,13 @@ def parse_func(code, _globals, sigs, origcode, _custom_units, _vars=None):
for arg_name in copier_arg_names:
var = context.vars[arg_name]
calldata_offset = calldata_offset_map[arg_name]
# Add clampers.
default_copiers.append(make_clamper(calldata_offset - 4, var.pos, var.typ))
# Add copying code.
if isinstance(var.typ, ByteArrayType):
default_copiers.append(['calldatacopy', var.pos, ['add', 4, ['calldataload', calldata_offset]], var.size * 32])
else:
default_copiers.append(['calldatacopy', var.pos, calldata_offset, var.size * 32])
# Add clampers.
default_copiers.append(make_clamper(calldata_offset - 4, var.pos, var.typ))

sig_chain.append([
'if', ['eq', ['mload', 0], method_id_node],
Expand All @@ -669,7 +667,9 @@ def parse_func(code, _globals, sigs, origcode, _custom_units, _vars=None):
['seq',
['label', function_routine],
['seq'] + clampers + [parse_body(c, context) for c in code.body] + ['stop']]]], typ=None, pos=getpos(code))

else:
# Function without default parameters.
method_id_node = LLLnode.from_list(sig.method_id, pos=getpos(code), annotation='%s' % sig.sig)
o = LLLnode.from_list(
['if',
Expand Down

0 comments on commit f2901ee

Please sign in to comment.