Skip to content

Commit

Permalink
postprocess probability to logits
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Oct 10, 2020
1 parent c7ec1f6 commit 688632f
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 71 deletions.
13 changes: 9 additions & 4 deletions autokeras/analysers/output_analysers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self, num_classes=None, multi_label=False, **kwargs):
self.num_classes = num_classes
self.label_encoder = None
self.multi_label = multi_label
self.encoded = False
self.labels = set()

def update(self, data):
Expand All @@ -47,9 +46,6 @@ def finalize(self):
# TODO: support raw string labels for multi-label.
self.labels = sorted(list(self.labels))

if (len(self.shape) > 1 and self.shape[1] > 1) or self.encoded_for_sigmoid():
self.encoded = True

# Infer the num_classes if not specified.
if not self.num_classes:
if self.encoded:
Expand Down Expand Up @@ -90,11 +86,20 @@ def get_expected_shape(self):
expected = [self.num_classes]
return expected

@property
def encoded(self):
return self.encoded_for_sigmoid or self.encoded_for_softmax

@property
def encoded_for_sigmoid(self):
if not len(self.labels) == 2:
return False
return sorted(self.labels) == [0, 1]

@property
def encoded_for_softmax(self):
return len(self.shape) > 1 and self.shape[1] > 1


class RegressionAnalyser(TargetAnalyser):
def __init__(self, output_dim=None, **kwargs):
Expand Down
39 changes: 26 additions & 13 deletions autokeras/blocks/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
super().__init__(loss=loss, metrics=metrics, **kwargs)
# Infered from analyser.
self._encoded = None
self._encoded_for_sigmoid = None
self._encoded_for_softmax = None
self._add_one_dimension = False
self._labels = None

Expand Down Expand Up @@ -134,42 +136,53 @@ def config_from_analyser(self, analyser):
self.num_classes = analyser.num_classes
self.loss = self.infer_loss()
self._encoded = analyser.encoded
self._encoded_for_sigmoid = analyser.encoded_for_sigmoid
self._encoded_for_softmax = analyser.encoded_for_softmax
self._add_one_dimension = len(analyser.shape) == 1
self._labels = analyser.labels

def get_hyper_preprocessors(self):
hyper_preprocessors = []

if self._add_one_dimension:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(preprocessors.AddOneDimension())
)

if self.dtype in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(preprocessors.CastToInt32())
)

if not self._encoded and self.dtype != tf.string:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(preprocessors.CastToString())
)
if self.multi_label:

if self._encoded_for_sigmoid:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(
preprocessors.MultiLabelEncoder()
preprocessors.SigmoidPostprocessor()
)
)
if not self._encoded:
if self.num_classes == 2 and not self.multi_label:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(
preprocessors.LabelEncoder(self._labels)
)
elif self._encoded_for_softmax:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(
preprocessors.SoftmaxPostprocessor()
)
else:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(
preprocessors.OneHotEncoder(self._labels)
)
)
elif self.num_classes == 2:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(
preprocessors.LabelEncoder(self._labels)
)
)
else:
hyper_preprocessors.append(
hpps_module.DefaultHyperPreprocessor(
preprocessors.OneHotEncoder(self._labels)
)
)
return hyper_preprocessors


