77import json
88import functools
99import datetime
10+ import csv
1011
1112import pandas
1213import numpy
1314import keras
1415import librosa
16+ import sklearn .metrics
1517
1618from . import features , urbansound8k , common
1719from .models import sbcnn
@@ -38,28 +40,90 @@ def dataframe_generator(X, Y, loader, batchsize=10, n_classes=10):
3840 yield batch
3941
4042
43+ class LogCallback (keras .callbacks .Callback ):
44+ def __init__ (self , log_path , score_epoch ):
45+ super ().__init__ ()
46+
47+ self .log_path = log_path
48+ self .score = score_epoch
49+
50+ self ._log_file = None
51+ self ._csv_writer = None
52+
53+ def __del__ (self ):
54+ if self ._log_file :
55+ self ._log_file .close ()
56+
57+
58+ def write_entry (self , epoch , data ):
59+ data = data .copy ()
60+
61+ if not self ._csv_writer :
62+ # create writer when we know what fields
63+ self ._log_file = open (self .log_path , 'w' )
64+ fields = ['epoch' ] + sorted (data .keys ())
65+ self ._csv_writer = csv .DictWriter (self ._log_file , fields )
66+ self ._csv_writer .writeheader ()
67+
68+ data ['epoch' ] = epoch
69+ self ._csv_writer .writerow (data )
70+ self ._log_file .flush () # ensure data hits disk
71+
72+ def on_epoch_end (self , epoch , logs ):
73+ logs = logs .copy ()
74+
75+ more = self .score () # uses current model
76+ for k , v in more .items ():
77+ logs [k ] = v
78+
79+ self .write_entry (epoch , logs )
80+
81+
82+
83+
4184def train_model (out_dir , fold , builder ,
42- loader , val_loader ,
43- frame_samples , window_frames ,
44- train_samples = 12000 , val_samples = 3000 ,
45- batch_size = 200 , epochs = 50 , seed = 1 , learning_rate = 3e-4 ):
85+ loader , val_loader , settings , seed = 1 ):
4686 """Train a single model"""
4787
88+ frame_samples = settings ['hop_length' ]
89+ train_samples = settings ['train_samples' ]
90+ window_frames = settings ['frames' ]
91+ val_samples = settings ['val_samples' ]
92+ epochs = settings ['epochs' ]
93+ batch_size = settings ['batch' ]
94+ #learning_rate = settings['learning_rate']
95+
96+ train , val = fold
97+
98+ def top3 (y_true , y_pred ):
99+ return keras .metrics .top_k_categorical_accuracy (y_true , y_pred , k = 3 )
100+
48101 model = builder ()
49102 model .compile (loss = 'categorical_crossentropy' ,
50103 optimizer = keras .optimizers .SGD (lr = 0.001 , momentum = 0.95 , nesterov = True ),
51104 metrics = ['accuracy' ])
52105
53-
54106 model_path = os .path .join (out_dir , 'e{epoch:02d}-v{val_loss:.2f}.t{loss:.2f}.model.hdf5' )
55107 checkpoint = keras .callbacks .ModelCheckpoint (model_path , monitor = 'val_acc' , mode = 'max' ,
56108 period = 1 , verbose = 1 , save_best_only = False )
57- callbacks_list = [checkpoint ]
58109
59- train , val = fold
110+ def voted_score ():
111+ y_pred = features .predict_voted (settings , model , val ,
112+ loader = val_loader , method = 'mean' , overlap = 0.5 )
113+ class_pred = numpy .argmax (y_pred , axis = 1 )
114+ acc = sklearn .metrics .accuracy_score (val .classID , class_pred )
115+ d = {
116+ 'voted_val_acc' : acc ,
117+ }
118+ return d
119+ log_path = os .path .join (out_dir , 'train.csv' )
120+ log = LogCallback (log_path , voted_score )
121+
122+
60123 train_gen = dataframe_generator (train , train .classID , loader = loader , batchsize = batch_size )
61124 val_gen = dataframe_generator (val , val .classID , loader = val_loader , batchsize = batch_size )
62125
126+ callbacks_list = [checkpoint , log ]
63127 hist = model .fit_generator (train_gen , validation_data = val_gen ,
64128 steps_per_epoch = math .ceil (train_samples / batch_size ),
65129 validation_steps = math .ceil (val_samples / batch_size ),
@@ -147,7 +211,6 @@ def settings(args):
147211 train_settings = {}
148212 for k in default_training_settings .keys ():
149213 v = args .get (k , default_training_settings [k ])
150- print ('v' , k , v , args .get (k ))
151214 train_settings [k ] = v
152215 return train_settings
153216
@@ -217,12 +280,7 @@ def build_model():
217280 builder = build_model ,
218281 loader = functools .partial (load , validation = False ),
219282 val_loader = functools .partial (load , validation = True ),
220- frame_samples = feature_settings ['hop_length' ],
221- window_frames = model_settings ['frames' ],
222- epochs = train_settings ['epochs' ],
223- train_samples = train_settings ['train_samples' ],
224- val_samples = train_settings ['val_samples' ],
225- batch_size = train_settings ['batch' ])
283+ settings = exsettings )
226284
227285
228286
0 commit comments