Skip to content

Commit

Permalink
fix nlp device kwargs (#2632)
Browse files Browse the repository at this point in the history
* fix_device_stuff

* reqs
  • Loading branch information
JKL98ISR committed Jul 16, 2023
1 parent 78eebd2 commit 6017cb9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
11 changes: 5 additions & 6 deletions deepchecks/nlp/utils/text_properties.py
Expand Up @@ -189,14 +189,13 @@ def get_transformer_model(
model_path = models_storage / 'onnx' / model_name

if model_path.exists():
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path, device_map=device)
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path).to(device or -1)

model = onnx.ORTModelForSequenceClassification.from_pretrained(
model_name,
export=True,
cache_dir=models_storage,
device_map=device
)
).to(device or -1)
# NOTE:
# 'optimum', after exporting/converting a model to the ONNX format,
# does not store it onto disk we need to save it now to not reconvert
Expand All @@ -207,7 +206,7 @@ def get_transformer_model(
model_path = models_storage / 'onnx' / 'quantized' / model_name

if model_path.exists():
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path, device_map=device)
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path).to(device or -1)

not_quantized_model = get_transformer_model(
property_name,
Expand All @@ -217,7 +216,7 @@ def get_transformer_model(
models_storage=models_storage
)

quantizer = onnx.ORTQuantizer.from_pretrained(not_quantized_model, device_map=device)
quantizer = onnx.ORTQuantizer.from_pretrained(not_quantized_model).to(device or -1)

quantizer.quantize(
save_dir=model_path,
Expand All @@ -227,7 +226,7 @@ def get_transformer_model(
per_channel=False
)
)
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path, device_map=device)
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path).to(device or -1)


def get_transformer_pipeline(
Expand Down
2 changes: 1 addition & 1 deletion requirements/nlp-prop-requirements.txt
@@ -1,2 +1,2 @@
optimum[onnxruntime]>=1.7.0
optimum[onnxruntime]>=1.8.8
fasttext>=0.8.0

0 comments on commit 6017cb9

Please sign in to comment.