From dd0b4d09ae86c0bc0d5505de14a057e52a9da4d0 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 10 Dec 2022 11:50:03 -0800 Subject: [PATCH] [python] use jitscript model if available --- engines/python/src/test/resources/resnet18/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/engines/python/src/test/resources/resnet18/model.py b/engines/python/src/test/resources/resnet18/model.py index 0edeccf5c..a6ee751b1 100644 --- a/engines/python/src/test/resources/resnet18/model.py +++ b/engines/python/src/test/resources/resnet18/model.py @@ -47,7 +47,11 @@ def initialize(self, properties: dict): device_id = properties.get("device_id", "-1") device_id = "cpu" if device_id == "-1" else "cuda:" + device_id self.device = torch.device(device_id) - self.model = models.resnet18(pretrained=True).to(self.device) + if os.path.exists("resnet18.pt"): + self.model = torch.jit.load("resnet18.pt", + map_location=self.device) + else: + self.model = models.resnet18(pretrained=True).to(self.device) self.model.eval() self.image_processing = transforms.Compose([ transforms.Resize(256),