diff --git a/test/apis/batch/inferentia/handler.py b/test/apis/batch/inferentia/handler.py index d80c00b9aa..73f257d79a 100644 --- a/test/apis/batch/inferentia/handler.py +++ b/test/apis/batch/inferentia/handler.py @@ -46,6 +46,7 @@ def __init__(self, config, job_spec): self.model.load_state_dict(torch.load(model_name)) self.model.eval() elif config["device"] == "inf": + import torch_neuron self.model = torch.jit.load(model_name) else: