From 1a449cad40f53a5fee15930896e1b5ac98cdcd5e Mon Sep 17 00:00:00 2001 From: droidadroit Date: Fri, 16 Nov 2018 09:57:30 -0600 Subject: [PATCH] Added documentation --- autokeras/net_module.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/autokeras/net_module.py b/autokeras/net_module.py index f50c537e6..886be676e 100644 --- a/autokeras/net_module.py +++ b/autokeras/net_module.py @@ -8,6 +8,17 @@ class NetworkModule: + """ Class to create a network module. + + Attributes: + loss: A function taking two parameters, the predictions and the ground truth. + metric: An instance of the Metric subclasses. + searcher_args: A dictionary containing the parameters for the searcher's __init__ function. + searcher: An instance of the Searcher class. + path: A string. The path to the directory to save the searcher. + verbose: A boolean. Setting it to true prints to stdout. + generators: A list of instances of the NetworkGenerator class or its subclasses. + """ def __init__(self, loss, metric, searcher_args, path, verbose=False): self.searcher_args = searcher_args self.searcher = None @@ -18,14 +29,14 @@ def __init__(self, loss, metric, searcher_args, path, verbose=False): self.generators = [] def fit(self, n_output_node, input_shape, train_data, test_data, time_limit=24 * 60 * 60): - """ Search the best CnnModule. + """ Search the best network. Args: n_output_node: A integer value represent the number of output node in the final layer. input_shape: A tuple to express the shape of every train entry. For example, - MNIST dataset would be (28,28,1) - train_data: A PyTorch DataLoader instance represents the training data - test_data: A PyTorch DataLoader instance represents the testing data + MNIST dataset would be (28,28,1). + train_data: A PyTorch DataLoader instance representing the training data. + test_data: A PyTorch DataLoader instance representing the testing data. time_limit: A integer value represents the time limit on searching for models. """ # Create the searcher and save on disk @@ -66,8 +77,8 @@ def final_fit(self, train_data, test_data, trainer_args=None, retrain=False): Args: trainer_args: A dictionary containing the parameters of the ModelTrainer constructor. retrain: A boolean of whether reinitialize the weights of the model. - train_data: A DataLoader instance representing the training data - test_data: A DataLoader instance representing the testing data + train_data: A DataLoader instance representing the training data. + test_data: A DataLoader instance representing the testing data. """ graph = self.searcher.load_best_model() @@ -91,12 +102,16 @@ def best_model(self): class CnnModule(NetworkModule): + """ Class to create a CNN module. + """ def __init__(self, loss, metric, searcher_args, path, verbose=False): super(CnnModule, self).__init__(loss, metric, searcher_args, path, verbose) self.generators.append(CnnGenerator) class MlpModule(NetworkModule): + """ Class to create an MLP module. + """ def __init__(self, loss, metric, searcher_args, path, verbose=False): super(MlpModule, self).__init__(loss, metric, searcher_args, path, verbose) self.generators.append(MlpGenerator)