In [None]:
import das.utils, das.data, das.train, das.io, das.evaluate, das.predict
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import flammkuchen

## Load the model
Change `model_savename` - if the model is saved at `res/20220601_103543_model.h5`, then `model_savename="res/20220601_103543"`.

In [None]:
model_savename = 'res/20220601_103543'
model, params = das.utils.load_model_and_params(model_savename)

## Load the data

In [None]:
data = das.io.load(params['data_dir'], x_suffix=params['x_suffix'], y_suffix=params['y_suffix'])

## Prep the data for tensorflow

In [None]:
data_gen = das.data.AudioSequence(data['train']['x'],
                                  data['train']['y'],
                                  shuffle=True,
                                  **params)
val_gen = das.data.AudioSequence(data['val']['x'],
                             data['val']['y'],
                             shuffle=False,
                             **params)

## Prep callbacks
The new model will be saved to `model_savename + "_continued_model.h5"`, in our example "res/20220601_103543_continued_model.h5".

In [None]:
das.utils.save_params(params, model_savename + "_continued")

checkpoint_save_name = model_savename + "_continued_model.h5"  # this will overwrite intermediates from previous epochs
callbacks = [
    ModelCheckpoint(checkpoint_save_name, save_best_only=True, save_weights_only=False, monitor='val_loss', verbose=1),
    EarlyStopping(monitor='val_loss', patience=20, verbose=1),
]

## Fit the model

In [None]:
fit_hist = model.fit(
    data_gen,
    epochs=params['nb_epoch'],
    steps_per_epoch=min(len(data_gen), 100),
    verbose=params['verbose'],
    validation_data=val_gen,
    callbacks=callbacks,
    class_weight=params['class_weights'],
)

## Test the model

In [None]:
print(f'   Re-loading last best model from {checkpoint_save_name}.')
model.load_weights(checkpoint_save_name)

print('   Predicting.')
x_test, y_test, y_pred = das.evaluate.evaluate_probabilities(x=data['test']['x'], y=data['test']['y'], model=model, params=params)

labels_test = das.predict.labels_from_probabilities(y_test)
labels_pred = das.predict.labels_from_probabilities(y_pred)

print('   Evaluating.')
conf_mat, report = das.evaluate.evaluate_segments(labels_test, labels_pred, params['class_names'], report_as_dict=True)
print(conf_mat)
print(report)

save_filename = "{0}_results.h5".format(model_savename + "_continued")
print(f'   Saving to {save_filename}.')
ddd = {
    # 'fit_hist': fit_hist.history,
    'confusion_matrix': conf_mat,
    'classification_report': report,
    'x_test': x_test,
    'y_test': y_test,
    'y_pred': y_pred,
    'labels_test': labels_test,
    'labels_pred': labels_pred,
    'params': params,
}
flammkuchen.save(save_filename, ddd)