Skip to content

Commit

Permalink
fix_tensor_gpu (#1197)
Browse files Browse the repository at this point in the history
* fix_tensor_gpu

* fix_tensor_gpu

* fix_tensor_gpu

* fix_tensor_gpu
  • Loading branch information
JKL98ISR committed Apr 5, 2022
1 parent 22d4a72 commit 0e7ec77
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 7 deletions.
5 changes: 4 additions & 1 deletion deepchecks/vision/batch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# ----------------------------------------------------------------------------
#
"""Contains code for BatchWrapper."""
from typing import Tuple, Iterable, Any, TypeVar, Callable, cast
from typing import Tuple, Iterable, Any, TypeVar, Callable, cast, TYPE_CHECKING

import torch

from deepchecks.core import DatasetKind

if TYPE_CHECKING:
from deepchecks.vision.context import Context


__all__ = ['Batch']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _label_to_image_batch(self, label_batch: List[torch.Tensor], image_batch: Li
return_bbox_image_batch = []
for image, label in zip(image_batch, label_batch):
return_bbox_image_batch.append(
self._label_to_image(label.detach().cpu().numpy(), image.shape[:2])
self._label_to_image(label.cpu().detach().numpy(), image.shape[:2])
)
return return_bbox_image_batch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from collections import defaultdict
from typing import Callable, TypeVar, Hashable, Dict, Union

import numpy as np
import pandas as pd

from deepchecks import ConditionResult
Expand Down Expand Up @@ -116,12 +115,12 @@ def update(self, context: Context, batch: Batch, dataset_kind: DatasetKind):
if dataset.task_type == TaskType.OBJECT_DETECTION:
for img, labels in zip(batch.images, batch.labels):
for label in labels:
label = np.array(label)
label = label.cpu().detach().numpy()
bbox = label[1:]
cropped_img = crop_image(img, *bbox)
if cropped_img.shape[0] == 0 or cropped_img.shape[1] == 0:
continue
class_id = label[0]
class_id = int(label[0])
imgs += [cropped_img]
target += [class_id]
else:
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/vision/utils/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,4 @@ def un_normalize_batch(tensor, mean: Sized, std: Sized, max_pixel_value: int = 2
std = torch.tensor(std, device=tensor.device).reshape(reshape_shape)
tensor = (tensor * std) + mean
tensor = tensor * torch.tensor(max_pixel_value, device=tensor.device).reshape(reshape_shape)
return tensor.cpu().numpy()
return tensor.cpu().detach().numpy()
2 changes: 1 addition & 1 deletion tests/vision/vision_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _hash_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, torch.Tensor):
image = Image.fromarray(image.cpu().numpy().squeeze())
image = Image.fromarray(image.cpu().detach().numpy().squeeze())

image = image.resize((10, 10))
image = image.convert('L')
Expand Down

0 comments on commit 0e7ec77

Please sign in to comment.