diff --git a/src/instructlab/eval/mmlu.py b/src/instructlab/eval/mmlu.py index 5589abb4..8ae4eb73 100644 --- a/src/instructlab/eval/mmlu.py +++ b/src/instructlab/eval/mmlu.py @@ -6,6 +6,7 @@ # Third Party from lm_eval.evaluator import simple_evaluate # type: ignore from lm_eval.tasks import TaskManager # type: ignore +import torch # First Party from instructlab.eval.evaluator import Evaluator @@ -58,6 +59,7 @@ def run(self) -> tuple: tasks=self.tasks, num_fewshot=self.few_shots, batch_size=self.batch_size, + device=("cuda" if torch.cuda.is_available() else "cpu"), ) results = mmlu_output["results"]