Skip to content
Permalink
Browse files

Re-write Annotated.Class.(implements using attributes)

Summary: Instead of reading the AST directly, use the `attributes` function to check if a class implements a protocol.

Reviewed By: sinancepel

Differential Revision: D10362812

fbshipit-source-id: 686c6a220ecd0409768432064eeca1c5a2df4f6b
  • Loading branch information...
tekknolagi authored and facebook-github-bot committed Nov 13, 2018
1 parent b05234f commit f14577db5940c0b4087fffe209786cd4075f37df
@@ -156,20 +156,6 @@ module Method = struct
end
else
annotation


let implements
{ define; _ }
~protocol_method:{ define = protocol; _ } =
let open Define in
let parameter_equal
{ Node.value = { Parameter.annotation; _ }; _ }
{ Node.value = { Parameter.annotation = protocol_annotation; _ }; _ } =
Option.equal Expression.equal annotation protocol_annotation
in
Access.equal (Define.unqualified_name define) (Define.unqualified_name protocol) &&
Option.equal Expression.equal define.return_annotation protocol.return_annotation &&
List.equal ~equal:parameter_equal define.parameters protocol.parameters
end


@@ -352,23 +338,6 @@ let is_protocol { Node.value = { Class.bases; _ }; _ } =
List.exists ~f:is_protocol bases


let implements definition ~protocol =
let rec implements instance_methods protocol_methods =
match instance_methods, protocol_methods with
| _, [] ->
true
| [], _ :: _ ->
false
| instance_method :: instance_methods,
((protocol_method :: protocol_methods) as old_protocol_methods) ->
if Method.implements ~protocol_method instance_method then
implements instance_methods protocol_methods
else
implements instance_methods old_protocol_methods
in
implements (methods definition) (methods protocol)


module Attribute = struct
type attribute = {
name: Expression.expression;
@@ -714,6 +683,46 @@ let attributes
Hashtbl.set ~key ~data:result Attribute.Cache.cache;
result

let implements ~resolution definition ~protocol =
let overload_implements (name, overload) (protocol_name, protocol_overload) =
let open Type.Callable in
Access.equal name protocol_name &&
Type.equal overload.annotation protocol_overload.annotation &&
equal_parameters Type.equal overload.parameters protocol_overload.parameters
in
let rec implements instance_methods protocol_methods =
match instance_methods, protocol_methods with
| _, [] ->
true
| [], _ :: _ ->
false
| instance_method :: instance_methods,
((protocol_method :: protocol_methods) as old_protocol_methods) ->
if overload_implements instance_method protocol_method then
implements instance_methods protocol_methods
else
implements instance_methods old_protocol_methods
in
let callables_of_attribute =
function
| { Node.value =
{ Attribute.annotation = {
Annotation.annotation = Type.Callable {
kind = Type.Record.Callable.Named callable_name;
overloads;
_ };
_ };
parent;
_ }; _ } ->
let local_name = Access.drop_prefix ~prefix:(name parent) callable_name in
List.map ~f:(fun overload -> (local_name, overload)) overloads
| _ -> []
in
let definition_attributes = attributes ~resolution definition in
let protocol_attributes = attributes ~resolution protocol in
implements
(List.concat_map ~f:callables_of_attribute definition_attributes)
(List.concat_map ~f:callables_of_attribute protocol_attributes)

let attribute_fold
?(transitive = false)
@@ -54,7 +54,6 @@ module Method : sig
-> Type.t Int.Map.t
val return_annotation: t -> resolution: Resolution.t -> Type.t

val implements: t -> protocol_method: t -> bool
end

val generics: t -> resolution: Resolution.t -> Type.t list
@@ -88,7 +87,7 @@ val immediate_superclasses
val methods: t -> Method.t list

val is_protocol: t -> bool
val implements: t -> protocol: t -> bool
val implements: resolution: Resolution.t -> t -> protocol: t -> bool

module Attribute : sig
type attribute = {
@@ -871,7 +871,7 @@ let infer_implementations (module Handler: Handler) ~implementing_classes ~proto
>>| Class.create
>>| (fun definition ->
not (Class.is_protocol definition) &&
Class.implements ~protocol:protocol_definition definition)
Class.implements ~resolution ~protocol:protocol_definition definition)
|> Option.value ~default:false
in
List.filter ~f:implements classes_to_analyze
@@ -418,20 +418,40 @@ let test_is_protocol _ =


let test_implements _ =
(* TODO(T36516076) Adapt assert_conforms to fit testing idioms *)
let assert_conforms definition protocol conforms =
match parse_last_statement definition with
| { Node.value = Statement.Class definition; _ } ->
begin
match parse_last_statement protocol with
| { Node.value = Statement.Class protocol; _ } ->
assert_equal
(Class.implements
~protocol:(Class.create (Node.create_with_default_location protocol))
(Class.create (Node.create_with_default_location definition)))
conforms
| _ ->
assert_unreached ()
end
let get_last_statement { Source.statements; _ } =
List.last_exn statements
in
let environment = Environment.Builder.create () in
let definition =
definition
|> parse
|> Preprocessing.preprocess
in
let protocol =
protocol
|> parse
|> Preprocessing.preprocess
in
Service.Environment.populate
(Environment.handler ~configuration environment)
(definition :: protocol :: Test.typeshed_stubs);
let ((module Handler: Environment.Handler) as handler) =
Environment.handler environment ~configuration
in
let resolution = Environment.resolution handler () in
Annotated.Class.Attribute.Cache.clear ();
match definition |> get_last_statement,
protocol |> get_last_statement with
| { Node.value = Statement.Class definition; _ },
{ Node.value = Statement.Class protocol; _ } ->
assert_equal
(Class.implements
~resolution
~protocol:(Class.create (Node.create_with_default_location protocol))
(Class.create (Node.create_with_default_location definition)))
conforms
| _ ->
assert_unreached ()
in
@@ -480,7 +500,50 @@ let test_implements _ =
def foo(): pass
def bar(): pass
|}
true
true;
assert_conforms
{|
class List():
def empty() -> bool: pass
def length() -> int: pass
|}
{|
class Sized(typing.Protocol):
def empty() -> bool: pass
def len() -> int: pass
|}
false;
assert_conforms
{|
class List():
def empty() -> bool: pass
@typing.overload
def length(x: int) -> str: pass
def length() -> str: pass
def length(x: int) -> int: pass
def length() -> int: pass
|}
{|
class Sized(typing.Protocol):
def empty() -> bool: pass
def length() -> int: pass
|}
true;
assert_conforms
{|
class List():
def empty() -> bool: pass
@typing.overload
def length(x: int) -> str: pass
def length() -> str: pass
def length(x: int) -> int: pass
|}
{|
class Sized(typing.Protocol):
def empty() -> bool: pass
def length() -> int: pass
|}
false


let test_class_attributes _ =
@@ -1137,106 +1200,6 @@ let test_overrides _ =
(Access.create "Foo")


let test_method_implements _ =
let definition ?(parameters = []) ?return_annotation name =
Method.create
~define:{
Statement.Define.name = Access.create name;
parameters;
body = [+Pass];
decorators = [];
docstring = None;
return_annotation;
async = false;
generated = false;
parent = Some (Access.create "Parent");
}
~parent:
(Class.create
(Node.create_with_default_location
{
Statement.Class.name = Access.create "Parent";
bases = [];
body = [+Pass];
decorators = [];
docstring = None;
}))
in

assert_true
(Method.implements
~protocol_method:(definition "match")
(definition "match"));
assert_false
(Method.implements
~protocol_method:(definition "mismatch")
(definition "match"));

let parameters =
[
Parameter.create ~name:(~~"a") ();
Parameter.create ~name:(~~"b") ();
]
in
assert_true
(Method.implements
~protocol_method:(definition ~parameters "match")
(definition ~parameters "match"));

(* Naming of parameters doesn't matter. *)
let definition_parameters =
[
Parameter.create ~name:(~~"a") ();
Parameter.create ~name:(~~"b") ();
]
in
let protocol_parameters =
[
Parameter.create ~name:(~~"a") ();
Parameter.create ~name:(~~"c") ();
]
in
assert_true
(Method.implements
~protocol_method:(definition ~parameters:protocol_parameters "match")
(definition ~parameters:definition_parameters "match"));

(* Number of parameters, parameter and return annotations matter. *)
let definition_parameters =
[
Parameter.create ~name:(~~"a") ();
Parameter.create ~name:(~~"b") ();
]
in
let protocol_parameters =
[
Parameter.create ~name:(~~"a") ();
Parameter.create ~name:(~~"c") ();
Parameter.create ~name:(~~"z") ();
]
in
assert_false
(Method.implements
~protocol_method:(definition ~parameters:protocol_parameters "match")
(definition ~parameters:definition_parameters "match"));

let definition_parameters = [Parameter.create ~name:(~~"a") ~annotation:!"int" ()] in
let protocol_parameters = [Parameter.create ~name:(~~"a") ()] in
assert_false
(Method.implements
~protocol_method:(definition ~parameters:protocol_parameters "match")
(definition ~parameters:definition_parameters "match"));

assert_true
(Method.implements
~protocol_method:(definition ~return_annotation:!"int" "match")
(definition ~return_annotation:!"int" "match"));
assert_false
(Method.implements
~protocol_method:(definition ~return_annotation:!"int" "match")
(definition ~return_annotation:!"float" "match"))


let () =
"class">:::[
"attributes">::test_class_attributes;
@@ -1253,8 +1216,4 @@ let () =
"overrides">::test_overrides;
"superclasses">::test_superclasses;
]
|> Test.run;
"method">:::[
"implements">::test_method_implements;
]
|> Test.run
@@ -44,6 +44,12 @@ let populate_with_sources sources =
let populate source =
populate_with_sources [parse source]

let populate_preprocess source =
populate_with_sources [
source
|> parse
|> Preprocessing.preprocess
]

let global environment =
Environment.resolution environment ()
@@ -785,10 +791,10 @@ let test_populate _ =
|> Annotation.create_immutable ~global:true ~original:(Some Type.Top))
let test_infer_protocols _ =
let test_infer_protocols_edges _ =
let edges =
let environment =
populate {|
populate_preprocess {|
class Empty(typing.Protocol):
pass
class Sized(typing.Protocol):
@@ -1349,6 +1355,7 @@ let test_infer_protocols _ =
let configuration = Configuration.Analysis.create () in
let type_sources = Test.typeshed_stubs in
let assert_protocols ?classes_to_infer source expected_edges =
Annotated.Class.Attribute.Cache.clear ();
let expected_edges =
let to_edge (source, target) =
{
@@ -1463,7 +1470,7 @@ let test_infer_protocols _ =
def foo() -> str:
pass
|}
["A", "P"]
["A", "P"; "C", "P"]
let () =
@@ -1479,7 +1486,7 @@ let () =
"class_definition">::test_class_definition;
"connect_definition">::test_connect_definition;
"import_dependencies">::test_import_dependencies;
"infer_protocols">::test_infer_protocols;
"infer_protocols_edges">::test_infer_protocols_edges;
"infer_protocols">::test_infer_protocols;
"modules">::test_modules;
"populate">::test_populate;

0 comments on commit f14577d

Please sign in to comment.
You can’t perform that action at this time.