Skip to content

Commit

Permalink
remove onnxruntime-gpu requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Feb 28, 2024
1 parent e525054 commit f4fdad0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from albumentations.core.transforms_interface import ImageOnlyTransform
import cv2
import pandas as pd
import onnxruntime as ort

from rastervision.pipeline.file_system.utils import (file_exists, file_to_json,
get_tmp_dir)
from rastervision.pipeline.config import (build_config, Config, ConfigError,
upgrade_config)

if TYPE_CHECKING:
import onnxruntime as ort
from rastervision.pytorch_learner import LearnerConfig

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -456,7 +456,7 @@ class ONNXRuntimeAdapter:
also outputs PyTorch Tensors.
"""

def __init__(self, ort_session: ort.InferenceSession) -> None:
def __init__(self, ort_session: 'ort.InferenceSession') -> None:
"""Constructor.
Args:
Expand All @@ -482,6 +482,8 @@ def from_file(cls, path: str, providers: Optional[List[str]] = None
Returns:
ONNXRuntimeAdapter: An ONNXRuntimeAdapter instance.
"""
import onnxruntime as ort

if providers is None:
providers = ort.get_available_providers()
log.info(f'Using ONNX execution providers: {providers}')
Expand Down
1 change: 0 additions & 1 deletion rastervision_pytorch_learner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@ opencv-python-headless==4.9.0.80
matplotlib==3.8.2
tqdm==4.66.1
onnx==1.15.0
onnxruntime-gpu==1.17

0 comments on commit f4fdad0

Please sign in to comment.