Skip to content

Commit

Permalink
Fix min and max
Browse files Browse the repository at this point in the history
Resolves   #834.
  • Loading branch information
evhub committed Mar 11, 2024
1 parent fcb2c2e commit 2566839
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 34 deletions.
2 changes: 2 additions & 0 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ To make Coconut built-ins universal across Python versions, Coconut makes availa
- `py_xrange`
- `py_repr`
- `py_breakpoint`
- `py_min`
- `py_max`

_Note: Coconut's `repr` can be somewhat tricky, as it will attempt to remove the `u` before reprs of unicode strings on Python 2, but will not always be able to do so if the unicode string is nested._

Expand Down
4 changes: 4 additions & 0 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ py_reversed = reversed
py_enumerate = enumerate
py_repr = repr
py_breakpoint = breakpoint
py_min = min
py_max = max

# all py_ functions, but not py_ types, go here
chr = _builtins.chr
Expand All @@ -189,6 +191,8 @@ zip = _builtins.zip
filter = _builtins.filter
reversed = _builtins.reversed
enumerate = _builtins.enumerate
min = _builtins.min
max = _builtins.max


_coconut_py_str = py_str
Expand Down
6 changes: 3 additions & 3 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ class map(_coconut_baseclass, _coconut.map):
def __len__(self):
if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters):
return _coconut.NotImplemented
return _coconut.min(_coconut.len(it) for it in self.iters)
return _coconut.min((_coconut.len(it) for it in self.iters), default=0)
def __repr__(self):
return "%s(%r, %s)" % (self.__class__.__name__, self.func, ", ".join((_coconut.repr(it) for it in self.iters)))
def __reduce__(self):
Expand Down Expand Up @@ -985,7 +985,7 @@ class zip(_coconut_baseclass, _coconut.zip):
def __len__(self):
if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters):
return _coconut.NotImplemented
return _coconut.min(_coconut.len(it) for it in self.iters)
return _coconut.min((_coconut.len(it) for it in self.iters), default=0)
def __repr__(self):
return "zip(%s%s)" % (", ".join((_coconut.repr(it) for it in self.iters)), ", strict=True" if self.strict else "")
def __reduce__(self):
Expand Down Expand Up @@ -1036,7 +1036,7 @@ class zip_longest(zip):
def __len__(self):
if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters):
return _coconut.NotImplemented
return _coconut.max(_coconut.len(it) for it in self.iters)
return _coconut.max((_coconut.len(it) for it in self.iters), default=0)
def __repr__(self):
return "zip_longest(%s, fillvalue=%s)" % (", ".join((_coconut.repr(it) for it in self.iters)), _coconut.repr(self.fillvalue))
def __reduce__(self):
Expand Down
2 changes: 2 additions & 0 deletions coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,8 @@ def get_path_env_var(env_var, default):
"py_xrange",
"py_repr",
"py_breakpoint",
"py_min",
"py_max",
"_namedtuple_of",
"reveal_type",
"reveal_locals",
Expand Down
84 changes: 53 additions & 31 deletions coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.1.0"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 2
DEVELOP = 3
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down Expand Up @@ -61,16 +61,16 @@ def _get_target_info(target):

# if a new assignment is added below, a new builtins import should be added alongside it
_base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr
py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr
_coconut_py_str, _coconut_py_super, _coconut_py_dict = str, super, dict
py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr, py_min, py_max = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr, min, max
_coconut_py_str, _coconut_py_super, _coconut_py_dict, _coconut_py_min, _coconut_py_max = str, super, dict, min, max
from functools import wraps as _coconut_wraps
exec("_coconut_exec = exec")
'''

# if a new assignment is added below, a new builtins import should be added alongside it
_base_py2_header = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long
py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr
_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict, _coconut_py_bytes = raw_input, xrange, int, long, print, str, super, unicode, repr, dict, bytes
py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr, py_min, py_max = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, min, max
_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict, _coconut_py_bytes, _coconut_py_min, _coconut_py_max = raw_input, xrange, int, long, print, str, super, unicode, repr, dict, bytes, min, max
from functools import wraps as _coconut_wraps
from collections import Sequence as _coconut_Sequence
from future_builtins import *
Expand Down Expand Up @@ -278,26 +278,26 @@ def __call__(self, obj):
_coconut_operator.methodcaller = _coconut_methodcaller
'''

