Skip to content

Commit 068c49e

Browse files
committed
Merge branch 'develop' of https://github.com/oracle/accelerated-data-science into ODSC-39737/allow_to_use_predict_locally
2 parents 1d8dd19 + 48b5254 commit 068c49e

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

ads/templates/score_pytorch.jinja2

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,21 @@ import pandas as pd
1010
from io import BytesIO
1111
import base64
1212
import logging
13+
from random import randint
14+
15+
16+
def get_torch_device():
17+
num_devices = torch.cuda.device_count()
18+
if num_devices == 0:
19+
return "cpu"
20+
if num_devices == 1:
21+
return "cuda:0"
22+
else:
23+
return f"cuda:{randint(0, num_devices-1)}"
24+
1325

1426
model_name = '{{model_file_name}}'
27+
device = torch.device(get_torch_device())
1528

1629
"""
1730
Inference script. This script is used for prediction by scoring server when schema is known.
@@ -59,6 +72,7 @@ def load_model(model_file_name=model_name):
5972

6073
{% endif %}
6174
print("Model is successfully loaded.")
75+
the_model = the_model.to(device)
6276
return the_model
6377

6478
@lru_cache(maxsize=1)
@@ -158,6 +172,7 @@ def pre_inference(data, input_schema_path):
158172
data = deserialize(data, input_schema_path)
159173

160174
# Add further data preprocessing if needed
175+
data = data.to(device)
161176
return data
162177

163178
def post_inference(yhat):
@@ -199,6 +214,6 @@ def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dir
199214

200215
with torch.no_grad():
201216
yhat = post_inference(
202-
model(inputs)
217+
model(inputs).to("cpu")
203218
)
204219
return {'prediction': yhat}

0 commit comments

Comments
 (0)