Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[issue-787] added device parameter to the vision checks context #832

Merged
merged 23 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Set Up Env
run: make env
- name: Print Requirements
run: pip freeze
- name: Run Tests
run: make test
coverage:
Expand Down
6 changes: 3 additions & 3 deletions deepchecks/core/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class SingleDatasetBaseCheck(BaseCheck):
context_type: ClassVar[Optional[Type[Any]]] = None # TODO: Base context type

@abc.abstractmethod
def run(self, dataset, model=None) -> CheckResult:
def run(self, dataset, model=None, **kwargs) -> CheckResult:
"""Run check."""
raise NotImplementedError()

Expand All @@ -463,7 +463,7 @@ class TrainTestBaseCheck(BaseCheck):
context_type: ClassVar[Optional[Type[Any]]] = None # TODO: Base context type

@abc.abstractmethod
def run(self, train_dataset, test_dataset, model=None) -> CheckResult:
def run(self, train_dataset, test_dataset, model=None, **kwargs) -> CheckResult:
"""Run check."""
raise NotImplementedError()

Expand All @@ -474,7 +474,7 @@ class ModelOnlyBaseCheck(BaseCheck):
context_type: ClassVar[Optional[Type[Any]]] = None # TODO: Base context type

@abc.abstractmethod
def run(self, model) -> CheckResult:
def run(self, model, **kwargs) -> CheckResult:
"""Run check."""
raise NotImplementedError()

Expand Down
72 changes: 52 additions & 20 deletions deepchecks/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# TODO: This file should be completely modified
# pylint: disable=broad-except,not-callable
import copy
from typing import Tuple, Mapping, Optional, Any
from typing import Tuple, Mapping, Optional, Any, Union
from collections import OrderedDict

from ignite.metrics import Metric
import torch
from torch import nn
from ignite.metrics import Metric

from deepchecks.vision.utils.validation import validate_model
from deepchecks.core.check import (
Expand All @@ -32,6 +33,7 @@
DeepchecksNotSupportedError, DeepchecksValueError
)
from deepchecks.vision.dataset import VisionData, TaskType
from deepchecks.vision.utils.validation import apply_to_tensor


