Skip to content

Commit

Permalink
Refactor the classes extension structure for code reuse (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchengeasy authored and haifeng-jin committed Jan 26, 2019
1 parent e231e7c commit fcb45df
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 137 deletions.
4 changes: 2 additions & 2 deletions autokeras/image/image_supervised.py
Expand Up @@ -7,7 +7,7 @@
from autokeras.nn.loss_function import classification_loss, regression_loss
from autokeras.nn.metric import Accuracy, MSE
from autokeras.preprocessor import OneHotEncoder, ImageDataTransformer
from autokeras.supervised import PortableDeepSupervised, DeepSupervised
from autokeras.supervised import PortableDeepSupervised, DeepTaskSupervised
from autokeras.utils import pickle_to_file, \
read_csv_file, read_image, compute_image_resize_params, resize_image_data

Expand Down Expand Up @@ -77,7 +77,7 @@ def load_image_dataset(csv_file_path, images_path, parallel=True):
return np.array(x), np.array(y)


class ImageSupervised(DeepSupervised, ABC):
class ImageSupervised(DeepTaskSupervised, ABC):
"""Abstract image supervised class.
Attributes:
Expand Down
82 changes: 28 additions & 54 deletions autokeras/predefined_model.py
@@ -1,44 +1,36 @@
import numpy as np
import torch
from abc import ABC, abstractmethod
from functools import reduce
from sklearn.model_selection import train_test_split

from autokeras.utils import rand_temp_folder_generator, validate_xy
from autokeras.utils import rand_temp_folder_generator, validate_xy, resize_image_data, compute_image_resize_params
from autokeras.nn.metric import Accuracy
from autokeras.nn.loss_function import classification_loss
from autokeras.nn.generator import ResNetGenerator, DenseNetGenerator
from autokeras.search import train
from autokeras.constant import Constant
from autokeras.preprocessor import ImageDataTransformer, OneHotEncoder
from autokeras.supervised import SingleModelSupervised


class PredefinedModel(ABC):
class PredefinedModel(SingleModelSupervised):
"""The base class for the predefined model without architecture search
Attributes:
graph: The graph form of the model.
y_encoder: Label encoder, used in transform_y or inverse_transform_y for encode the label. For example,
if one hot encoder needed, y_encoder can be OneHotEncoder.
data_transformer_class: A transformer class to process the data. See example as ImageDataTransformer.
data_transformer: A instance of data_transformer_class.
data_transformer: A instance of transformer to process the data, See example as ImageDataTransformer.
verbose: A boolean of whether the search process will be printed to stdout.
path: A string. The path to a directory, where the intermediate results are saved.
"""

def __init__(self, y_encoder=OneHotEncoder, data_transformer_class=ImageDataTransformer,
verbose=False,
path=None):
def __init__(self, y_encoder=OneHotEncoder(), data_transformer=None, verbose=False, path=None):
super().__init__(verbose, path)
self.graph = None
self.generator = None
self.loss = classification_loss
self.metric = Accuracy
self.y_encoder = y_encoder()
self.data_transformer_class = data_transformer_class
self.data_transformer = None
self.verbose = verbose
if path is None:
path = rand_temp_folder_generator()
self.path = path
self.y_encoder = y_encoder
self.data_transformer = data_transformer

@abstractmethod
def _init_generator(self, n_output_node, input_shape):
Expand All @@ -50,16 +42,22 @@ def _init_generator(self, n_output_node, input_shape):
"""
pass

def compile(self, loss=classification_loss, metric=Accuracy):
"""Configures the model for training.
@property
def loss(self):
return classification_loss

Args:
loss: The loss function to train the model. See example as classification_loss.
metric: The metric to be evaluted by the model during training and testing.
See example as Accuracy.
"""
self.loss = loss
self.metric = metric
@property
def metric(self):
return Accuracy

def preprocess(self, x):
return resize_image_data(x, self.resize_shape)

def transform_y(self, y_train):
return self.y_encoder.transform(y_train)

def inverse_transform_y(self, output):
return self.y_encoder.inverse_transform(output)

def fit(self, x, y, trainer_args=None):
"""Trains the model on the dataset given.
Expand All @@ -72,8 +70,10 @@ def fit(self, x, y, trainer_args=None):
trainer_args: A dictionary containing the parameters of the ModelTrainer constructor.
"""
validate_xy(x, y)
self.resize_shape = compute_image_resize_params(x)
x = self.preprocess(x)
self.y_encoder.fit(y)
y = self.y_encoder.transform(y)
y = self.transform_y(y)
# Divide training data into training and testing data.
validation_set_size = int(len(y) * Constant.VALIDATION_SET_SIZE)
validation_set_size = min(validation_set_size, 500)
Expand All @@ -83,7 +83,7 @@ def fit(self, x, y, trainer_args=None):
random_state=42)

