Skip to content

Commit

Permalink
Added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
droidadroit committed Nov 16, 2018
1 parent 2fcc50b commit 1a449ca
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions autokeras/net_module.py
Expand Up @@ -8,6 +8,17 @@


class NetworkModule:
""" Class to create a network module.
Attributes:
loss: A function taking two parameters, the predictions and the ground truth.
metric: An instance of the Metric subclasses.
searcher_args: A dictionary containing the parameters for the searcher's __init__ function.
searcher: An instance of the Searcher class.
path: A string. The path to the directory to save the searcher.
verbose: A boolean. Setting it to true prints to stdout.
generators: A list of instances of the NetworkGenerator class or its subclasses.
"""
def __init__(self, loss, metric, searcher_args, path, verbose=False):
self.searcher_args = searcher_args
self.searcher = None
Expand All @@ -18,14 +29,14 @@ def __init__(self, loss, metric, searcher_args, path, verbose=False):
self.generators = []

def fit(self, n_output_node, input_shape, train_data, test_data, time_limit=24 * 60 * 60):
""" Search the best CnnModule.
""" Search the best network.
Args:
n_output_node: A integer value represent the number of output node in the final layer.
input_shape: A tuple to express the shape of every train entry. For example,
MNIST dataset would be (28,28,1)
train_data: A PyTorch DataLoader instance represents the training data
test_data: A PyTorch DataLoader instance represents the testing data
MNIST dataset would be (28,28,1).
train_data: A PyTorch DataLoader instance representing the training data.
test_data: A PyTorch DataLoader instance representing the testing data.
time_limit: A integer value represents the time limit on searching for models.
"""
# Create the searcher and save on disk
Expand Down Expand Up @@ -66,8 +77,8 @@ def final_fit(self, train_data, test_data, trainer_args=None, retrain=False):
Args:
trainer_args: A dictionary containing the parameters of the ModelTrainer constructor.
retrain: A boolean of whether reinitialize the weights of the model.
train_data: A DataLoader instance representing the training data
test_data: A DataLoader instance representing the testing data
train_data: A DataLoader instance representing the training data.
test_data: A DataLoader instance representing the testing data.
"""
graph = self.searcher.load_best_model()
Expand All @@ -91,12 +102,16 @@ def best_model(self):


class CnnModule(NetworkModule):
""" Class to create a CNN module.
"""
def __init__(self, loss, metric, searcher_args, path, verbose=False):
super(CnnModule, self).__init__(loss, metric, searcher_args, path, verbose)
self.generators.append(CnnGenerator)


class MlpModule(NetworkModule):
""" Class to create an MLP module.
"""
def __init__(self, loss, metric, searcher_args, path, verbose=False):
super(MlpModule, self).__init__(loss, metric, searcher_args, path, verbose)
self.generators.append(MlpGenerator)

0 comments on commit 1a449ca

Please sign in to comment.