-
Notifications
You must be signed in to change notification settings - Fork 238
/
InferType.scala
291 lines (289 loc) · 11.5 KB
/
InferType.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
package is.hail.expr.ir
import is.hail.expr.Nat
import is.hail.types.tcoerce
import is.hail.types.virtual._
import is.hail.utils._
object InferType {
def apply(ir: IR): Type = {
ir match {
case I32(_) => TInt32
case I64(_) => TInt64
case F32(_) => TFloat32
case F64(_) => TFloat64
case Str(_) => TString
case UUID4(_) => TString
case Literal(t, _) => t
case EncodedLiteral(codec, _) => codec.encodedVirtualType
case True() | False() => TBoolean
case Void() => TVoid
case Cast(_, t) => t
case CastRename(_, t) => t
case NA(t) => t
case IsNA(_) => TBoolean
case Coalesce(values) => values.head.typ
case Consume(_) => TInt64
case Ref(_, t) => t
case RelationalRef(_, t) => t
case RelationalLet(_, _, body) => body.typ
case In(_, t) => t.virtualType
case MakeArray(_, t) => t
case MakeStream(_, t, _) => t
case MakeNDArray(data, shape, _, _) =>
TNDArray(tcoerce[TIterable](data.typ).elementType, Nat(shape.typ.asInstanceOf[TTuple].size))
case StreamBufferedAggregate(_, _, newKey, _, _, aggSignatures, _) =>
val tupleFieldTypes = TTuple(aggSignatures.map(_ => TBinary):_*)
TStream(newKey.typ.asInstanceOf[TStruct].insertFields(IndexedSeq(("agg", tupleFieldTypes))))
case _: ArrayLen => TInt32
case _: StreamIota => TStream(TInt32)
case _: StreamRange => TStream(TInt32)
case _: SeqSample => TStream(TInt32)
case _: ArrayZeros => TArray(TInt32)
case _: LowerBoundOnOrderedCollection => TInt32
case _: StreamFor => TVoid
case _: InitOp => TVoid
case _: SeqOp => TVoid
case _: CombOp => TVoid
case ResultOp(_, aggSig) =>
aggSig.resultType
case AggStateValue(i, sig) => TBinary
case _: CombOpValue => TVoid
case _: InitFromSerializedValue => TVoid
case _: SerializeAggs => TVoid
case _: DeserializeAggs => TVoid
case _: Begin => TVoid
case Die(_, t, _) => t
case Trap(child) => TTuple(TTuple(TString, TInt32), child.typ)
case ConsoleLog(message, result) => result.typ
case If(cond, cnsq, altr) =>
assert(cond.typ == TBoolean)
assert(cnsq.typ == altr.typ)
cnsq.typ
case Let(name, value, body) =>
body.typ
case AggLet(name, value, body, _) =>
body.typ
case TailLoop(_, _, body) =>
body.typ
case Recur(_, _, typ) =>
typ
case ApplyBinaryPrimOp(op, l, r) =>
BinaryOp.getReturnType(op, l.typ, r.typ)
case ApplyUnaryPrimOp(op, v) =>
UnaryOp.getReturnType(op, v.typ)
case ApplyComparisonOp(op, l, r) =>
assert(l.typ == r.typ)
op match {
case _: Compare => TInt32
case _ => TBoolean
}
case a: ApplyIR => a.explicitNode.typ
case a: AbstractApplyNode[_] =>
val typeArgs = a.typeArgs
val argTypes = a.args.map(_.typ)
assert(a.implementation.unify(typeArgs, argTypes, a.returnType))
a.returnType
case ArrayRef(a, i, _) =>
assert(i.typ == TInt32)
tcoerce[TArray](a.typ).elementType
case ArraySlice(a, start, stop, step, _) =>
assert(start.typ == TInt32)
stop.foreach(ir => assert(ir.typ == TInt32))
assert(step.typ == TInt32)
tcoerce[TArray](a.typ)
case ArraySort(a, _, _, lessThan) =>
assert(lessThan.typ == TBoolean)
val et = tcoerce[TStream](a.typ).elementType
TArray(et)
case ArrayMaximalIndependentSet(edges, _) =>
val et = tcoerce[TArray](edges.typ).elementType.asInstanceOf[TBaseStruct].types.head
TArray(et)
case ToSet(a) =>
val et = tcoerce[TStream](a.typ).elementType
TSet(et)
case ToDict(a) =>
val elt = tcoerce[TBaseStruct](tcoerce[TStream](a.typ).elementType)
TDict(elt.types(0), elt.types(1))
case ta@ToArray(a) =>
val elt = tcoerce[TStream](a.typ).elementType
TArray(elt)
case CastToArray(a) =>
val elt = tcoerce[TContainer](a.typ).elementType
TArray(elt)
case ToStream(a, _) =>
val elt = tcoerce[TIterable](a.typ).elementType
TStream(elt)
case RNGStateLiteral() =>
TRNGState
case RNGSplit(_, _) =>
TRNGState
case StreamLen(a) => TInt32
case GroupByKey(collection) =>
val elt = tcoerce[TBaseStruct](tcoerce[TStream](collection.typ).elementType)
TDict(elt.types(0), TArray(elt.types(1)))
case StreamTake(a, _) =>
a.typ
case StreamDrop(a, _) =>
a.typ
case StreamGrouped(a, _) =>
TStream(a.typ)
case StreamGroupByKey(a, _, _) =>
TStream(a.typ)
case StreamMap(a, name, body) =>
TStream(body.typ)
case StreamZip(as, _, body, _, _) =>
TStream(body.typ)
case StreamZipJoin(_, _, _, _, joinF) =>
TStream(joinF.typ)
case StreamMultiMerge(as, _) =>
TStream(tcoerce[TStream](as.head.typ).elementType)
case StreamFilter(a, name, cond) =>
a.typ
case StreamTakeWhile(a, name, cond) =>
a.typ
case StreamDropWhile(a, name, cond) =>
a.typ
case StreamFlatMap(a, name, body) =>
TStream(tcoerce[TStream](body.typ).elementType)
case StreamFold(a, zero, accumName, valueName, body) =>
assert(body.typ == zero.typ)
zero.typ
case StreamFold2(_, _, _, _, result) => result.typ
case StreamDistribute(child, pivots, pathPrefix, _, _) =>
val keyType = pivots.typ.asInstanceOf[TContainer].elementType
TArray(TStruct(("interval", TInterval(keyType)), ("fileName", TString), ("numElements", TInt32), ("numBytes", TInt64)))
case StreamScan(a, zero, accumName, valueName, body) =>
assert(body.typ == zero.typ)
TStream(zero.typ)
case StreamAgg(_, _, query) =>
query.typ
case StreamAggScan(_, _, query) =>
TStream(query.typ)
case StreamLocalLDPrune(streamChild, _, _, _, _) =>
val childType = tcoerce[TStruct](tcoerce[TStream](streamChild.typ).elementType)
TStream(TStruct(
"locus" -> childType.fieldType("locus"),
"alleles" -> childType.fieldType("alleles"),
"mean" -> TFloat64,
"centered_length_rec" -> TFloat64))
case RunAgg(body, result, _) =>
result.typ
case RunAggScan(_, _, _, _, result, _) =>
TStream(result.typ)
case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) =>
TStream(join.typ)
case NDArrayShape(nd) =>
val ndType = nd.typ.asInstanceOf[TNDArray]
ndType.shapeType
case NDArrayReshape(nd, shape, _) =>
TNDArray(tcoerce[TNDArray](nd.typ).elementType, Nat(shape.typ.asInstanceOf[TTuple].size))
case NDArrayConcat(nds, _) =>
tcoerce[TArray](nds.typ).elementType
case NDArrayMap(nd, _, body) =>
TNDArray(body.typ, tcoerce[TNDArray](nd.typ).nDimsBase)
case NDArrayMap2(l, _, _, _, body, _) =>
TNDArray(body.typ, tcoerce[TNDArray](l.typ).nDimsBase)
case NDArrayReindex(nd, indexExpr) =>
TNDArray(tcoerce[TNDArray](nd.typ).elementType, Nat(indexExpr.length))
case NDArrayAgg(nd, axes) =>
val childType = tcoerce[TNDArray](nd.typ)
TNDArray(childType.elementType, Nat(childType.nDims - axes.length))
case NDArrayRef(nd, idxs, _) =>
assert(idxs.forall(_.typ == TInt64))
tcoerce[TNDArray](nd.typ).elementType
case NDArraySlice(nd, slices) =>
val childTyp = tcoerce[TNDArray](nd.typ)
val slicesTyp = tcoerce[TTuple](slices.typ)
val tuplesOnly = slicesTyp.types.collect { case x: TTuple => x}
val remainingDims = Nat(tuplesOnly.length)
TNDArray(childTyp.elementType, remainingDims)
case NDArrayFilter(nd, _) =>
nd.typ
case NDArrayMatMul(l, r, _) =>
val lTyp = tcoerce[TNDArray](l.typ)
val rTyp = tcoerce[TNDArray](r.typ)
TNDArray(lTyp.elementType, Nat(TNDArray.matMulNDims(lTyp.nDims, rTyp.nDims)))
case NDArrayQR(nd, mode, _) =>
if (Array("complete", "reduced").contains(mode)) {
TTuple(TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)))
} else if (mode == "raw") {
TTuple(TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(1)))
} else if (mode == "r") {
TNDArray(TFloat64, Nat(2))
} else {
throw new NotImplementedError(s"Cannot infer type for mode $mode")
}
case NDArraySVD(nd, _, compute_uv, _) =>
if (compute_uv) {
TTuple(TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(1)), TNDArray(TFloat64, Nat(2)))
} else {
TNDArray(TFloat64, Nat(1))
}
case NDArrayInv(_, _) =>
TNDArray(TFloat64, Nat(2))
case NDArrayWrite(_, _) => TVoid
case AggFilter(_, aggIR, _) =>
aggIR.typ
case AggExplode(array, name, aggBody, _) =>
aggBody.typ
case AggGroupBy(key, aggIR, _) =>
TDict(key.typ, aggIR.typ)
case AggArrayPerElement(a, _, _, aggBody, _, _) => TArray(aggBody.typ)
case ApplyAggOp(_, _, aggSig) =>
aggSig.returnType
case ApplyScanOp(_, _, aggSig) =>
aggSig.returnType
case AggFold(zero, _, _, _, _, _) =>
zero.typ
case MakeStruct(fields) =>
TStruct(fields.map { case (name, a) =>
(name, a.typ)
}: _*)
case SelectFields(old, fields) =>
val tbs = tcoerce[TStruct](old.typ)
tbs.select(fields.toFastIndexedSeq)._1
case InsertFields(old, fields, fieldOrder) =>
val tbs = tcoerce[TStruct](old.typ)
val s = tbs.insertFields(fields.map(f => (f._1, f._2.typ)))
fieldOrder.map { fds =>
assert(fds.length == s.size, s"${fds} != ${s.types.toIndexedSeq}")
TStruct(fds.map(f => f -> s.fieldType(f)): _*)
}.getOrElse(s)
case GetField(o, name) =>
val t = tcoerce[TStruct](o.typ)
if (t.index(name).isEmpty)
throw new RuntimeException(s"$name not in $t")
t.field(name).typ
case MakeTuple(values) =>
TTuple(values.map { case (i, value) => TupleField(i, value.typ) }.toFastIndexedSeq)
case GetTupleElement(o, idx) =>
val t = tcoerce[TTuple](o.typ)
val fd = t.fields(t.fieldIndex(idx)).typ
fd
case TableCount(_) => TInt64
case MatrixCount(_) => TTuple(TInt64, TInt32)
case TableAggregate(child, query) =>
query.typ
case MatrixAggregate(child, query) =>
query.typ
case _: TableWrite => TVoid
case _: TableMultiWrite => TVoid
case _: MatrixWrite => TVoid
case _: MatrixMultiWrite => TVoid
case _: BlockMatrixCollect => TNDArray(TFloat64, Nat(2))
case BlockMatrixWrite(_, writer) => writer.loweredTyp
case _: BlockMatrixMultiWrite => TVoid
case TableGetGlobals(child) => child.typ.globalType
case TableCollect(child) => TStruct("rows" -> TArray(child.typ.rowType), "global" -> child.typ.globalType)
case TableToValueApply(child, function) => function.typ(child.typ)
case MatrixToValueApply(child, function) => function.typ(child.typ)
case BlockMatrixToValueApply(child, function) => function.typ(child.typ)
case CollectDistributedArray(_, _, _, _, body, _, _, _) => TArray(body.typ)
case ReadPartition(_, rowType, _) => TStream(rowType)
case WritePartition(value, writeCtx, writer) => writer.returnType
case _: WriteMetadata => TVoid
case ReadValue(_, _, typ) => typ
case _: WriteValue => TString
case LiftMeOut(child) => child.typ
}
}
}