Skip to content

Commit

Permalink
Warn if cuda present but not used
Browse files Browse the repository at this point in the history
  • Loading branch information
ItayGabbay committed May 24, 2022
1 parent 95d8481 commit 8be63cc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
6 changes: 3 additions & 3 deletions deepchecks/vision/base_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(
self,
dataset: VisionData,
model: Optional[nn.Module] = None,
device: Union[str, torch.device, None] = 'cpu',
device: Union[str, torch.device, None] = None,
random_state: int = 42,
n_samples: Optional[int] = 10_000
) -> CheckResult:
Expand Down Expand Up @@ -110,7 +110,7 @@ def run(
train_dataset: VisionData,
test_dataset: VisionData,
model: Optional[nn.Module] = None,
device: Union[str, torch.device, None] = 'cpu',
device: Union[str, torch.device, None] = None,
random_state: int = 42,
n_samples: Optional[int] = 10_000
) -> CheckResult:
Expand Down Expand Up @@ -189,7 +189,7 @@ class ModelOnlyCheck(ModelOnlyBaseCheck):
def run(
self,
model: nn.Module,
device: Union[str, torch.device, None] = 'cpu',
device: Union[str, torch.device, None] = None,
random_state: int = 42
) -> CheckResult:
"""Run check."""
Expand Down
9 changes: 7 additions & 2 deletions deepchecks/vision/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self,
model_name: str = '',
scorers: Mapping[str, Metric] = None,
scorers_per_class: Mapping[str, Metric] = None,
device: Union[str, torch.device, None] = 'cpu',
device: Union[str, torch.device, None] = None,
random_state: int = 42,
n_samples: int = None
):
Expand All @@ -72,9 +72,14 @@ def __init__(self,
if train and test:
train.validate_shared_label(test)

if device is None:
device = 'cpu'
if torch.cuda.is_available():
warnings.warn('Checks will run on the cpu by default.'
'To make use of cuda devices, use the device parameter in the run function.')
self._device = torch.device(device) if isinstance(device, str) else (device if device else torch.device('cpu'))
self._prediction_formatter_error = {}

self._prediction_formatter_error = {}
if model is not None:
if not isinstance(model, nn.Module):
warnings.warn('Deepchecks can\'t validate that model is in evaluation state. Make sure it is to '
Expand Down

0 comments on commit 8be63cc

Please sign in to comment.