Skip to content

Commit

Permalink
Add type param constraint support
Browse files Browse the repository at this point in the history
Refs   #757.
  • Loading branch information
evhub committed May 29, 2023
1 parent 5dce827 commit c236bd0
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 233 deletions.
6 changes: 4 additions & 2 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ To distribute your code with checkable type annotations, you'll need to include
To explicitly annotate your code with types to be checked, Coconut supports:
* [Python 3 function type annotations](https://www.python.org/dev/peps/pep-0484/),
* [Python 3.6 variable type annotations](https://www.python.org/dev/peps/pep-0526/),
* [PEP 695 type parameter syntax](#type-parameter-syntax) for easily adding type parameters to classes, functions, [`data` types](#data), and type aliases,
* [Python 3.12 type parameter syntax](#type-parameter-syntax) for easily adding type parameters to classes, functions, [`data` types](#data), and type aliases,
* Coconut's own [enhanced type annotation syntax](#enhanced-type-annotation), and
* Coconut's [protocol intersection operator](#protocol-intersection).

Expand Down Expand Up @@ -2579,14 +2579,16 @@ _Can't be done without a long series of checks in place of the destructuring ass

### Type Parameter Syntax

Coconut fully supports [PEP 695](https://peps.python.org/pep-0695/) type parameter syntax (with the caveat that all type variables are invariant rather than inferred).
Coconut fully supports [Python 3.12 PEP 695](https://peps.python.org/pep-0695/) type parameter syntax on all Python versions.

That includes type parameters for classes, [`data` types](#data), and [all types of function definition](#function-definition). For different types of function definition, the type parameters always come in brackets right after the function name. Coconut's [enhanced type annotation syntax](#enhanced-type-annotation) is supported for all type parameter bounds.

_Warning: until `mypy` adds support for `infer_variance=True` in `TypeVar`, `TypeVar`s created this way will always be invariant._

Additionally, Coconut supports the alternative bounds syntax of `type NewType[T <: bound] = ...` rather than `type NewType[T: bound] = ...`, to make it more clear that it is an upper bound rather than a type. In `--strict` mode, `<:` is required over `:` for all type parameter bounds. _DEPRECATED: `<=` can also be used as an alternative to `<:`._

Note that the `<:` syntax should only be used for [type bounds](https://peps.python.org/pep-0695/#upper-bound-specification), not [type constraints](https://peps.python.org/pep-0695/#constrained-type-specification)—for type constraints, Coconut style prefers the vanilla Python `:` syntax, which helps to disambiguate between the two cases, as they are functionally different but otherwise hard to tell apart at a glance. This is enforced in `--strict` mode.

_Note that, by default, all type declarations are wrapped in strings to enable forward references and improve runtime performance. If you don't want that—e.g. because you want to use type annotations at runtime—simply pass the `--no-wrap-types` flag._

##### PEP 695 Docs
Expand Down
60 changes: 34 additions & 26 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ License: Apache 2.0
Description: MyPy stub file for __coconut__.py.
"""

import sys
import typing as _t

# -----------------------------------------------------------------------------------------------------------------------
# TYPE VARS:
# -----------------------------------------------------------------------------------------------------------------------

import sys
import typing as _t

_Callable = _t.Callable[..., _t.Any]
_Iterable = _t.Iterable[_t.Any]
_Tuple = _t.Tuple[_t.Any, ...]
Expand Down Expand Up @@ -55,21 +55,14 @@ _P = _t.ParamSpec("_P")
class _SupportsIndex(_t.Protocol):
def __index__(self) -> int: ...


# -----------------------------------------------------------------------------------------------------------------------
# IMPORTS:
# -----------------------------------------------------------------------------------------------------------------------

if sys.version_info >= (3, 11):
from typing import dataclass_transform as _dataclass_transform
if sys.version_info >= (3,):
import builtins as _builtins
else:
try:
from typing_extensions import dataclass_transform as _dataclass_transform
except ImportError:
dataclass_transform = ...

import _coconut as __coconut # we mock _coconut as a package since mypy doesn't handle namespace classes very well
_coconut = __coconut
import __builtin__ as _builtins

if sys.version_info >= (3, 2):
from functools import lru_cache as _lru_cache
Expand All @@ -81,13 +74,24 @@ if sys.version_info >= (3, 7):
from dataclasses import dataclass as _dataclass
else:
@_dataclass_transform()
def _dataclass(cls: t_coype[_T], **kwargs: _t.Any) -> type[_T]: ...
def _dataclass(cls: type[_T], **kwargs: _t.Any) -> type[_T]: ...

if sys.version_info >= (3, 11):
from typing import dataclass_transform as _dataclass_transform
else:
try:
from typing_extensions import dataclass_transform as _dataclass_transform
except ImportError:
dataclass_transform = ...

try:
from typing_extensions import deprecated as _deprecated # type: ignore
except ImportError:
def _deprecated(message: _t.Text) -> _t.Callable[[_T], _T]: ... # type: ignore

import _coconut as __coconut # we mock _coconut as a package since mypy doesn't handle namespace classes very well
_coconut = __coconut


# -----------------------------------------------------------------------------------------------------------------------
# STUB:
Expand Down Expand Up @@ -153,18 +157,18 @@ py_repr = repr
py_breakpoint = breakpoint

# all py_ functions, but not py_ types, go here
chr = chr
hex = hex
input = input
map = map
oct = oct
open = open
print = print
range = range
zip = zip
filter = filter
reversed = reversed
enumerate = enumerate
chr = _builtins.chr
hex = _builtins.hex
input = _builtins.input
map = _builtins.map
oct = _builtins.oct
open = _builtins.open
print = _builtins.print
range = _builtins.range
zip = _builtins.zip
filter = _builtins.filter
reversed = _builtins.reversed
enumerate = _builtins.enumerate


_coconut_py_str = py_str
Expand Down Expand Up @@ -435,13 +439,17 @@ def recursive_iterator(func: _T_iter_func) -> _T_iter_func:
return func


# if sys.version_info >= (3, 12):
# from typing import override
# else:
try:
from typing_extensions import override as _override # type: ignore
override = _override
except ImportError:
def override(func: _Tfunc) -> _Tfunc:
return func


def _coconut_call_set_names(cls: object) -> None: ...


Expand Down
154 changes: 62 additions & 92 deletions _coconut/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ import multiprocessing as _multiprocessing
import pickle as _pickle
from multiprocessing import dummy as _multiprocessing_dummy

if sys.version_info >= (3,):
import builtins as _builtins
else:
import __builtin__ as _builtins

if sys.version_info >= (3,):
import copyreg as _copyreg
else:
Expand Down Expand Up @@ -68,41 +73,6 @@ else:
# -----------------------------------------------------------------------------------------------------------------------

typing = _t

from typing_extensions import TypeVar
typing.TypeVar = TypeVar # type: ignore

if sys.version_info < (3, 8):
try:
from typing_extensions import Protocol
except ImportError:
Protocol = ... # type: ignore
typing.Protocol = Protocol # type: ignore

if sys.version_info < (3, 10):
try:
from typing_extensions import TypeAlias, ParamSpec, Concatenate
except ImportError:
TypeAlias = ... # type: ignore
ParamSpec = ... # type: ignore
Concatenate = ... # type: ignore
typing.TypeAlias = TypeAlias # type: ignore
typing.ParamSpec = ParamSpec # type: ignore
typing.Concatenate = Concatenate # type: ignore

if sys.version_info < (3, 11):
try:
from typing_extensions import TypeVarTuple, Unpack
except ImportError:
TypeVarTuple = ... # type: ignore
Unpack = ... # type: ignore
typing.TypeVarTuple = TypeVarTuple # type: ignore
typing.Unpack = Unpack # type: ignore

# -----------------------------------------------------------------------------------------------------------------------
# STUB:
# -----------------------------------------------------------------------------------------------------------------------

collections = _collections
copy = _copy
functools = _functools
Expand Down Expand Up @@ -141,62 +111,62 @@ tee_type: _t.Any = ...
reiterables: _t.Any = ...
fmappables: _t.Any = ...

Ellipsis = Ellipsis
NotImplemented = NotImplemented
NotImplementedError = NotImplementedError
Exception = Exception
AttributeError = AttributeError
ImportError = ImportError
IndexError = IndexError
KeyError = KeyError
NameError = NameError
TypeError = TypeError
ValueError = ValueError
StopIteration = StopIteration
RuntimeError = RuntimeError
callable = callable
classmethod = classmethod
complex = complex
all = all
any = any
bool = bool
bytes = bytes
dict = dict
enumerate = enumerate
filter = filter
float = float
frozenset = frozenset
getattr = getattr
hasattr = hasattr
hash = hash
id = id
int = int
isinstance = isinstance
issubclass = issubclass
iter = iter
Ellipsis = _builtins.Ellipsis
NotImplemented = _builtins.NotImplemented
NotImplementedError = _builtins.NotImplementedError
Exception = _builtins.Exception
AttributeError = _builtins.AttributeError
ImportError = _builtins.ImportError
IndexError = _builtins.IndexError
KeyError = _builtins.KeyError
NameError = _builtins.NameError
TypeError = _builtins.TypeError
ValueError = _builtins.ValueError
StopIteration = _builtins.StopIteration
RuntimeError = _builtins.RuntimeError
callable = _builtins.callable
classmethod = _builtins.classmethod
complex = _builtins.complex
all = _builtins.all
any = _builtins.any
bool = _builtins.bool
bytes = _builtins.bytes
dict = _builtins.dict
enumerate = _builtins.enumerate
filter = _builtins.filter
float = _builtins.float
frozenset = _builtins.frozenset
getattr = _builtins.getattr
hasattr = _builtins.hasattr
hash = _builtins.hash
id = _builtins.id
int = _builtins.int
isinstance = _builtins.isinstance
issubclass = _builtins.issubclass
iter = _builtins.iter
len: _t.Callable[..., int] = ... # pattern-matching needs an untyped _coconut.len to avoid type errors
list = list
locals = locals
globals = globals
map = map
min = min
max = max
next = next
object = object
print = print
property = property
range = range
reversed = reversed
set = set
setattr = setattr
slice = slice
str = str
sum = sum
super = super
tuple = tuple
type = type
zip = zip
vars = vars
repr = repr
list = _builtins.list
locals = _builtins.locals
globals = _builtins.globals
map = _builtins.map
min = _builtins.min
max = _builtins.max
next = _builtins.next
object = _builtins.object
print = _builtins.print
property = _builtins.property
range = _builtins.range
reversed = _builtins.reversed
set = _builtins.set
setattr = _builtins.setattr
slice = _builtins.slice
str = _builtins.str
sum = _builtins.sum
super = _builtins.super
tuple = _builtins.tuple
type = _builtins.type
zip = _builtins.zip
vars = _builtins.vars
repr = _builtins.repr
if sys.version_info >= (3,):
bytearray = bytearray
bytearray = _builtins.bytearray
Loading

0 comments on commit c236bd0

Please sign in to comment.