PMML是数据挖掘的一种通用的规范，它用统一的XML格式来描述我们生成的机器学习模型。这样无论你的模型是sklearn,R还是Spark MLlib生成的，我们都可以将其转化为标准的XML格式来存储。当我们需要将这个PMML的模型用于部署的时候，可以使用目标环境的解析PMML模型的库来加载模型，并做预测。

使用PMML，需要两步的工作，第一块是将离线训练得到的模型转化为PMML模型文件，第二块是将PMML模型文件载入在线预测环境，进行预测。

java中需要添加PMML相关依赖

In [None]:
<dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator</artifactId>
        <version>1.4.1</version>
    </dependency>
    <dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator-extension</artifactId>
        <version>1.4.1</version>
    </dependency>

需要添加对xml的解析依赖：

In [None]:
<dependency>
        <groupId>javax.xml.bind</groupId>
        <artifactId>jaxb-api</artifactId>
        <version>2.3.0</version>
    </dependency>
    <dependency>
        <groupId>com.sun.xml.bind</groupId>
        <artifactId>jaxb-impl</artifactId>
        <version>2.3.0</version>
    </dependency>
    <dependency>
        <groupId>com.sun.xml.bind</groupId>
        <artifactId>jaxb-core</artifactId>
        <version>2.3.0</version>
    </dependency>
    <dependency>
        <groupId>javax.activation</groupId>
        <artifactId>activation</artifactId>
        <version>1.1.1</version>
    </dependency>

测试逻辑回归：

sex	cp	fbs	restecg	exang	slop	thal	ifhealth	age	trestbps	chol	thalach	oldpeak	ca

1	0.5	0	0	0	1	0	0	0.16666666666666666	0.33962264150943394	0.2831050228310502	0.8854961832061069	0.564516129032258	0

预测结果：

"0": 0.6719725244993267

"1": 0.3280274755006733


In [None]:
package com.fun.recommend.implement.service;

 import com.alibaba.fastjson.JSONObject;
 import org.dmg.pmml.FieldName;
 import org.dmg.pmml.PMML;
 import org.jpmml.evaluator.*;
 import org.xml.sax.SAXException;


import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
 import javax.xml.bind.JAXBException;


/**
 * @author ：maoyaozong
 * @date ：Created in 2020/3/19 15:18
 */
public class PMMLDemo {

    private Evaluator loadPmml(){
        PMML pmml = new PMML();
        InputStream inputStream = null;
        try {
            inputStream = new FileInputStream("E:\\Learn-note\\data\\regression.pmml");//"D:/demo.pmml"
        } catch (IOException e) {
            e.printStackTrace();
        }
        if(inputStream == null){
            return null;
        }
        InputStream is = inputStream;
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        } catch (SAXException e1) {
            e1.printStackTrace();
        } catch (JAXBException e1) {
            e1.printStackTrace();
        }finally {
            //关闭输入流
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        pmml = null;
        return evaluator;
    }

    private int predict(Evaluator evaluator, JSONObject json) {
        List<InputField> inputFields = evaluator.getInputFields();
        //模型的原始特征，从画像中获取数据，作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = json.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();

        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();

        Object targetFieldValue = results.get(targetFieldName);
        System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
        ProbabilityDistribution p = (ProbabilityDistribution)targetFieldValue;
        System.out.println(p.getCategories());
        System.out.println(p.getProbability("1"));
        //Double primitiveValue = -1.0;
        //if (targetFieldValue instanceof Computable) {
        //    Computable computable = (Computable) targetFieldValue;
        //    primitiveValue = (Double)computable.getResult();
        //}
        //System.out.println(json.toString() +  ":" + primitiveValue);
        return 0;
        //return primitiveValue;
    }

    public static void main(String[] args) {
        PMMLDemo demo = new PMMLDemo();
        Evaluator model = demo.loadPmml();
        JSONObject json = new JSONObject();
        json.put("sex", 1.0);
        json.put("cp", 0.5);
        json.put("fbs", 0);
        json.put("restecg", 0);
        json.put("exang", 0);
        json.put("slop", 1.0);
        json.put("thal", 0);
        json.put("ifhealth", 0);
        json.put("age", 0.166666667);
        json.put("trestbps", 0.3396226);
        json.put("chol", 0.283105023);
        json.put("thalach", 0.885496183);
        json.put("oldpeak", 0.564516129);
        json.put("ca", 0);

        System.out.println(model.getSummary());
        demo.predict(model,json);
     }
}



输出结果：

Regression

target: ifhealth value: ProbabilityDistribution{result=0.0, probability_entries=[1=0.3280274458009889, 0=0.6719725541990111]}
[1, 0]

0.3280274458009889

结论：可以复现