Skip to content
Merged
6 changes: 3 additions & 3 deletions scoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@
'ctc_loss',
'wer',
'l1_loss',
'loss',
]

MAX_EVAL_METRICS = ['average_precision', 'ssim', 'accuracy', 'bleu_score']
MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']


def generate_eval_cols(metrics):
Expand Down Expand Up @@ -128,8 +129,7 @@ def get_index_that_reaches_target(workload_df,
op = operator.le if is_minimized else operator.ge
validation_target_reached = validation_series.apply(
lambda x: op(x, validation_target))

target_reached = pd.Series(validation_target_reached[0])
target_reached = pd.Series(validation_target_reached)
# Remove trials that never reach the target
target_reached = target_reached[target_reached.apply(np.any)]

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ install_requires =
absl-py==1.4.0
numpy>=1.23
pandas>=2.0.1
tabulate==0.9.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it used anywhere else in the repo? Couldn't find anything with a quick search.

tensorflow==2.12.0
tensorflow-datasets==4.9.2
tensorflow-probability==0.20.0
Expand Down