/
Pipeline.scala
78 lines (64 loc) · 2.94 KB
/
Pipeline.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package ml.combust.mleap.runtime.transformer
import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types.{DataType, NodeShape, StructField, StructType}
import ml.combust.mleap.runtime.frame.{FrameBuilder, Transformer => FrameTransformer}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.Try
/**
* Created by hwilkins on 11/8/15.
*/
case class PipelineModel(transformers: Seq[FrameTransformer]) extends Model {
override def inputSchema: StructType = {
throw new NotImplementedError("inputSchema is not implemented for a PipelineModel")
}
override def outputSchema: StructType = {
throw new NotImplementedError("outputSchema is not implemented for a PipelineModel")
}
}
case class Pipeline(override val uid: String = FrameTransformer.uniqueName("pipeline"),
override val shape: NodeShape,
override val model: PipelineModel) extends FrameTransformer {
def transformers: Seq[FrameTransformer] = model.transformers
override def transform[TB <: FrameBuilder[TB]](builder: TB): Try[TB] = {
model.transformers.foldLeft(Try(builder))((b, stage) => b.flatMap(stage.transform))
}
override def transformAsync[FB <: FrameBuilder[FB]](builder: FB)
(implicit ec: ExecutionContext): Future[FB] = {
model.transformers.foldLeft(Future(builder)) {
(fb, stage) => fb.flatMap(b => stage.transformAsync(b))
}
}
override def close(): Unit = transformers.foreach(_.close())
override def inputSchema: StructType = schemas._1
override def outputSchema: StructType = schemas._2
def intermediateSchema: StructType = schemas._3
def strictOutputSchema: StructType = schemas._4
private lazy val schemas: (StructType, StructType, StructType, StructType) = {
val (inputs, outputs, intermediates) = transformers.foldLeft(
(Map[String, DataType](), Map[String, DataType](), Map[String, DataType]())) {
case ((iacc, oacc, intacc), tform) =>
(iacc ++ tform.inputSchema.fields.map(f => f.name -> f.dataType),
oacc ++ tform.outputSchema.fields.map(f => f.name -> f.dataType),
intacc ++ { tform match {
case pip: Pipeline => pip.intermediateSchema.fields.map(f => f.name -> f.dataType)
case _ => Map() } })
}
val actualInputs = (inputs -- outputs.keys).map {
case (name, dt) => StructField(name, dt)
}.toSeq
val actualOutputs = outputs.map {
case (name, dt) => StructField(name, dt)
}.toSeq
val strictOutputs = outputs -- inputs.keys -- intermediates.keys
val strictOutputSchema = strictOutputs.map {
case (name, dt) => StructField(name, dt)
}.toSeq
val intermediateSchema = (outputs -- strictOutputs.keys).map {
case (name, dt) => StructField(name, dt)
}.toSeq
(StructType(actualInputs).get,
StructType(actualOutputs).get,
StructType(intermediateSchema).get,
StructType(strictOutputSchema).get)
}
}