diff --git a/scoring/scoring.py b/scoring/scoring.py index 12aae1357..7e52bd08c 100644 --- a/scoring/scoring.py +++ b/scoring/scoring.py @@ -47,9 +47,10 @@ 'ctc_loss', 'wer', 'l1_loss', + 'loss', ] -MAX_EVAL_METRICS = ['average_precision', 'ssim', 'accuracy', 'bleu_score'] +MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] def generate_eval_cols(metrics): @@ -128,8 +129,7 @@ def get_index_that_reaches_target(workload_df, op = operator.le if is_minimized else operator.ge validation_target_reached = validation_series.apply( lambda x: op(x, validation_target)) - - target_reached = pd.Series(validation_target_reached[0]) + target_reached = pd.Series(validation_target_reached) # Remove trials that never reach the target target_reached = target_reached[target_reached.apply(np.any)] diff --git a/setup.cfg b/setup.cfg index a7ce5ebb2..4c2d9e6d3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ install_requires = absl-py==1.4.0 numpy>=1.23 pandas>=2.0.1 + tabulate==0.9.0 tensorflow==2.12.0 tensorflow-datasets==4.9.2 tensorflow-probability==0.20.0