diff --git a/src/parse_results.py b/src/parse_results.py index 087afe7..92d84b9 100644 --- a/src/parse_results.py +++ b/src/parse_results.py @@ -4,7 +4,7 @@ from typing import List, Optional from src.metrics import Metrics -from src.utils import parse_config_args +from src.utils import parse_config_args, parse_config_arg def get_arg_parser() -> ArgumentParser: @@ -25,9 +25,9 @@ def get_arg_parser() -> ArgumentParser: DEFAULT_COLUMNS = ( "Setting", - Metrics.BATCH_SIZE, Metrics.INPUT_LENGTH, Metrics.TOKENS_SAMPLE, + Metrics.BATCH_SIZE, Metrics.THROUGHPUT_E2E, Metrics.LATENCY_E2E, ) @@ -54,6 +54,12 @@ def read_data(input_file: Path): return data +def parse_key(key: Optional[str]) -> Optional[str]: + if key is None: + return key + return getattr(Metrics, key.upper(), key) + + def make_table(data, cols): from markdownTable import markdownTable @@ -61,12 +67,6 @@ def make_table(data, cols): return markdownTable(data).getMarkdown() -def parse_key(key: Optional[str]) -> Optional[str]: - if key is None: - return key - return getattr(Metrics, key.upper(), key) - - def make_compare_table(data, cols, compare_value, compare_col): from markdownTable import markdownTable @@ -104,13 +104,20 @@ def make_compare_table(data, cols, compare_value, compare_col): def filter_data(data, filters): if filters is None: return data - filters = parse_config_args(filters) - filters = {parse_key(key): value for key, value in filters.items()} + + parsed_filters = {} + for filter in filters: + key, value = parse_config_arg(filter) + key = parse_key(key) + if key not in parsed_filters: + parsed_filters[key] = [] + parsed_filters[key].append(value) + filtered_data = [] for x in data: filter = True - for key, value in filters.items(): - filter = filter and x[key] == value + for key, value in parsed_filters.items(): + filter = filter and x[key] in value if filter: filtered_data.append(x) return filtered_data