forked from python/mypy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
expandtype.py
145 lines (119 loc) · 4.81 KB
/
expandtype.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from mtypes import (
Typ, Instance, Callable, TypeVisitor, UnboundType, ErrorType, Any, Void,
NoneTyp, TypeVar, Overloaded, TupleType, ErasedType
)
Typ expand_type(Typ typ, dict<int, Typ> map):
"""Expand any type variable references in a type with the actual values of
type variables in an instance.
"""
return typ.accept(ExpandTypeVisitor(map))
Typ expand_type_by_instance(Typ typ, Instance instance):
"""Expand type variables in type using type variable values in an
instance."""
if instance.args == []:
return typ
else:
dict<int, Typ> variables = {}
for i in range(len(instance.args)):
variables[i + 1] = instance.args[i]
typ = expand_type(typ, variables)
if isinstance(typ, Callable):
list<tuple<int, Typ>> bounds = []
for j in range(len(instance.args)):
bounds.append((j + 1, instance.args[j]))
typ = update_callable_implicit_bounds((Callable)typ, bounds)
else:
pass
return typ
class ExpandTypeVisitor(TypeVisitor<Typ>):
dict<int, Typ> variables # Lower bounds
void __init__(self, dict<int, Typ> variables):
self.variables = variables
Typ visit_unbound_type(self, UnboundType t):
return t
Typ visit_error_type(self, ErrorType t):
return t
Typ visit_any(self, Any t):
return t
Typ visit_void(self, Void t):
return t
Typ visit_none_type(self, NoneTyp t):
return t
Typ visit_erased_type(self, ErasedType t):
# Should not get here.
raise RuntimeError()
Typ visit_instance(self, Instance t):
args = self.expand_types(t.args)
return Instance(t.typ, args, t.line, t.repr)
Typ visit_type_var(self, TypeVar t):
repl = self.variables.get(t.id, t)
if isinstance(repl, Instance):
inst = (Instance)repl
# Return copy of instance with type erasure flag on.
return Instance(inst.typ, inst.args, inst.line, inst.repr, True)
else:
return repl
Typ visit_callable(self, Callable t):
return Callable(self.expand_types(t.arg_types),
t.arg_kinds,
t.arg_names,
t.ret_type.accept(self),
t.is_type_obj(),
t.name,
t.variables,
self.expand_bound_vars(t.bound_vars), t.line, t.repr)
Typ visit_overloaded(self, Overloaded t):
Callable[] items = []
for item in t.items():
items.append((Callable)item.accept(self))
return Overloaded(items)
Typ visit_tuple_type(self, TupleType t):
return TupleType(self.expand_types(t.items), t.line, t.repr)
Typ[] expand_types(self, Typ[] types):
Typ[] a = []
for t in types:
a.append(t.accept(self))
return a
list<tuple<int, Typ>> expand_bound_vars(self, list<tuple<int, Typ>> types):
list<tuple<int, Typ>> a = []
for id, t in types:
a.append((id, t.accept(self)))
return a
Callable update_callable_implicit_bounds(Callable t,
list<tuple<int, Typ>> arg_types):
# FIX what if there are existing bounds?
return Callable(t.arg_types,
t.arg_kinds,
t.arg_names,
t.ret_type,
t.is_type_obj(),
t.name,
t.variables,
arg_types, t.line, t.repr)
tuple<Typ[], Typ> expand_caller_var_args(Typ[] arg_types,
int fixed_argc):
"""Expand the caller argument types in a varargs call. Fixedargc
is the maximum number of fixed arguments that the target function
accepts.
Return (fixed argument types, type of the rest of the arguments). Return
(None, None) if the last (vararg) argument had an invalid type. If the
vararg argument was not an array (nor dynamic), the last item in the
returned tuple is None.
"""
if isinstance(arg_types[-1], TupleType):
return arg_types[:-1] + ((TupleType)arg_types[-1]).items, None
else:
Typ item_type
if isinstance(arg_types[-1], Any):
item_type = Any()
elif isinstance(arg_types[-1], Instance) and (
((Instance)arg_types[-1]).typ.full_name() == 'builtins.list'):
# List.
item_type = ((Instance)arg_types[-1]).args[0]
else:
return None, None
if len(arg_types) > fixed_argc:
return arg_types[:-1], item_type
else:
return (arg_types[:-1] +
[item_type] * (fixed_argc - len(arg_types) + 1), item_type)