__all__ = [
Expand Down Expand Up @@ -63,6 +65,9 @@ class Context:
See <a href=
"https://scikit-learn.org/stable/modules/model_evaluation.html#from-binary-to-multiclass-and-multilabel">
scikit-learn docs</a>
device : Union[str, torch.device], default: None
processing unit for use

"""

def __init__(self,
Expand All @@ -71,7 +76,8 @@ def __init__(self,
model: nn.Module = None,
model_name: str = '',
scorers: Mapping[str, Metric] = None,
scorers_per_class: Mapping[str, Metric] = None
scorers_per_class: Mapping[str, Metric] = None,
device: Union[str, torch.device, None] = None
):
# Validations
if train is None and test is None and model is None:
Expand All @@ -90,6 +96,7 @@ def __init__(self,
self._user_scorers = scorers
self._user_scorers_per_class = scorers_per_class
self._model_name = model_name
self._device = torch.device(device) if isinstance(device, str) else device

# Properties
# Validations note: We know train & test fit each other so all validations can be run only on train
Expand Down Expand Up @@ -124,6 +131,11 @@ def model_name(self):
"""Return model name."""
return self._model_name

@property
def device(self) -> Optional[torch.device]:
"""Return device specified by the user."""
return self._device

def have_test(self):
"""Return whether there is test dataset defined."""
return self._test is not None
Expand Down Expand Up @@ -161,17 +173,20 @@ class SingleDatasetCheck(SingleDatasetBaseCheck):

context_type = Context

def run(self, dataset, model=None) -> CheckResult:
def run(
self,
dataset: VisionData,
model: Optional[nn.Module] = None,
device: Union[str, torch.device, None] = None
) -> CheckResult:
"""Run check."""
assert self.context_type is not None
context = self.context_type( # pylint: disable=not-callable
dataset,
model=model
)
context = self.context_type(dataset, model=model, device=device)

self.initialize_run(context)

for batch in dataset.get_data_loader():
batch = apply_to_tensor(batch, lambda x: x.to(device))
self.update(context, batch)
context.flush_cached_inference()

Expand All @@ -198,22 +213,26 @@ class TrainTestCheck(TrainTestBaseCheck):

context_type = Context

def run(self, train_dataset, test_dataset, model=None) -> CheckResult:
def run(
self,
train_dataset: VisionData,
test_dataset: VisionData,
model: Optional[nn.Module] = None,
device: Union[str, torch.device, None] = None
) -> CheckResult:
"""Run check."""
assert self.context_type is not None
context = self.context_type( # pylint: disable=not-callable
train_dataset,
test_dataset,
model=model
)
context = self.context_type(train_dataset, test_dataset, model=model, device=device)

self.initialize_run(context)

for batch in context.train.get_data_loader():
batch = apply_to_tensor(batch, lambda x: x.to(device))
self.update(context, batch, dataset_name='train')
context.flush_cached_inference()

for batch in context.test.get_data_loader():
batch = apply_to_tensor(batch, lambda x: x.to(device))
self.update(context, batch, dataset_name='test')
context.flush_cached_inference()

Expand All @@ -237,10 +256,14 @@ class ModelOnlyCheck(ModelOnlyBaseCheck):

context_type = Context

def run(self, model) -> CheckResult:
def run(
self,
model: nn.Module,
device: Union[str, torch.device, None] = None
) -> CheckResult:
"""Run check."""
assert self.context_type is not None
context = self.context_type(model=model) # pylint: disable=not-callable
context = self.context_type(model=model, device=device)

self.initialize_run(context)
return finalize_check_result(self.compute(context), self)
Expand Down Expand Up @@ -268,7 +291,8 @@ def run(
test_dataset: Optional[VisionData] = None,
model: nn.Module = None,
scorers: Mapping[str, Metric] = None,
scorers_per_class: Mapping[str, Metric] = None
scorers_per_class: Mapping[str, Metric] = None,
device: Union[str, torch.device, None] = None
) -> SuiteResult:
"""Run all checks.

Expand All @@ -287,14 +311,22 @@ def run(
See <a href=
"https://scikit-learn.org/stable/modules/model_evaluation.html#from-binary-to-multiclass-and-multilabel">
scikit-learn docs</a>
device : Union[str, torch.device], default: None
processing unit for use

Returns
-------
SuiteResult
All results by all initialized checks
"""
context = Context(train_dataset, test_dataset, model,
scorers=scorers,
scorers_per_class=scorers_per_class)
context = Context(
train_dataset,
test_dataset,
model,
scorers=scorers,
scorers_per_class=scorers_per_class,
device=device
)

# Create instances of SingleDatasetCheck for train and test if train and test exist.
# This is needed because in the vision package checks update their internal state with update, so it will be
yromanyshyn marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
20 changes: 13 additions & 7 deletions deepchecks/vision/checks/performance/class_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,29 @@ def update(self, context: Context, batch: Any, dataset_name: str = 'train'):

def compute(self, context: Context) -> CheckResult:
"""Compute the metric result using the ignite metrics compute method and create display."""
self._state['train']['n_samples'] = context.train.get_samples_per_class()
self._state['test']['n_samples'] = context.test.get_samples_per_class()
self._state['classes'] = sorted(context.train.get_samples_per_class().keys())
self._state['train']['n_samples'] = context.train.n_of_samples_per_class
self._state['test']['n_samples'] = context.test.n_of_samples_per_class
self._state['classes'] = sorted(context.train.n_of_samples_per_class.keys())

results = []

for dataset_name in ['train', 'test']:
n_samples = self._state[dataset_name]['n_samples']
computed_metrics = (
(name, metric.compute().tolist())
for name, metric in self._state[dataset_name]['scorers'].items()
)
results.extend(
[dataset_name, class_name, name, class_score, n_samples[class_name]]
for name, score in [(name, metric.compute().tolist()) for name, metric in
self._state[dataset_name]['scorers'].items()]
for name, score in computed_metrics
# scorer returns numpy array of results with item per class
for class_score, class_name in zip(score, self._state['classes'])
)

results_df = pd.DataFrame(results, columns=['Dataset', 'Class', 'Metric', 'Value', 'Number of samples']
).sort_values(by=['Class'])
results_df = pd.DataFrame(
results,
columns=['Dataset', 'Class', 'Metric', 'Value', 'Number of samples']
).sort_values(by=['Class'])

fig = px.histogram(
results_df,
Expand Down