diff --git a/vizier/_src/benchmarks/analyzers/plot_utils.py b/vizier/_src/benchmarks/analyzers/plot_utils.py index 2e4e3d143..3e4cde805 100644 --- a/vizier/_src/benchmarks/analyzers/plot_utils.py +++ b/vizier/_src/benchmarks/analyzers/plot_utils.py @@ -225,7 +225,6 @@ def _metadata_to_str(metadata: vz.Metadata) -> str: percentiles=(elem_for_metric.percentile_error_bar,), **kwargs, ) - ax.set_xlabel('# of Trials') elif plot_type == 'scatter': plot = elem_for_metric.plot_array ax.scatter( @@ -254,6 +253,7 @@ def _metadata_to_str(metadata: vz.Metadata) -> str: ) else: raise ValueError(f'{plot_type} plot not yet supported!') + ax.set_xlabel(elem_for_metric.xlabel) ax.set_yscale(elem_for_metric.yscale) ax.yaxis.set_major_locator(mpl.ticker.LinearLocator(20)) ax.yaxis.set_minor_locator(mpl.ticker.LinearLocator(100)) diff --git a/vizier/_src/benchmarks/analyzers/state_analyzer.py b/vizier/_src/benchmarks/analyzers/state_analyzer.py index df951780b..1d6d321b7 100644 --- a/vizier/_src/benchmarks/analyzers/state_analyzer.py +++ b/vizier/_src/benchmarks/analyzers/state_analyzer.py @@ -54,6 +54,9 @@ class PlotElement: default='error-bar', validator=attrs.validators.in_(['error-bar', 'histogram', 'scatter']), ) + xlabel: str = attrs.field( + default='Num Trials', validator=attrs.validators.instance_of(str) + ) yscale: str = attrs.field( default='linear', validator=attrs.validators.in_(['linear', 'symlog', 'logit']),