Skip to content

Commit

Permalink
[MRG] Evaluation criteria (#568)
Browse files Browse the repository at this point in the history
* added evaluation function

* improved code quality

* changed condition of assertion
  • Loading branch information
droidadroit authored and haifeng-jin committed Mar 13, 2019
1 parent 9a65cd6 commit ec8b348
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
19 changes: 19 additions & 0 deletions autokeras/net_module.py
Expand Up @@ -11,6 +11,7 @@

from autokeras.utils import pickle_to_file, rand_temp_folder_generator, ensure_dir
from autokeras.nn.generator import CnnGenerator, MlpGenerator, ResNetGenerator, DenseNetGenerator
from autokeras.utils import get_device


class NetworkModule:
Expand Down Expand Up @@ -123,6 +124,24 @@ def predict(self, test_loader):
output = reduce(lambda x, y: np.concatenate((x, y)), outputs)
return output

def evaluate(self, test_data):
"""Evaluate the performance of the best architecture in terms of the loss.
Args:
test_data: A DataLoader instance representing the testing data.
"""
model = self.best_model.produce_model()
model.eval()
device = get_device()
target, prediction = [], []

with torch.no_grad():
for _, (x, y) in enumerate(test_data):
x, y = x.to(device), y.to(device)
prediction.append(model(x))
target.append(y)
return self.metric().compute(prediction, target)


class CnnModule(NetworkModule):
""" Class to create a CNN module."""
Expand Down
3 changes: 3 additions & 0 deletions tests/test_preprocessor.py
Expand Up @@ -24,4 +24,7 @@ def test_batch_dataset(_, _1):
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
cnn = CnnModule(classification_loss, Accuracy, {}, TEST_TEMP_DIR, True)
cnn.fit(2, (4, 250, 250, 3), train_dataloader, test_dataloader, 12 * 60 * 60)
score = cnn.evaluate(test_dataloader)
if score < 0 or score > 1.0:
raise AssertionError()
clean_dir(TEST_TEMP_DIR)

0 comments on commit ec8b348

Please sign in to comment.