_non_py37_extras = r'''def _coconut_default_breakpointhook(*args, **kwargs):
hookname = _coconut.os.getenv("PYTHONBREAKPOINT")
if hookname != "0":
if not hookname:
hookname = "pdb.set_trace"
modname, dot, funcname = hookname.rpartition(".")
if not dot:
modname = "builtins" if _coconut_sys.version_info >= (3,) else "__builtin__"
if _coconut_sys.version_info >= (2, 7):
import importlib
module = importlib.import_module(modname)
_below_py34_extras = '''def min(*args, **kwargs):
if len(args) == 1 and "default" in kwargs:
obj = tuple(args[0])
default = kwargs.pop("default")
if len(obj):
return _coconut_py_min(obj, **kwargs)
else:
import imp
module = imp.load_module(modname, *imp.find_module(modname))
hook = _coconut.getattr(module, funcname)
return hook(*args, **kwargs)
if not hasattr(_coconut_sys, "__breakpointhook__"):
_coconut_sys.__breakpointhook__ = _coconut_default_breakpointhook
def breakpoint(*args, **kwargs):
return _coconut.getattr(_coconut_sys, "breakpointhook", _coconut_default_breakpointhook)(*args, **kwargs)
return default
else:
return _coconut_py_min(*args, **kwargs)
def max(*args, **kwargs):
if len(args) == 1 and "default" in kwargs:
obj = tuple(args[0])
default = kwargs.pop("default")
if len(obj):
return _coconut_py_max(obj, **kwargs)
else:
return default
else:
return _coconut_py_max(*args, **kwargs)
'''

_finish_dict_def = '''
Expand All @@ -321,6 +321,26 @@ def __subclasscheck__(cls, subcls):
'''

_below_py37_extras = '''from collections import OrderedDict as _coconut_OrderedDict
def _coconut_default_breakpointhook(*args, **kwargs):
hookname = _coconut.os.getenv("PYTHONBREAKPOINT")
if hookname != "0":
if not hookname:
hookname = "pdb.set_trace"
modname, dot, funcname = hookname.rpartition(".")
if not dot:
modname = "builtins" if _coconut_sys.version_info >= (3,) else "__builtin__"
if _coconut_sys.version_info >= (2, 7):
import importlib
module = importlib.import_module(modname)
else:
import imp
module = imp.load_module(modname, *imp.find_module(modname))
hook = _coconut.getattr(module, funcname)
return hook(*args, **kwargs)
if not hasattr(_coconut_sys, "__breakpointhook__"):
_coconut_sys.__breakpointhook__ = _coconut_default_breakpointhook
def breakpoint(*args, **kwargs):
return _coconut.getattr(_coconut_sys, "breakpointhook", _coconut_default_breakpointhook)(*args, **kwargs)
class _coconut_dict_base(_coconut_OrderedDict):
__slots__ = ()
__doc__ = getattr(_coconut_OrderedDict, "__doc__", "<see help(py_dict)>")
Expand Down Expand Up @@ -385,15 +405,17 @@ def _get_root_header(version="universal"):
header += r'''py_breakpoint = breakpoint
'''
elif version == "3":
header += r'''if _coconut_sys.version_info < (3, 7):
''' + _indent(_non_py37_extras) + r'''else:
header += r'''if _coconut_sys.version_info >= (3, 7):
py_breakpoint = breakpoint
'''
else:
assert version.startswith("2"), version
header += _non_py37_extras
if version == "2":
header += _py26_extras
elif version == "2":
header += _py26_extras

if version.startswith("2"):
header += _below_py34_extras
elif version_info < (3, 4):
header += r'''if _coconut_sys.version_info < (3, 4):
''' + _indent(_below_py34_extras)

if version == "3":
header += r'''if _coconut_sys.version_info < (3, 7):
Expand Down
3 changes: 3 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def primary_test_2() -> bool:
match def maybe_dup(x, y=x) = (x, y)
assert maybe_dup(1) == (1, 1) == maybe_dup(x=1)
assert maybe_dup(1, 2) == (1, 2) == maybe_dup(x=1, y=2)
assert min((), default=10) == 10 == max((), default=10)
assert py_min(3, 4) == 3 == py_max(2, 3)
assert len(zip()) == 0 == len(zip_longest()) # type: ignore

with process_map.multiple_sequential_calls(): # type: ignore
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore
Expand Down

0 comments on commit 2566839

Please sign in to comment.