Expand Down
3 changes: 3 additions & 0 deletions autokeras/engine/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def transform(self, dataset):
"""
raise NotImplementedError

def get_config(self):
return {}


class TargetPreprocessor(Preprocessor):
"""Preprocessor for target data."""
Expand Down
7 changes: 4 additions & 3 deletions autokeras/preprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from autokeras.preprocessors.common import LambdaPreprocessor
from autokeras.preprocessors.common import SlidingWindow
from autokeras.preprocessors.encoders import LabelEncoder
from autokeras.preprocessors.encoders import MultiLabelEncoder
from autokeras.preprocessors.encoders import OneHotEncoder
from autokeras.preprocessors.postprocessors import SigmoidPostprocessor
from autokeras.preprocessors.postprocessors import SoftmaxPostprocessor


def serialize(encoder):
return tf.keras.utils.serialize_keras_object(encoder)
def serialize(preprocessor):
return tf.keras.utils.serialize_keras_object(preprocessor)


def deserialize(config, custom_objects=None):
Expand Down
3 changes: 2 additions & 1 deletion autokeras/preprocessors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tensorflow as tf

from autokeras.engine import preprocessor
from autokeras.utils import data_utils


class LambdaPreprocessor(preprocessor.Preprocessor):
Expand Down Expand Up @@ -59,7 +60,7 @@ def get_config(self):
return {}

def transform(self, dataset):
return dataset.map(tf.strings.as_string)
return dataset.map(data_utils.cast_to_string)


class SlidingWindow(preprocessor.Preprocessor):
Expand Down
24 changes: 0 additions & 24 deletions autokeras/preprocessors/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,27 +118,3 @@ def postprocess(self, data):
return np.array(
list(map(lambda x: self.labels[int(round(x[0]))], np.array(data)))
).reshape(-1, 1)


class MultiLabelEncoder(Encoder):
"""Encoder for multi-label data."""

def __init__(self, labels=None, **kwargs):
# TODO: support custom labels.
super().__init__(labels=[], **kwargs)

def transform(self, dataset):
return dataset

def postprocess(self, data):
"""Transform probabilities to zeros and ones.
# Arguments
data: numpy.ndarray. The output probabilities of the classification head.
# Returns
numpy.ndarray. The zeros and ones predictions.
"""
data[data < 0.5] = 0
data[data > 0.5] = 1
return data
57 changes: 57 additions & 0 deletions autokeras/preprocessors/postprocessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from autokeras.engine import preprocessor


class PostProcessor(preprocessor.TargetPreprocessor):
def transform(self, dataset):
return dataset


class SigmoidPostprocessor(PostProcessor):
"""Postprocessor for sigmoid outputs."""

def postprocess(self, data):
"""Transform probabilities to zeros and ones.
# Arguments
data: numpy.ndarray. The output probabilities of the classification head.
# Returns
numpy.ndarray. The zeros and ones predictions.
"""
data[data < 0.5] = 0
data[data > 0.5] = 1
return data


class SoftmaxPostprocessor(PostProcessor):
"""Postprocessor for softmax outputs."""

def postprocess(self, data):
"""Transform probabilities to zeros and ones.
# Arguments
data: numpy.ndarray. The output probabilities of the classification head.
# Returns
numpy.ndarray. The zeros and ones predictions.
"""
idx = np.argmax(data, axis=-1)
data = np.zeros(data.shape)
data[np.arange(data.shape[0]), idx] = 1
return data
5 changes: 3 additions & 2 deletions tests/autokeras/blocks/heads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ def test_multi_label_loss():
assert head.loss.name == "binary_crossentropy"


def test_clf_head_get_multi_label_preprocessor():
def test_clf_head_get_sigmoid_postprocessor():
head = head_module.ClassificationHead(name="a", multi_label=True)
head._encoded = True
head._encoded_for_sigmoid = True
assert isinstance(
head.get_hyper_preprocessors()[0].preprocessor,
preprocessors.MultiLabelEncoder,
preprocessors.SigmoidPostprocessor,
)


Expand Down
24 changes: 0 additions & 24 deletions tests/autokeras/preprocessors/encoders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,6 @@ def test_one_hot_encoder_decode_to_same_string():
assert np.array_equal(result, np.array([["a"], ["b"], ["c"]]))


def test_multi_label_postprocess_to_one_hot_labels():
encoder = encoders.MultiLabelEncoder()

y = encoder.postprocess(np.random.rand(10, 3))

assert set(y.flatten().tolist()) == set([1, 0])


def test_multi_label_transform_dataset_doesnt_change():
encoder = encoders.MultiLabelEncoder()
dataset = tf.data.Dataset.from_tensor_slices([1, 2]).batch(32)

assert encoder.transform(dataset) is dataset


def test_multi_label_deserialize_without_error():
encoder = encoders.MultiLabelEncoder()
dataset = tf.data.Dataset.from_tensor_slices([1, 2]).batch(32)

encoder = preprocessors.deserialize(preprocessors.serialize(encoder))

assert encoder.transform(dataset) is dataset


def test_label_encoder_decode_to_same_string():
encoder = encoders.LabelEncoder(["a", "b"])

Expand Down
67 changes: 67 additions & 0 deletions tests/autokeras/preprocessors/postprocessors_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2020 The AutoKeras Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import tensorflow as tf

from autokeras import preprocessors
from autokeras.preprocessors import postprocessors


def test_sigmoid_postprocess_to_zero_one():
postprocessor = postprocessors.SigmoidPostprocessor()

y = postprocessor.postprocess(np.random.rand(10, 3))

assert set(y.flatten().tolist()) == set([1, 0])


def test_sigmoid_transform_dataset_doesnt_change():
postprocessor = postprocessors.SigmoidPostprocessor()
dataset = tf.data.Dataset.from_tensor_slices([1, 2]).batch(32)

assert postprocessor.transform(dataset) is dataset


def test_sigmoid_deserialize_without_error():
postprocessor = postprocessors.SigmoidPostprocessor()
dataset = tf.data.Dataset.from_tensor_slices([1, 2]).batch(32)

postprocessor = preprocessors.deserialize(preprocessors.serialize(postprocessor))

assert postprocessor.transform(dataset) is dataset


def test_softmax_postprocess_to_zero_one():
postprocessor = postprocessors.SoftmaxPostprocessor()

y = postprocessor.postprocess(np.random.rand(10, 3))

assert set(y.flatten().tolist()) == set([1, 0])


def test_softmax_transform_dataset_doesnt_change():
postprocessor = postprocessors.SoftmaxPostprocessor()
dataset = tf.data.Dataset.from_tensor_slices([1, 2]).batch(32)

assert postprocessor.transform(dataset) is dataset


def test_softmax_deserialize_without_error():
postprocessor = postprocessors.SoftmaxPostprocessor()
dataset = tf.data.Dataset.from_tensor_slices([1, 2]).batch(32)

postprocessor = preprocessors.deserialize(preprocessors.serialize(postprocessor))

assert postprocessor.transform(dataset) is dataset

0 comments on commit 688632f

Please sign in to comment.