Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement first version of mypy plugin for chain and add (#201)
* ✨ Implement first version of mypy plugin for chain * ♻️ Refactor mypy plugin to work with original chain signature * 🚚 Rename chain to be a single function + mypy plugin * 🐛 Use correct types for mypy plugin, fix type checks * 🔀 Merge develop * ✅ Update tests to use singular chain function * ♻️ Refactor mypy plugin to generalize to reducers (chain, add) * ✨ Add reducer types to add, chain, to trigger mypy plugin * ✅ Add mypy plugin tests * 🚚 Rename chain module to chain_module and add to add_module * 🔧 Add mypy tests init file * 🎨 Format mypy test * ✅ Update mypy tests to raise mypy coverage to 100% * 🔊 Add Azure debugging logs * ✅ Try to convince Azure to run mypy tests from the right directory * 💚 Add mypy tests to dist package * 💚 Ignore mypy for mypy tests (that are made to error out) * 🔧 Make Thinc compatible with PEP 561, declares its own types * ✅ Refactor mypy tests to run on tempdir and be compatible with package * ✨ Add support for registry decorators with mypy plugin And revert add with secondary function as it now seems to work with a different module name * 🔧 Add thinc mypy plugin to default setup.cfg for internal development Co-authored-by: Ines Montani <ines@ines.io> Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
- Loading branch information
1 parent
100ac83
commit 522aff0
Showing
26 changed files
with
322 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from mypy.errorcodes import ErrorCode | ||
from mypy.options import Options | ||
from mypy.plugin import FunctionContext, Plugin, CheckerPluginInterface | ||
from mypy.types import Instance, Type, CallableType, TypeVarType | ||
from mypy.nodes import Expression, CallExpr, NameExpr, FuncDef, Decorator | ||
|
||
thinc_model_fullname = "thinc.model.Model" | ||
|
||
|
||
def plugin(version: str): | ||
return ThincPlugin | ||
|
||
|
||
class ThincPlugin(Plugin): | ||
def __init__(self, options: Options) -> None: | ||
super().__init__(options) | ||
|
||
def get_function_hook(self, fullname: str): | ||
return function_hook | ||
|
||
|
||
def function_hook(ctx: FunctionContext) -> Type: | ||
try: | ||
return get_reducers_type(ctx) | ||
except AssertionError: | ||
# Add more function callbacks here | ||
return ctx.default_return_type | ||
|
||
|
||
def get_reducers_type(ctx: FunctionContext) -> Type: | ||
assert len(ctx.args) == 3 | ||
assert isinstance(ctx.context, CallExpr) | ||
assert isinstance(ctx.context.callee, NameExpr) | ||
assert isinstance(ctx.context.callee.node, (FuncDef, Decorator)) | ||
assert isinstance(ctx.context.callee.node.type, CallableType) | ||
assert isinstance(ctx.context.callee.node.type.ret_type, Instance) | ||
assert ctx.context.callee.node.type.ret_type.args | ||
assert len(ctx.context.callee.node.type.ret_type.args) == 2 | ||
out_type = ctx.context.callee.node.type.ret_type.args[1] | ||
assert isinstance(out_type, TypeVarType) | ||
assert out_type.fullname | ||
if not out_type.fullname == "thinc.types.Reduced_OutT": | ||
return ctx.default_return_type | ||
l1_args, l2_args, layers_args = ctx.args | ||
l1_types, l2_types, layers_types = ctx.arg_types | ||
l1_type_instance = l1_types[0] | ||
l2_type_instance = l2_types[0] | ||
l1_arg = l1_args[0] | ||
l2_arg = l2_args[0] | ||
assert isinstance(l1_type_instance, Instance) | ||
assert isinstance(l2_type_instance, Instance) | ||
assert isinstance(ctx.default_return_type, Instance) | ||
assert l1_type_instance.type.fullname == thinc_model_fullname | ||
assert l2_type_instance.type.fullname == thinc_model_fullname | ||
arg_in_type = l1_type_instance.args[0] | ||
arg_out_type = l2_type_instance.args[1] | ||
reduce_2_layers( | ||
l1_arg=l1_arg, | ||
l1_type=l1_type_instance, | ||
l2_arg=l2_arg, | ||
l2_type=l2_type_instance, | ||
api=ctx.api, | ||
) | ||
last_arg = l2_arg | ||
last_type = l2_type_instance | ||
for arg, type_ in zip(layers_args, layers_types): | ||
assert isinstance(type_, Instance) | ||
reduce_2_layers( | ||
l1_arg=last_arg, l1_type=last_type, l2_arg=arg, l2_type=type_, api=ctx.api, | ||
) | ||
last_arg = arg | ||
last_type = type_ | ||
arg_out_type = type_.args[1] | ||
return Instance(ctx.default_return_type.type, [arg_in_type, arg_out_type]) | ||
|
||
|
||
def reduce_2_layers( | ||
*, | ||
l1_arg: Expression, | ||
l1_type: Instance, | ||
l2_arg: Expression, | ||
l2_type: Instance, | ||
api: CheckerPluginInterface | ||
): | ||
if l1_type.args[1] != l2_type.args[0]: | ||
api.fail( | ||
"Layer mismatch, output not compatible with next layer", | ||
l1_arg, | ||
code=error_layer_output, | ||
) | ||
api.fail( | ||
"Layer mismatch, input not compatible with previous layer", | ||
l2_arg, | ||
code=error_layer_input, | ||
) | ||
|
||
|
||
error_layer_input = ErrorCode("layer-mismatch-input", "Invalid layer input", "Thinc") | ||
error_layer_output = ErrorCode("layer-mismatch-output", "Invalid layer output", "Thinc") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[mypy] | ||
follow_imports = silent | ||
strict_optional = True | ||
warn_redundant_casts = True | ||
warn_unused_ignores = True | ||
# disallow_any_generics = True | ||
check_untyped_defs = True | ||
disallow_untyped_defs = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[mypy] | ||
plugins = thinc.mypy | ||
|
||
follow_imports = silent | ||
strict_optional = True | ||
warn_redundant_casts = True | ||
warn_unused_ignores = True | ||
# disallow_any_generics = True | ||
check_untyped_defs = True | ||
disallow_untyped_defs = True |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from thinc.api import chain, ReLu, MaxPool, Softmax, add | ||
|
||
bad_model = chain(ReLu(10), MaxPool(), Softmax()) | ||
|
||
bad_model2 = add(ReLu(10), MaxPool(), Softmax()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from thinc.api import chain, ReLu, MaxPool, Softmax, add | ||
|
||
bad_model = chain(ReLu(10), MaxPool(), Softmax()) | ||
|
||
bad_model2 = add(ReLu(10), MaxPool(), Softmax()) | ||
|
||
bad_model_only_plugin = chain( | ||
ReLu(10), ReLu(10), ReLu(10), ReLu(10), MaxPool(), Softmax() | ||
) | ||
|
||
bad_model_only_plugin2 = add( | ||
ReLu(10), ReLu(10), ReLu(10), ReLu(10), MaxPool(), Softmax() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from thinc.api import chain, ReLu, MaxPool, Softmax, add | ||
|
||
good_model = chain(ReLu(10), ReLu(10), Softmax()) | ||
reveal_type(good_model) | ||
|
||
good_model2 = add(ReLu(10), ReLu(10), Softmax()) | ||
reveal_type(good_model2) | ||
|
||
bad_model_undetected = chain(ReLu(10), ReLu(10), MaxPool(), Softmax()) | ||
reveal_type(bad_model_undetected) | ||
|
||
bad_model_undetected2 = add(ReLu(10), ReLu(10), MaxPool(), Softmax()) | ||
reveal_type(bad_model_undetected2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any, TypeVar | ||
|
||
from thinc.api import chain, ReLu, MaxPool, Softmax, add, Model | ||
|
||
good_model = chain(ReLu(10), ReLu(10), Softmax()) | ||
reveal_type(good_model) | ||
|
||
good_model2 = add(ReLu(10), ReLu(10), Softmax()) | ||
reveal_type(good_model2) | ||
|
||
bad_model_undetected = chain(ReLu(10), ReLu(10), ReLu(10), ReLu(10), Softmax()) | ||
reveal_type(bad_model_undetected) | ||
|
||
bad_model_undetected2 = add(ReLu(10), ReLu(10), ReLu(10), ReLu(10), Softmax()) | ||
reveal_type(bad_model_undetected2) | ||
|
||
|
||
def forward() -> None: | ||
pass | ||
|
||
|
||
OtherType = TypeVar("OtherType") | ||
|
||
|
||
def other_function( | ||
layer1: Model, layer2: Model, *layers: Model | ||
) -> Model[Any, OtherType]: | ||
return Model("some_model", forward) | ||
|
||
|
||
non_combinator_model = other_function( | ||
Model("x", forward), Model("y", forward), Model("z", forward) | ||
) | ||
reveal_type(non_combinator_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
3: error: Cannot infer type argument 2 of "chain" [misc] | ||
5: error: Cannot infer type argument 1 of "add" [misc] |
Oops, something went wrong.