diff --git a/llmtune/cli/toolkit.py b/llmtune/cli/toolkit.py index 3be899a..cef55db 100644 --- a/llmtune/cli/toolkit.py +++ b/llmtune/cli/toolkit.py @@ -68,7 +68,7 @@ def run_one_experiment(config: Config, config_path: str) -> None: results_file_path = join(dir_helper.save_paths.results, "results.csv") if not exists(results_path) or exists(results_file_path): inference_runner = LoRAInference(test, test_column, config, dir_helper) - inference_runner.infer_all() + inference_runner.infer_test_set() RichUI.after_inference(results_path) else: RichUI.inference_found(results_path) diff --git a/llmtune/inference/generics.py b/llmtune/inference/generics.py index 24a2bfb..b42db50 100644 --- a/llmtune/inference/generics.py +++ b/llmtune/inference/generics.py @@ -7,5 +7,5 @@ def infer_one(self, prompt: str): pass @abstractmethod - def infer_all(self): + def infer_test_set(self): pass diff --git a/llmtune/inference/lora.py b/llmtune/inference/lora.py index 720822c..68d812c 100644 --- a/llmtune/inference/lora.py +++ b/llmtune/inference/lora.py @@ -64,7 +64,7 @@ def _get_merged_model(self, weights_path: str): return model, tokenizer - def infer_all(self): + def infer_test_set(self): results = [] prompts = self.test_dataset["formatted_prompt"] labels = self.test_dataset[self.label_column]