Skip to content

Commit

Permalink
Implement first version of mypy plugin for chain and add (#201)
Browse files Browse the repository at this point in the history
* ✨ Implement first version of mypy plugin for chain

* ♻️ Refactor mypy plugin to work with original chain signature

* 🚚 Rename chain to be a single function + mypy plugin

* 🐛 Use correct types for mypy plugin, fix type checks

* 🔀 Merge develop

* ✅ Update tests to use singular chain function

* ♻️ Refactor mypy plugin to generalize to reducers (chain, add)

* ✨ Add reducer types to add, chain, to trigger mypy plugin

* ✅ Add mypy plugin tests

* 🚚 Rename chain module to chain_module and add to add_module

* 🔧 Add mypy tests init file

* 🎨 Format mypy test

* ✅ Update mypy tests to raise mypy coverage to 100%

* 🔊 Add Azure debugging logs

* ✅ Try to convince Azure to run mypy tests from the right directory

* 💚 Add mypy tests to dist package

* 💚 Ignore mypy for mypy tests (that are made to error out)

* 🔧 Make Thinc compatible with PEP 561, declares its own types

* ✅ Refactor mypy tests to run on tempdir and be compatible with package

* ✨ Add support for registry decorators with mypy plugin

And revert add with secondary function as it now seems to work with a different module name

* 🔧 Add thinc mypy plugin to default setup.cfg for internal development

Co-authored-by: Ines Montani <ines@ines.io>
Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
  • Loading branch information
3 people committed Jan 15, 2020
1 parent 100ac83 commit 522aff0
Show file tree
Hide file tree
Showing 26 changed files with 322 additions and 67 deletions.
3 changes: 3 additions & 0 deletions MANIFEST.in
Expand Up @@ -2,3 +2,6 @@ recursive-include thinc *.cu *.pyx *.pxd
include LICENSE
include README.md
prune tmp/
include thinc/tests/mypy/configs/*.ini
include thinc/tests/mypy/outputs/*.txt
include thinc/py.typed
5 changes: 4 additions & 1 deletion setup.cfg
Expand Up @@ -126,4 +126,7 @@ exclude_lines =
[mypy]
ignore_missing_imports = True
no_implicit_optional = True
plugins = pydantic.mypy
plugins = pydantic.mypy, thinc.mypy

[mypy-thinc.tests.mypy.*]
ignore_errors = True
4 changes: 2 additions & 2 deletions thinc/layers/__init__.py
Expand Up @@ -23,9 +23,9 @@
from .tensorflowwrapper import TensorFlowWrapper

# Combinators
from .add import add
from .add_module import add
from .bidirectional import bidirectional
from .chain import chain
from .chain_module import chain
from .clone import clone
from .concatenate import concatenate
from .noop import noop
Expand Down
8 changes: 4 additions & 4 deletions thinc/layers/add.py → thinc/layers/add_module.py
Expand Up @@ -4,19 +4,19 @@
from ..config import registry
from ..types import Array
from ..util import get_width
from thinc.types import Reduced_OutT


InT = TypeVar("InT", bound=Array)


@registry.layers("add.v0")
def add(*layers: Model) -> Model[InT, InT]:
def add(layer1: Model[InT, InT], layer2: Model[InT, InT], *layers: Model) -> Model[InT, Reduced_OutT]:
"""Compose two or more models `f`, `g`, etc, such that their outputs are
added, i.e. `add(f, g)(x)` computes `f(x) + g(x)`.
"""
if len(layers) < 2: # we need variable arguments for the config
raise TypeError("The 'add' combinator needs at least 2 layers")
if layers and layers[0].name == "add":
layers = (layer1, layer2) + layers
if layers[0].name == "add":
layers[0].layers.extend(layers[1:])
return layers[0]
return Model(
Expand Down
69 changes: 16 additions & 53 deletions thinc/layers/chain.py → thinc/layers/chain_module.py
Expand Up @@ -3,23 +3,33 @@
from ..model import Model
from ..config import registry
from ..util import get_width
from ..types import Ragged, Padded
from ..types import Ragged, Padded, Reduced_OutT


InT = TypeVar("InT")
OutT = TypeVar("OutT")
Mid1T = TypeVar("Mid1T")
Mid2T = TypeVar("Mid2T")


# This implementation is named 'chains' because we have a type-shennanigans
# function 'chain' below.
# TODO: Unhack this when we can
# We currently have an issue with Pydantic when arguments have generic types.
# https://github.com/samuelcolvin/pydantic/issues/1158
# For now we work around the issue by applying the decorator to this blander
# version of the function.
@registry.layers("chain.v0")
def chains(*layers: Model) -> Model[InT, Any]:
def chain_no_types(*layer: Model) -> Model:
return chain(*layer)


def chain(
layer1: Model[InT, Mid1T], layer2: Model[Mid1T, OutT], *layers: Model
) -> Model[InT, Reduced_OutT]:
"""Compose two models `f` and `g` such that they become layers of a single
feed-forward model that computes `g(f(x))`.
Also supports chaining more than 2 layers.
"""
if len(layers) < 2: # we need variable arguments for the config
raise TypeError("The 'chain' combinator needs at least 2 layers")
layers = (layer1, layer2) + layers
model: Model[InT, Any] = Model(
">>".join(layer.name for layer in layers),
forward,
Expand Down Expand Up @@ -95,50 +105,3 @@ def init(model: Model, X: Optional[InT] = None, Y: Optional[OutT] = None) -> Non
layers_with_nO = [lyr for lyr in model.layers if lyr.has_dim("nO")]
if layers_with_nO:
model.set_dim("nO", layers_with_nO[-1].get_dim("nO"))


# Unfortunately mypy doesn't support type-level checking on the cardinality
# of variadic arguments: in other words, if you have an *args, you can't have
# a type-checked condition on len(args). But we *can* get sneaky:
# you can have a type-checked condition on *optional* args, and these *will*
# get read by mypy. Hence the trickery below.

Mid1T = TypeVar("Mid1T")
Mid2T = TypeVar("Mid2T")
Mid3T = TypeVar("Mid3T")
Mid4T = TypeVar("Mid4T")
Mid5T = TypeVar("Mid5T")
Mid6T = TypeVar("Mid6T")
Mid7T = TypeVar("Mid7T")
Mid8T = TypeVar("Mid8T")
Mid9T = TypeVar("Mid9T")


def chain(
l1: Model[InT, Mid1T],
l2: Model[Mid1T, Mid2T],
l3: Optional[Model[Mid2T, Mid3T]] = None,
l4: Optional[Model[Mid3T, Mid4T]] = None,
l5: Optional[Model[Mid4T, Mid5T]] = None,
l6: Optional[Model[Mid5T, Mid6T]] = None,
l7: Optional[Model[Mid6T, Mid7T]] = None,
l8: Optional[Model[Mid7T, Mid8T]] = None,
l9: Optional[Model[Mid8T, Mid9T]] = None,
*etc: Model
) -> Model[InT, Any]: # pragma: no cover
if l3 is None:
return chains(l1, l2)
elif l4 is None:
return chains(l1, l2, l3)
elif l5 is None:
return chains(l1, l2, l3, l4)
elif l6 is None:
return chains(l1, l2, l3, l4, l5)
elif l7 is None:
return chains(l1, l2, l3, l4, l5, l6)
elif l8 is None:
return chains(l1, l2, l3, l4, l5, l6, l7)
elif l9 is None:
return chains(l1, l2, l3, l4, l5, l6, l7, l8)
else:
return chains(l1, l2, l3, l4, l5, l6, l7, l8, l9, *etc)
2 changes: 1 addition & 1 deletion thinc/layers/clone.py
@@ -1,7 +1,7 @@
from typing import TypeVar, cast, List

from .noop import noop
from .chain import chain
from .chain_module import chain
from ..model import Model
from ..config import registry

Expand Down
2 changes: 1 addition & 1 deletion thinc/layers/maxout.py
Expand Up @@ -7,7 +7,7 @@
from ..util import get_width
from .dropout import Dropout
from .layernorm import LayerNorm
from .chain import chain
from .chain_module import chain


InT = Array2d
Expand Down
2 changes: 1 addition & 1 deletion thinc/layers/mish.py
Expand Up @@ -4,7 +4,7 @@
from ..initializers import xavier_uniform_init, zero_init
from ..config import registry
from ..types import Array2d
from .chain import chain
from .chain_module import chain
from .layernorm import LayerNorm
from .dropout import Dropout

Expand Down
2 changes: 1 addition & 1 deletion thinc/layers/relu.py
Expand Up @@ -4,7 +4,7 @@
from ..initializers import xavier_uniform_init, zero_init
from ..config import registry
from ..types import Array2d
from .chain import chain
from .chain_module import chain
from .layernorm import LayerNorm
from .dropout import Dropout

Expand Down
99 changes: 99 additions & 0 deletions thinc/mypy.py
@@ -0,0 +1,99 @@
from mypy.errorcodes import ErrorCode
from mypy.options import Options
from mypy.plugin import FunctionContext, Plugin, CheckerPluginInterface
from mypy.types import Instance, Type, CallableType, TypeVarType
from mypy.nodes import Expression, CallExpr, NameExpr, FuncDef, Decorator

thinc_model_fullname = "thinc.model.Model"


def plugin(version: str):
return ThincPlugin


class ThincPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)

def get_function_hook(self, fullname: str):
return function_hook


def function_hook(ctx: FunctionContext) -> Type:
try:
return get_reducers_type(ctx)
except AssertionError:
# Add more function callbacks here
return ctx.default_return_type


def get_reducers_type(ctx: FunctionContext) -> Type:
assert len(ctx.args) == 3
assert isinstance(ctx.context, CallExpr)
assert isinstance(ctx.context.callee, NameExpr)
assert isinstance(ctx.context.callee.node, (FuncDef, Decorator))
assert isinstance(ctx.context.callee.node.type, CallableType)
assert isinstance(ctx.context.callee.node.type.ret_type, Instance)
assert ctx.context.callee.node.type.ret_type.args
assert len(ctx.context.callee.node.type.ret_type.args) == 2
out_type = ctx.context.callee.node.type.ret_type.args[1]
assert isinstance(out_type, TypeVarType)
assert out_type.fullname
if not out_type.fullname == "thinc.types.Reduced_OutT":
return ctx.default_return_type
l1_args, l2_args, layers_args = ctx.args
l1_types, l2_types, layers_types = ctx.arg_types
l1_type_instance = l1_types[0]
l2_type_instance = l2_types[0]
l1_arg = l1_args[0]
l2_arg = l2_args[0]
assert isinstance(l1_type_instance, Instance)
assert isinstance(l2_type_instance, Instance)
assert isinstance(ctx.default_return_type, Instance)
assert l1_type_instance.type.fullname == thinc_model_fullname
assert l2_type_instance.type.fullname == thinc_model_fullname
arg_in_type = l1_type_instance.args[0]
arg_out_type = l2_type_instance.args[1]
reduce_2_layers(
l1_arg=l1_arg,
l1_type=l1_type_instance,
l2_arg=l2_arg,
l2_type=l2_type_instance,
api=ctx.api,
)
last_arg = l2_arg
last_type = l2_type_instance
for arg, type_ in zip(layers_args, layers_types):
assert isinstance(type_, Instance)
reduce_2_layers(
l1_arg=last_arg, l1_type=last_type, l2_arg=arg, l2_type=type_, api=ctx.api,
)
last_arg = arg
last_type = type_
arg_out_type = type_.args[1]
return Instance(ctx.default_return_type.type, [arg_in_type, arg_out_type])


def reduce_2_layers(
*,
l1_arg: Expression,
l1_type: Instance,
l2_arg: Expression,
l2_type: Instance,
api: CheckerPluginInterface
):
if l1_type.args[1] != l2_type.args[0]:
api.fail(
"Layer mismatch, output not compatible with next layer",
l1_arg,
code=error_layer_output,
)
api.fail(
"Layer mismatch, input not compatible with previous layer",
l2_arg,
code=error_layer_input,
)


error_layer_input = ErrorCode("layer-mismatch-input", "Invalid layer input", "Thinc")
error_layer_output = ErrorCode("layer-mismatch-output", "Invalid layer output", "Thinc")
Empty file added thinc/py.typed
Empty file.
6 changes: 3 additions & 3 deletions thinc/tests/layers/test_combinators.py
@@ -1,8 +1,8 @@
import pytest
import numpy
from thinc.api import chain, clone, concatenate, noop, add
from thinc.api import clone, concatenate, noop, add
from thinc.api import Linear, Dropout, Model
from thinc.layers.chain import chains
from thinc.layers import chain


@pytest.fixture(params=[1, 2, 9])
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_chain():
with pytest.raises(TypeError):
chain(Linear())
with pytest.raises(TypeError):
chains()
chain()


def test_concatenate_one(model1):
Expand Down
Empty file added thinc/tests/mypy/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions thinc/tests/mypy/configs/mypy-default.ini
@@ -0,0 +1,8 @@
[mypy]
follow_imports = silent
strict_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
# disallow_any_generics = True
check_untyped_defs = True
disallow_untyped_defs = True
10 changes: 10 additions & 0 deletions thinc/tests/mypy/configs/mypy-plugin.ini
@@ -0,0 +1,10 @@
[mypy]
plugins = thinc.mypy

follow_imports = silent
strict_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
# disallow_any_generics = True
check_untyped_defs = True
disallow_untyped_defs = True
Empty file.
5 changes: 5 additions & 0 deletions thinc/tests/mypy/modules/fail_no_plugin.py
@@ -0,0 +1,5 @@
from thinc.api import chain, ReLu, MaxPool, Softmax, add

bad_model = chain(ReLu(10), MaxPool(), Softmax())

bad_model2 = add(ReLu(10), MaxPool(), Softmax())
13 changes: 13 additions & 0 deletions thinc/tests/mypy/modules/fail_plugin.py
@@ -0,0 +1,13 @@
from thinc.api import chain, ReLu, MaxPool, Softmax, add

bad_model = chain(ReLu(10), MaxPool(), Softmax())

bad_model2 = add(ReLu(10), MaxPool(), Softmax())

bad_model_only_plugin = chain(
ReLu(10), ReLu(10), ReLu(10), ReLu(10), MaxPool(), Softmax()
)

bad_model_only_plugin2 = add(
ReLu(10), ReLu(10), ReLu(10), ReLu(10), MaxPool(), Softmax()
)
13 changes: 13 additions & 0 deletions thinc/tests/mypy/modules/success_no_plugin.py
@@ -0,0 +1,13 @@
from thinc.api import chain, ReLu, MaxPool, Softmax, add

good_model = chain(ReLu(10), ReLu(10), Softmax())
reveal_type(good_model)

good_model2 = add(ReLu(10), ReLu(10), Softmax())
reveal_type(good_model2)

bad_model_undetected = chain(ReLu(10), ReLu(10), MaxPool(), Softmax())
reveal_type(bad_model_undetected)

bad_model_undetected2 = add(ReLu(10), ReLu(10), MaxPool(), Softmax())
reveal_type(bad_model_undetected2)
34 changes: 34 additions & 0 deletions thinc/tests/mypy/modules/success_plugin.py
@@ -0,0 +1,34 @@
from typing import Any, TypeVar

from thinc.api import chain, ReLu, MaxPool, Softmax, add, Model

good_model = chain(ReLu(10), ReLu(10), Softmax())
reveal_type(good_model)

good_model2 = add(ReLu(10), ReLu(10), Softmax())
reveal_type(good_model2)

bad_model_undetected = chain(ReLu(10), ReLu(10), ReLu(10), ReLu(10), Softmax())
reveal_type(bad_model_undetected)

bad_model_undetected2 = add(ReLu(10), ReLu(10), ReLu(10), ReLu(10), Softmax())
reveal_type(bad_model_undetected2)


def forward() -> None:
pass


OtherType = TypeVar("OtherType")


def other_function(
layer1: Model, layer2: Model, *layers: Model
) -> Model[Any, OtherType]:
return Model("some_model", forward)


non_combinator_model = other_function(
Model("x", forward), Model("y", forward), Model("z", forward)
)
reveal_type(non_combinator_model)
2 changes: 2 additions & 0 deletions thinc/tests/mypy/outputs/fail-no-plugin.txt
@@ -0,0 +1,2 @@
3: error: Cannot infer type argument 2 of "chain" [misc]
5: error: Cannot infer type argument 1 of "add" [misc]

0 comments on commit 522aff0

Please sign in to comment.