@@ -81,7 +81,7 @@ def on_epoch_end(self, epoch, logs):
8181
8282
8383
84- def train_model (out_dir , fold , builder ,
84+ def train_model (out_dir , train , val , builder ,
8585 loader , val_loader , settings , seed = 1 ):
8686 """Train a single model"""
8787
@@ -93,7 +93,7 @@ def train_model(out_dir, fold, builder,
9393 batch_size = settings ['batch' ]
9494 learning_rate = settings .get ('learning_rate' , 0.01 )
9595
96- train , val = fold
96+ assert len ( train ) > len ( val ) * 5 , 'training data should be much larger than validation'
9797
9898 def top3 (y_true , y_pred ):
9999 return keras .metrics .top_k_categorical_accuracy (y_true , y_pred , k = 3 )
@@ -159,7 +159,7 @@ def parse(args):
159159 common .add_arguments (parser )
160160 Settings .add_arguments (parser )
161161
162- a ('--fold' , type = int , default = 0 ,
162+ a ('--fold' , type = int , default = 1 ,
163163 help = '' )
164164 a ('--skip_model_check' , action = 'store_true' , default = False ,
165165 help = 'Skip checking whether model fits on STM32 device' )
@@ -182,6 +182,17 @@ def setup_keras():
182182 sess = tf .Session (config = session_config )
183183 B .set_session (sess )
184184
185+ def load_training_data (data , fold ):
186+ assert fold >= 1 # should be 1 indexed
187+ folds = urbansound8k .folds (data )
188+ assert len (folds ) == 10
189+ train_data = folds [fold - 1 ][0 ]
190+ val_data = folds [fold - 1 ][1 ]
191+ test_folds = folds [fold - 1 ][2 ].fold .unique ()
192+ assert len (test_folds ) == 1
193+ assert test_folds [0 ] == fold , (test_folds [0 ], '!=' , fold ) # by convention, test fold is fold number
194+ return train_data , val_data
195+
185196def main ():
186197 setup_keras ()
187198
@@ -216,10 +227,8 @@ def main():
216227
217228 features .maybe_download (feature_settings , feature_dir )
218229
219-
220230 data = urbansound8k .load_dataset ()
221- folds , test = urbansound8k .folds (data )
222- assert len (folds ) == 9
231+ train_data , val_data = load_training_data (data , fold )
223232
224233 def load (sample , validation ):
225234 augment = not validation and train_settings ['augment' ] != 0
@@ -245,7 +254,7 @@ def build_model():
245254 print ('Training model' , name )
246255 print ('Settings' , json .dumps (exsettings ))
247256
248- h = train_model (output_dir , folds [ fold ] ,
257+ h = train_model (output_dir , train_data , val_data ,
249258 builder = build_model ,
250259 loader = functools .partial (load , validation = False ),
251260 val_loader = functools .partial (load , validation = True ),
0 commit comments