Skip to content

Commit

Permalink
feat: add support for multi-layer perceptron estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
iamDecode committed Jan 28, 2022
1 parent 45d4b2c commit afc931e
Show file tree
Hide file tree
Showing 8 changed files with 788 additions and 2 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ The library currently supports the following models:
| [Gaussian Naive Bayes](sklearn_pmml_model/naive_bayes) || | ✅<sup>3</sup> |
| [Support Vector Machines](sklearn_pmml_model/svm) ||| ✅<sup>3</sup> |
| [Nearest Neighbors](sklearn_pmml_model/neighbors) ||| |
| [Neural Networks](sklearn_pmml_model/neural_network) ||| |

<sub><sup>1</sup> Categorical feature support using slightly modified internals, based on [scikit-learn#12866](https://github.com/scikit-learn/scikit-learn/pull/12866).</sub>

<sub><sup>2</sup> These models differ only in training characteristics, the resulting model is of the same form. Classification is supported using `PMMLLogisticRegression` for regression models and `PMMLRidgeClassifier` for general regression models.</sub>

<sub><sup>3</sup> By one-hot encoding categorical features automatically.</sub>

---

## Example
A minimal working example (using [this PMML file](https://github.com/iamDecode/sklearn-pmml-model/blob/master/models/randomForest.pmml)) is shown below:
Expand Down
128 changes: 128 additions & 0 deletions models/nn-iris.pmml
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
<?xml version="1.0"?>
<PMML version="4.4.1" xmlns="http://www.dmg.org/PMML-4_4" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.dmg.org/PMML-4_4 http://www.dmg.org/pmml/v4-4/pmml-4-4.xsd">
<Header copyright="Copyright (c) 2022 decode" description="Neural Network Model">
<Extension name="user" value="decode" extender="SoftwareAG PMML Generator"/>
<Application name="SoftwareAG PMML Generator" version="2.5.1"/>
<Timestamp>2022-01-28 11:41:54</Timestamp>
</Header>
<DataDictionary numberOfFields="5">
<DataField name="Class" optype="categorical" dataType="string">
<Value value="versicolor"/>
<Value value="setosa"/>
<Value value="virginica"/>
</DataField>
<DataField name="sepal length (cm)" optype="continuous" dataType="float"/>
<DataField name="sepal width (cm)" optype="continuous" dataType="float"/>
<DataField name="petal length (cm)" optype="continuous" dataType="float"/>
<DataField name="petal width (cm)" optype="continuous" dataType="float"/>
</DataDictionary>
<NeuralNetwork modelName="NeuralNet_model" functionName="classification" numberOfLayers="2" activationFunction="logistic">
<MiningSchema>
<MiningField name="Class" usageType="predicted" invalidValueTreatment="returnInvalid"/>
<MiningField name="sepal length (cm)" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="sepal width (cm)" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="petal length (cm)" usageType="active" invalidValueTreatment="returnInvalid"/>
<MiningField name="petal width (cm)" usageType="active" invalidValueTreatment="returnInvalid"/>
</MiningSchema>
<Output>
<OutputField name="Predicted_Class" optype="categorical" dataType="string" feature="predictedValue"/>
<OutputField name="Probability_versicolor" optype="continuous" dataType="double" feature="probability" value="versicolor"/>
<OutputField name="Probability_setosa" optype="continuous" dataType="double" feature="probability" value="setosa"/>
<OutputField name="Probability_viginica" optype="continuous" dataType="double" feature="probability" value="virginica"/>
</Output>
<NeuralInputs numberOfInputs="4">
<NeuralInput id="1">
<DerivedField name="derivedNI_sepal length (cm)" optype="continuous" dataType="double">
<FieldRef field="sepal length (cm)"/>
</DerivedField>
</NeuralInput>
<NeuralInput id="2">
<DerivedField name="derivedNI_sepal width (cm)" optype="continuous" dataType="double">
<FieldRef field="sepal width (cm)"/>
</DerivedField>
</NeuralInput>
<NeuralInput id="3">
<DerivedField name="derivedNI_petal length (cm)" optype="continuous" dataType="double">
<FieldRef field="petal length (cm)"/>
</DerivedField>
</NeuralInput>
<NeuralInput id="4">
<DerivedField name="derivedNI_petal width (cm)" optype="continuous" dataType="double">
<FieldRef field="petal width (cm)"/>
</DerivedField>
</NeuralInput>
</NeuralInputs>
<NeuralLayer numberOfNeurons="5">
<Neuron id="5" bias="6.48585945924381">
<Con from="1" weight="4.54587431828472"/>
<Con from="2" weight="5.52146633781706"/>
<Con from="3" weight="-9.17023767457325"/>
<Con from="4" weight="-3.33030423024173"/>
</Neuron>
<Neuron id="6" bias="-0.219171178910383">
<Con from="1" weight="-0.621178185846089"/>
<Con from="2" weight="-0.251461468530296"/>
<Con from="3" weight="0.191483166902765"/>
<Con from="4" weight="0.110745459632997"/>
</Neuron>
<Neuron id="7" bias="-0.403390824874627">
<Con from="1" weight="-0.75197406561544"/>
<Con from="2" weight="-1.83034558198221"/>
<Con from="3" weight="3.14589785761204"/>
<Con from="4" weight="1.69002269627064"/>
</Neuron>
<Neuron id="8" bias="-0.503134825404882">
<Con from="1" weight="-0.49881585534317"/>
<Con from="2" weight="1.59778404626156"/>
<Con from="3" weight="-0.0324193085084592"/>
<Con from="4" weight="0.0903405534052034"/>
</Neuron>
<Neuron id="9" bias="-0.195288296527289">
<Con from="1" weight="0.641867291385391"/>
<Con from="2" weight="0.119223840923296"/>
<Con from="3" weight="-0.401058887116727"/>
<Con from="4" weight="-0.536453517669621"/>
</Neuron>
</NeuralLayer>
<NeuralLayer numberOfNeurons="3" activationFunction="identity" normalizationMethod="softmax">
<Neuron id="10" bias="-3.92364759629226">
<Con from="5" weight="9.05216771475619"/>
<Con from="6" weight="-0.754661354448251"/>
<Con from="7" weight="6.54676332447574"/>
<Con from="8" weight="-4.08076405029937"/>
<Con from="9" weight="-2.99574486936919"/>
</Neuron>
<Neuron id="11" bias="1.58442024136309">
<Con from="5" weight="2.66383665553495"/>
<Con from="6" weight="-1.25438096411401"/>
<Con from="7" weight="-8.2424400278915"/>
<Con from="8" weight="1.99704085891973"/>
<Con from="9" weight="0.448610898349359"/>
</Neuron>
<Neuron id="12" bias="2.33921725721368">
<Con from="5" weight="-11.7159867913009"/>
<Con from="6" weight="2.0090441922176"/>
<Con from="7" weight="1.69567608217778"/>
<Con from="8" weight="2.08373122967452"/>
<Con from="9" weight="2.54713021209121"/>
</Neuron>
</NeuralLayer>
<NeuralOutputs numberOfOutputs="3">
<NeuralOutput outputNeuron="10">
<DerivedField name="derivedNO_Class" optype="continuous" dataType="double">
<NormDiscrete field="Class" value="versicolor"/>
</DerivedField>
</NeuralOutput>
<NeuralOutput outputNeuron="11">
<DerivedField name="derivedNO_Class" optype="continuous" dataType="double">
<NormDiscrete field="Class" value="setosa"/>
</DerivedField>
</NeuralOutput>
<NeuralOutput outputNeuron="12">
<DerivedField name="derivedNO_Class" optype="continuous" dataType="double">
<NormDiscrete field="Class" value="virginica"/>
</DerivedField>
</NeuralOutput>
</NeuralOutputs>
</NeuralNetwork>
</PMML>
22 changes: 22 additions & 0 deletions sklearn_pmml_model/neural_network/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# sklearn-pmml-model.neural_network

This package contains `PMMLMLPClassifier` and `PMMLMLPRegressor`.

## Example
A minimal working example is shown below:

```python
import numpy as np
import pandas as pd
from sklearn_pmml_model.neural_network import PMMLMLPClassifier
from sklearn.datasets import load_iris

# Prepare data
data = load_iris(as_frame=True)
X = data.data
y = pd.Series(np.array(data.target_names)[data.target])
y.name = "Class"

clf = PMMLMLPClassifier(pmml="models/nn-iris.pmml")
clf.predict(X)
```
9 changes: 9 additions & 0 deletions sklearn_pmml_model/neural_network/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
The :mod:`sklearn.neural_network` module includes models based on neural networks.
"""

# License: BSD 2-Clause

from ._classes import PMMLMLPClassifier, PMMLMLPRegressor

__all__ = ['PMMLMLPClassifier', 'PMMLMLPRegressor']
112 changes: 112 additions & 0 deletions sklearn_pmml_model/neural_network/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# License: BSD 2-Clause

import numpy as np
from sklearn_pmml_model.base import PMMLBaseClassifier


class PMMLBaseNeuralNetwork:
"""
Abstract class for Neural Network models.
The PMML model consists out of a <NeuralNetwork> element, containing a
<NeuralInputs> element that describes the input layer neurons with
<NeuralInput> elements. Next, a <NeuralLayer> element describes all other
neurons with associated weights and biases. The activation function is either
specified globally with the activationFunction attribute on the
<NeuralNetwork> element, or the same attribute on each layer. Note however
that scikit-learn only supports a single activation function for all hidden
layers. Finally, the <NeuralOutputs> element describes the output layer.
The output is currently expected to match the target field in <MiningSchema>.
Notes
-----
Specification: http://dmg.org/pmml/v4-3/NeuralNetwork.html
"""
def __init__(self):
nn_model = self.root.find('NeuralNetwork')

if nn_model is None:
raise Exception('PMML model does not contain NeuralNetwork.')

inputs = nn_model.find('NeuralInputs')

if inputs is None:
raise Exception('PMML model does not contain NeuralInputs.')

mapping = {
x.find('DerivedField').find('FieldRef').get('field'): x.get('id')
for x in inputs.findall('NeuralInput')
}

target = self.target_field.get('name')
fields = [name for name, field in self.fields.items() if name != target and field.tag == 'DataField']
if set(mapping.keys()) != set(fields):
raise Exception('PMML model preprocesses the data which currently unsupported.')

layers = nn_model.findall('NeuralLayer')

if len(layers) == 0:
raise Exception('PMML model does not contain any NeuralLayer elements.')

self.n_layers_ = len(layers) + 1 # +1 for input layer

neurons = [layer.findall('Neuron') for layer in layers]
self.hidden_layer_sizes = [len(neuron) for neuron in neurons][:-1]

# Determine activation function
activation_functions = {
'logistic': 'logistic',
'tanh': 'tanh',
'identity': 'identity',
'rectifier': 'relu'
}
activation_function = nn_model.get('activationFunction')

if activation_function is None:
activation_function = layers[0].get('activationFunction')

layer_activations = [
layer.get('activationFunction')
for layer in layers[:-1]
if layer.get('activationFunction') is not None
]

if len(np.unique([activation_function] + layer_activations)) > 1:
raise Exception('Neural networks with different activation functions per '
'layer are not currently supported by scikit-learn.')

if activation_function not in activation_functions:
raise Exception('PMML model uses unsupported activationFunction.')

self.activation = activation_functions[activation_function]

# Set neuron weights
sizes = list(zip(
[len(mapping)] + [len(layer) for layer in layers][:-1],
[len(layer) for layer in layers]
))

self.coefs_ = [np.zeros(shape=s) for s in sizes]
self.intercepts_ = [
np.array([float(neuron.get('bias', 0)) for neuron in layer])
for layer in neurons
]

field_ids = [mapping[field] for field in fields]
for li, layer in enumerate(neurons):
if li == 0:
layer_ids = field_ids
else:
layer_ids = [x.get('id') for x in neurons[li - 1]]
for ni, neuron in enumerate(layer):
for connection in neuron.findall('Con'):
ci = layer_ids.index(connection.get('from'))
self.coefs_[li][ci, ni] = float(connection.get('weight'))

if not isinstance(self, PMMLBaseClassifier):
self.out_activation_ = "identity"
elif self.n_outputs_ == 2:
self.out_activation_ = "logistic"
else:
self.out_activation_ = "softmax"
59 changes: 59 additions & 0 deletions sklearn_pmml_model/neural_network/_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# License: BSD 2-Clause

import numpy as np
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils.multiclass import type_of_target
from sklearn_pmml_model.base import PMMLBaseClassifier, PMMLBaseRegressor, get_type
from sklearn_pmml_model.datatypes import Category
from sklearn_pmml_model.neural_network._base import PMMLBaseNeuralNetwork


class PMMLMLPClassifier(PMMLBaseClassifier, PMMLBaseNeuralNetwork, MLPClassifier):
"""
Multi-layer Perceptron classifier.
Parameters
----------
pmml : str, object
Filename or file object containing PMML data.
Notes
-----
Specification: http://dmg.org/pmml/v4-3/NeuralNetwork.html
"""

def __init__(self, pmml):
PMMLBaseClassifier.__init__(self, pmml)
MLPClassifier.__init__(self)
PMMLBaseNeuralNetwork.__init__(self)

self.n_outputs_ = len(self.classes_)

target_type: Category = get_type(self.target_field)
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
self._label_binarizer.classes_ = np.array(target_type.categories)
self._label_binarizer.y_type_ = type_of_target(target_type.categories)
self._label_binarizer.sparse_input_ = False


class PMMLMLPRegressor(PMMLBaseRegressor, PMMLBaseNeuralNetwork, MLPRegressor):
"""
Multi-layer Perceptron regressor.
Parameters
----------
pmml : str, object
Filename or file object containing PMML data.
Notes
-----
Specification: http://dmg.org/pmml/v4-3/NeuralNetwork.html
"""

def __init__(self, pmml):
PMMLBaseRegressor.__init__(self, pmml)
MLPRegressor.__init__(self)
PMMLBaseNeuralNetwork.__init__(self)
30 changes: 30 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pypmml import Model
import numpy as np
import pandas as pd
from os import path

# df = pd.read_csv(path.join('models/categorical-test.csv'))
# cats = np.unique(df['age'])
# df['age'] = pd.Categorical(df['age'], categories=cats).codes + 1
# Xte = df.iloc[:, 1:]
# yte = df.iloc[:, 0]
#
# #model = Model.load('tests/neignbors/knn-sklearn2pmml.pmml')
# model = Model.load('models/knn-reg-pima.pmml')
# results = model.predict(Xte)
# print(results)


from sklearn.datasets import load_iris
pd.set_option("display.precision", 16)

data = load_iris(as_frame=True)

X = data.data
y = data.target
y.name = "Class"

#model = Model.load('tests/neignbors/knn-sklearn2pmml.pmml')
model = Model.load('models/nn-iris.pmml')
results = model.predict(X)
print(results.to_string())

0 comments on commit afc931e

Please sign in to comment.