# initialize data_transformer
self.data_transformer = self.data_transformer_class(x_train)
self.data_transformer = ImageDataTransformer(x_train)
# Wrap the data into DataLoaders
train_loader = self.data_transformer.transform_train(x_train, y_train)
test_loader = self.data_transformer.transform_test(x_test, y_test)
Expand All @@ -97,32 +97,6 @@ def fit(self, x, y, trainer_args=None):
trainer_args, self.metric, self.loss,
self.verbose, self.path)

def predict(self, x_test):
"""Return predict results for the testing data.
Args:
x_test: An instance of numpy.ndarray containing the testing data.
Returns:
A numpy.ndarray containing the results.
"""
test_loader = self.data_transformer.transform_test(x_test)
model = self.graph.produce_model()
model.eval()

outputs = []
with torch.no_grad():
for index, inputs in enumerate(test_loader):
outputs.append(model(inputs).numpy())
output = reduce(lambda x, y: np.concatenate((x, y)), outputs)
return self.y_encoder.inverse_transform(output)

def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`.
"""
y_predict = self.predict(x_test)
return self.metric().evaluate(y_predict, y_test)


class PredefinedResnet(PredefinedModel):
def _init_generator(self, n_output_node, input_shape):
Expand Down
160 changes: 82 additions & 78 deletions autokeras/supervised.py
Expand Up @@ -42,19 +42,6 @@ def fit(self, x, y, time_limit=None):
time_limit: The time limit for the search in seconds.
"""

@abstractmethod
def final_fit(self, x_train, y_train, x_test, y_test, trainer_args=None, retrain=False):
"""Final training after found the best architecture.
Args:
x_train: A numpy.ndarray of training data.
y_train: A numpy.ndarray of training targets.
x_test: A numpy.ndarray of testing data.
y_test: A numpy.ndarray of testing targets.
trainer_args: A dictionary containing the parameters of the ModelTrainer constructor.
retrain: A boolean of whether reinitialize the weights of the model.
"""

@abstractmethod
def predict(self, x_test):
"""Return predict results for the testing data.
Expand All @@ -73,7 +60,25 @@ def evaluate(self, x_test, y_test):
pass


class DeepSupervised(Supervised):
class SearchSupervised(Supervised):
"""The base class for all supervised task with architecture search.
"""

@abstractmethod
def final_fit(self, x_train, y_train, x_test, y_test, trainer_args=None, retrain=False):
"""Final training after found the best architecture.
Args:
x_train: A numpy.ndarray of training data.
y_train: A numpy.ndarray of training targets.
x_test: A numpy.ndarray of testing data.
y_test: A numpy.ndarray of testing targets.
trainer_args: A dictionary containing the parameters of the ModelTrainer constructor.
retrain: A boolean of whether reinitialize the weights of the model.
"""


class DeepTaskSupervised(SearchSupervised):

def __init__(self, verbose=False, path=None, resume=False, searcher_args=None,
search_type=BayesianSearcher):
Expand Down Expand Up @@ -209,80 +214,21 @@ def evaluate(self, x_test, y_test):
return self.metric().evaluate(y_predict, y_test)


class PortableClass(ABC):
def __init__(self, graph, verbose=False):
self.graph = graph
self.verbose = verbose

@abstractmethod
def fit(self, **kwargs):
"""further training of the model (graph).
"""
pass

@abstractmethod
def predict(self, x_test):
"""Return predict results for the testing data.
Args:
x_test: An instance of numpy.ndarray containing the testing data.
Returns:
A numpy.ndarray containing the results.
"""
pass

@abstractmethod
def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
pass


class PortableDeepSupervised(PortableClass):
def __init__(self, graph, y_encoder, data_transformer, verbose=False, path=None):
class SingleModelSupervised(Supervised):
"""The base class for all supervised task without architecture search.
"""
def __init__(self, verbose=False, path=None):
"""Initialize the instance.
Args:
graph: The graph form of the learned model.
y_encoder: The encoder of the label. See example as OneHotEncoder
data_transformer: A transformer class to process the data. See example as ImageDataTransformer.
verbose: A boolean of whether the search process will be printed to stdout.
path: A string. The path to a directory, where the intermediate results are saved.
"""
super(PortableDeepSupervised, self).__init__(graph, verbose)
self.y_encoder = y_encoder
self.data_transformer = data_transformer
super().__init__(verbose)
if path is None:
path = rand_temp_folder_generator()
self.path = path

def fit(self, x_train, y_train, x_test, y_test, trainer_args=None, retrain=False):
"""further training of the model (graph).
Args:
x_train: A numpy.ndarray of training data.
y_train: A numpy.ndarray of training targets.
x_test: A numpy.ndarray of testing data.
y_test: A numpy.ndarray of testing targets.
trainer_args: A dictionary containing the parameters of the ModelTrainer constructor.
retrain: A boolean of whether reinitialize the weights of the model.
"""
x_train = self.preprocess(x_train)
x_test = self.preprocess(x_test)
if trainer_args is None:
trainer_args = {'max_no_improvement_num': 30}

y_train = self.transform_y(y_train)
y_test = self.transform_y(y_test)

train_data = self.data_transformer.transform_train(x_train, y_train)
test_data = self.data_transformer.transform_test(x_test, y_test)

if retrain:
self.graph.weighted = False
_, _1, self.graph = train(None, self.graph, train_data, test_data, trainer_args,
self.metric, self.loss, self.verbose, self.path)

@property
@abstractmethod
def metric(self):
Expand Down Expand Up @@ -333,3 +279,61 @@ def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
y_predict = self.predict(x_test)
return self.metric().evaluate(y_predict, y_test)

def save(self, model_path):
"""Save the model as keras format.
Args:
model_path: the path to save model.
"""
self.graph.produce_keras_model().save(model_path)


class PortableDeepSupervised(SingleModelSupervised):
def __init__(self, graph, y_encoder, data_transformer, verbose=False, path=None):
"""Initialize the instance.
Args:
graph: The graph form of the learned model.
y_encoder: The encoder of the label. See example as OneHotEncoder
data_transformer: A transformer class to process the data. See example as ImageDataTransformer.
verbose: A boolean of whether the search process will be printed to stdout.
path: A string. The path to a directory, where the intermediate results are saved.
"""
super().__init__(verbose, path)
self.graph = graph
self.y_encoder = y_encoder
self.data_transformer = data_transformer

def fit(self, x, y, trainer_args=None, retrain=False):
"""Trains the model on the dataset given.
Args:
x: A numpy.ndarray instance containing the training data or the training data combined with the
validation data.
y: A numpy.ndarray instance containing the label of the training data. or the label of the training data
combined with the validation label.
trainer_args: A dictionary containing the parameters of the ModelTrainer constructor.
retrain: A boolean of whether reinitialize the weights of the model.
"""
x = self.preprocess(x)
# Divide training data into training and testing data.
validation_set_size = int(len(y) * Constant.VALIDATION_SET_SIZE)
validation_set_size = min(validation_set_size, 500)
validation_set_size = max(validation_set_size, 1)
x_train, x_test, y_train, y_test = train_test_split(x, y,
test_size=validation_set_size,
random_state=42)
if trainer_args is None:
trainer_args = {'max_no_improvement_num': 30}

y_train = self.transform_y(y_train)
y_test = self.transform_y(y_test)

train_data = self.data_transformer.transform_train(x_train, y_train)
test_data = self.data_transformer.transform_test(x_test, y_test)

if retrain:
self.graph.weighted = False
_, _1, self.graph = train(None, self.graph, train_data, test_data, trainer_args,
self.metric, self.loss, self.verbose, self.path)
4 changes: 2 additions & 2 deletions autokeras/text/text_supervised.py
Expand Up @@ -5,11 +5,11 @@
from autokeras.nn.loss_function import classification_loss, regression_loss
from autokeras.nn.metric import Accuracy, MSE
from autokeras.preprocessor import OneHotEncoder, TextDataTransformer
from autokeras.supervised import DeepSupervised
from autokeras.supervised import DeepTaskSupervised
from autokeras.text.text_preprocessor import text_preprocess


class TextSupervised(DeepSupervised, ABC):
class TextSupervised(DeepTaskSupervised, ABC):
"""TextClassifier class.
Attributes:
Expand Down
2 changes: 1 addition & 1 deletion tests/image/test_image_supervised.py
Expand Up @@ -216,7 +216,7 @@ def test_export_keras_model(_, _1):
score = model.evaluate(train_x, train_y)
assert score <= 1.0
before = model.graph
model.fit(train_x, train_y, train_x, train_y)
model.fit(train_x, train_y)
assert model.graph == before
clean_dir(TEST_TEMP_DIR)

Expand Down

0 comments on commit fcb45df

Please sign in to comment.