package com.stargraph.risk.ml import org.apache.spark.sql.{SaveMode, SparkSession} import org.shaded.jpmml.evaluator.spark.{EvaluatorUtil, TransformerBuilder} case class TestEntry( a_date:String, adiposity:Double, alcohol:Option[Double], b_date:String, chd:Double, dst_sp:Double, famhist:String, from_key:String, ldl:Double, new_diff:Double, obesity:Double, sbp:Double, src_sp:Double, to_key:String, tobacco:Double, typea:Double, vfeature:Double, age:Option[Double] ) object PMMLUnitTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().appName("haha").master("local[*]") .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .config("spark.rdd.compress", "true") .config("spark.driver.host", "localhost") .getOrCreate() val modelPath = "pipe.pmml"//paras(1).trim//"lrHeart.pmml" val inputRDD1 = spark.sparkContext.parallelize(Seq( TestEntry( "2018-11-01",32.27,Some(1.0),"2018-10-02",1.0, 114.0,"Absent","node/cc",2.95,30.0,26.81,206.0,144.0,"node/gg",Double.NaN,72.0,206.0,Some(1.0) ) )) val inputDF1=spark.sqlContext.createDataFrame(inputRDD1) val inputRDD2 = spark.sparkContext.parallelize(Seq( TestEntry( "2018-11-01",32.27,Some(56.06),"2018-10-02",1.0, 114.0,"Absent","node/cc",2.95,30.0,26.81,206.0,144.0,"node/gg",6.0,72.0,206.0,None ) )) val inputDF2=spark.sqlContext.createDataFrame(inputRDD2) import java.io.File; val file = new File(modelPath) val evaluator = EvaluatorUtil.createEvaluator(file) val pmmlTransformerBuilder = new TransformerBuilder(evaluator) .withTargetCols() .withOutputCols() .exploded(false); val pmmlTransformer = pmmlTransformerBuilder.build(); val result1=pmmlTransformer.transform(inputDF1) result1.show() result1.printSchema() val result2=pmmlTransformer.transform(inputDF2) result2.show() result2.printSchema() } }