Skip to content

Commit

Permalink
bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jan 6, 2018
1 parent b80fb3b commit 17cdd6a
Showing 1 changed file with 1 addition and 55 deletions.
56 changes: 1 addition & 55 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,8 @@
import numpy as np


def test_model_rainer():
def test_model_trainer():
model = RandomConvClassifierGenerator(3, (28, 28, 1)).generate()
ModelTrainer(model, np.random.rand(2, 28, 28, 1), np.random.rand(2, 3), np.random.rand(1, 28, 28, 1),
np.random.rand(1, 3), False).train_model()


def test_extract_config1():
model_a = Sequential()

model_a.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model_a.add(Conv2D(32, (3, 3), activation='relu'))
model_a.add(MaxPooling2D(pool_size=(2, 2)))
model_a.add(Dropout(0.25))

model_a.add(Flatten())
model_a.add(Dense(128, activation='relu'))
model_a.add(Dropout(0.5))
model_a.add(Dense(10, activation='softmax'))

model_b = Sequential()

model_b.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model_b.add(Conv2D(32, (3, 3), activation='relu'))
model_b.add(MaxPooling2D(pool_size=(2, 2)))
model_b.add(Dropout(0.25))

model_b.add(Flatten())
model_b.add(Dense(128, activation='relu'))
model_b.add(Dropout(0.5))
model_b.add(Dense(10, activation='softmax'))

assert extract_config(model_a) == extract_config(model_b)


def test_extract_config2():
model_a = Sequential()

model_a.add(Conv2D(32, (3, 3), activation='softmax', input_shape=(28, 28, 1)))
model_a.add(Conv2D(32, (3, 3), activation='relu'))
model_a.add(MaxPooling2D(pool_size=(2, 2)))
model_a.add(Dropout(0.25))

model_a.add(Flatten())
model_a.add(Dense(128, activation='relu'))
model_a.add(Dropout(0.5))
model_a.add(Dense(10, activation='softmax'))

model_b = Sequential()

model_b.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model_b.add(Conv2D(32, (3, 3), activation='relu'))
model_b.add(MaxPooling2D(pool_size=(2, 2)))
model_b.add(Dropout(0.25))

model_b.add(Flatten())
model_b.add(Dense(128, activation='relu'))
model_b.add(Dropout(0.5))
model_b.add(Dense(10, activation='softmax'))
assert extract_config(model_a) != extract_config(model_b)

0 comments on commit 17cdd6a

Please sign in to comment.