diff --git a/engines/python/src/test/resources/resnet18/model.py b/engines/python/src/test/resources/resnet18/model.py index 0edeccf5c..a6ee751b1 100644 --- a/engines/python/src/test/resources/resnet18/model.py +++ b/engines/python/src/test/resources/resnet18/model.py @@ -47,7 +47,11 @@ def initialize(self, properties: dict): device_id = properties.get("device_id", "-1") device_id = "cpu" if device_id == "-1" else "cuda:" + device_id self.device = torch.device(device_id) - self.model = models.resnet18(pretrained=True).to(self.device) + if os.path.exists("resnet18.pt"): + self.model = torch.jit.load("resnet18.pt", + map_location=self.device) + else: + self.model = models.resnet18(pretrained=True).to(self.device) self.model.eval() self.image_processing = transforms.Compose([ transforms.Resize(256),