Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional hyperparameter tuning bug #66

Closed
rcmagic1 opened this issue Aug 26, 2019 · 11 comments
Closed

Conditional hyperparameter tuning bug #66

rcmagic1 opened this issue Aug 26, 2019 · 11 comments

Comments

@rcmagic1
Copy link

rcmagic1 commented Aug 26, 2019

I'm using Keras-Tuner to run trials on a multi-layer NN with variable number of layer and units within each layer, similar to the example in the README:

for i in range(hp.Int('num_layers', 2, 20)):
        model.add(layers.Dense(units=hp.Int('units_' + str(i),
                                            min_value=32,
                                            max_value=512,
                                            step=32),
                               activation='relu'))

The "units_#" hyperpameter should be conditional upon "num_layer" hyperparameter. E.g.if "num_layers=2" then I should see "units_0" and "units_1". However in my testing I'm not seeing proper correlation (num_layers doesn't match the number of units_# hyperparameter values set). Instead I see something like the following:

[Trial summary]

Hp values:
|-num_fc_layers: 2
|-num_units_0: ...
|-num_units_1: ...
|-num_units_2: ...
|-num_units_3: ...
|-num_units_4: ..

or

[Trial summary]

Hp values:
|-num_fc_layers: 5
|-num_units_0: ...
|-num_units_1: ...
|-num_units_2: ...

This effectively makes the summary of hyperparameters used in a trial useless.
I did some debugging of the code but haven't found the culprit yet.
I'm using "randomsearch" tuner and wrapped my model build in HyperModel class (rather than function method).

Could someone please take a look? Thank you.

@omalleyt12
Copy link
Contributor

To clarify, you would like the Trial summary to only show HyperParameters that were used in the Python code for this particular trial? I think this is possible but will require careful thought.

In general, the Oracle will attempt to provide a value for any HyperParameter it has seen so far. There's no way for the Oracle to know in your code that units_4 will only be active when num_layers > 4, so doing this would require making special note of what hyperparameters were actually accessed during build_model

We do have a concept of explicitly conditional hyperparameters (https://github.com/keras-team/keras-tuner/blob/master/kerastuner/engine/hyperparameters.py#L393), but right now that is not reflected in the Trial summary

@rcmagic1
Copy link
Author

rcmagic1 commented Aug 28, 2019

Yes, my expectation was that the trial summary would show all the hyperparameter settings used for that trial. It's fine if too many hp values are shown (e.g. if num_layers == 4 but perhaps units_[0-10] were shown), but in the case where too few are shown I get an incomplete view of that model (e.g. if num_layers == 4 but only units_[0-1] shown. I would also like to see units_[2-3]).

In the meantime my fix is to print( model.summary() ) in my build function to verify the model architecture, but it's much more verbose than the trail summary.

Would the num_layers and units_# hps discussed above be considered conditional hps? Or is there another mechanism to explicitly define the conditional relationship between hp?

Thanks again.

@omalleyt12
Copy link
Contributor

Thanks, will look into this

The mechanism for specifying conditional hyperparameters is:

a = hp.Int('a', 0, 10)
with hp.conditionaLscope('a', 5):
  b = hp.Int('b', 0, 10)

with that syntax, b would only be active when a == 5

@rcmagic1
Copy link
Author

Thank you!

@omalleyt12
Copy link
Contributor

All HyperParameters should be shown now, can you please try with the master branch?

In the future, for display purposes we may hide hyperparameter values that were set but never accessed during a trial, but closing this issue for now as that's more of an enhancement

@rcmagic1
Copy link
Author

Hi @omalleyt12 thank you.
I just tried master branch and am now NOT seeing any hyperparameters nested in a for loop. If you recall before I was seeing some of the nested hyperparameters, just not all the relevant ones.
For example, going back to code similar to my original example if I have the following code:

for i in range(hp.Int('num_fc_layers', 1, 4)):
model.add(layers.Dense(units=hp.Int('units_' + str(i),
min_value=10,
max_value=80,
step=10),
activation='relu'))

Previously I saw an incomplete list of "units_" printed out; however, now I see no "unit_" printed out.

[Trial complete]
[Trial summary]

Hp values:
|-num_fc_layers: 6
|-Score: 2.2857720851898193
|-Best step: 0

This isn't a major showstopper for me at this point, but I wanted to follow up to let you know. Thanks again.

@omalleyt12 omalleyt12 reopened this Oct 10, 2019
@omalleyt12
Copy link
Contributor

@rcmagic1 Thanks for letting me know, could you provide a minimal reproduction including the Tuner code you are using?

When I run the examples below I see the expected values (first example takes on the default value for 'num_fc_layers' so only has 'units_0', second example has all HPs)

import tensorflow as tf
import kerastuner as kt

def build_model(hp):
    model = tf.keras.Sequential()
    for i in range(hp.Int('num_fc_layers', 1, 4)):
        model.add(tf.keras.layers.Dense(
            units=hp.Int('units_' + str(i),
            min_value=10,
            max_value=80,
            step=10),
        activation='relu'))

hp = kt.HyperParameters()
build_model(hp)
print(hp.values)

hp = kt.HyperParameters()
hp.Fixed('num_fc_layers', 4)
build_model(hp)
print(hp.values)

@rcmagic1
Copy link
Author

certainly

  # define NN model architecture for hyperparameter optimization/tuning
   # alternatively can define function; howeve, class allows additional args (e.g. num_features, num_classes) to be passed in
   class NNHyperModel(HyperModel):
   
       def __init__( self,
                     num_features,
                     num_classes ):
           self.num_features  = num_features
           self.num_classes   = num_classes
   
       def build(self, hp):
           # note:  hp.Int behaves similar to "range" function where the max_value is excluded, so need to add 1 to included it
           num_fc_layers_min  = 1
           num_fc_layers_max  = 4

           num_units_min  =  10
           num_units_max  =  80
           num_units_step =  10

           dropout_min  =  0
           dropout_max  =  0.3
           dropout_step =  0.1

           # construct model
           model = Sequential()

           #Note:  Keras bug with InputLayer, so use separate Dense layer with input_dim for now: https://github.com/keras-team/keras/issues/10417
           #model.add( Input( shape=(self.num_features,) ) )
           model.add( Dense( input_dim=self.num_features, units=hp.Int('num_units',min_value=num_units_min,max_value=num_units_max+1,step=num_units_step), activation='relu', kernel_initializer='he_normal' ) )
           model.add( Dropout( hp.Float('dropout',min_value=dropout_min,max_value=dropout_max,step=dropout_step) ) )
           for i in range( hp.Int('num_fc_layers',min_value=num_fc_layers_min-1,max_value=num_fc_layers_max) ):
               model.add( Dense( units=hp.Int('num_units_'+str(i+1),min_value=num_units_min,max_value=num_units_max+1,step=num_units_step), activation='relu', kernel_initializer='he_normal' ) )
               model.add( Dropout( hp.Float('dropout_'+str(i+1),min_value=dropout_min,max_value=dropout_max,step=dropout_step) ) )
           model.add( Dense(units=self.num_classes, activation='softmax', kernel_initializer='he_normal' ) )
   
           # compile model
           model.compile( optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'] )

           print( model.summary() )

           return model

##########
hypermodel = NNHyperModel( num_features = num_features,
                                  num_classes = num_classes )
##########
       tuner = RandomSearch( hypermodel,
                             objective='val_loss',
                             max_trials=200,
                             executions_per_trial=2,
                             directory=CACHE_DIR,
                             project_name=join(CACHE_DIR,"hypertuning_randomsearch_dir") )
       tuner.search_space_summary()

       early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)
       callback_list = [ early_stopping ]

       # split training data into stratified train/dev sets
       X_train, X_dev, Y_train, Y_dev = train_test_split( X_training_set, Y_training_set, test_size=validation_split, stratify=Y_training_set, random_state=42 )

       tuner.search( x=X_train, y=Y_train,
                     validation_data=(X_dev, Y_dev),
                     batch_size=32, epochs=300, callbacks=callback_list )

       model = tuner.get_best_models(num_models=1)[0]

       tuner.results_summary()

and I'll get results like this. Note the missing num_units_1 and num_units_2 when num_fc_layers=2

[Trial complete]
[Trial summary]

Hp values:
|-dropout: 0.30000000000000004
|-num_fc_layers: 2
|-num_units: 80
|-tuner/epochs: 10
|-Score: 2.1471301317214966
|-Best step: 0

@omalleyt12
Copy link
Contributor

@rcmagic1 Thanks for the repro!!

Ah ok, yep found a bug in the MultiExecutionTuner (which RandomSearch etc. rely on) where the HyperParameters were being copied before sending to the build_model and so the Oracle space wasn't updating properly after the initial build

This issue should be fixed by #113

@rcmagic1
Copy link
Author

Great! Let me know when the fix is available and I'll validate it. Thanks again.

@omalleyt12
Copy link
Contributor

@rcmagic1 Thanks! The fix is available now, please let me know if the HyperParameters look right now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants