Skip to content
Permalink
Browse files

fix(batching): enable to process three dimension output in batching

  • Loading branch information...
jemmyshin committed Aug 12, 2019
1 parent b0f22d0 commit 64163cb15b614d67c47694b37da20c384754e9c7
Showing with 14 additions and 13 deletions.
  1. +10 −12 gnes/encoder/image/base.py
  2. +4 −1 gnes/helper.py
@@ -74,7 +74,6 @@ def forward(self, x):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = self._model.to(self._device)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
import torch
self._model.eval()
@@ -87,13 +86,21 @@ def _padding(img: List['np.ndarray']):
if im.shape[0] < max_lenth else im for im in img]
return img, max_lenth

@batching
# for video
if len(img[0].shape) == 4:
img, max_lenth = _padding(img)
# for image
else:
max_lenth = -1

@batching(chunk_dim=max_lenth)
def _encode(_, img: List['np.ndarray']):
import copy

if len(img[0].shape) == 4:
img_ = copy.deepcopy(img)
img_ = np.concatenate((list(img_[i] for i in range(len(img_)))), axis=0)

img_for_torch = np.array(img_, dtype=np.float32).transpose(0, 3, 1, 2)
else:
img_for_torch = np.array(img, dtype=np.float32).transpose(0, 3, 1, 2)
@@ -110,17 +117,8 @@ def _encode(_, img: List['np.ndarray']):
result_npy.append(encodes.data.cpu().numpy())

output = np.array(result_npy, dtype=np.float32)

if len(img[0].shape) == 4:
output = output.reshape((len(img), max_lenth, -1))
return output

# for video
if len(img[0].shape) == 4:
padding_image, max_lenth = _padding(img)
output = _encode(self, padding_image)
# for image
else:
output = _encode(self, img)
output = _encode(self, img)

return output
@@ -375,7 +375,7 @@ def pooling_torch(data_tensor, mask_tensor, pooling_strategy):

def batching(func: Callable[[Any], np.ndarray] = None, *,
batch_size: Union[int, Callable] = None, num_batch=None,
iter_axis: int = 0, concat_axis: int = 0):
iter_axis: int = 0, concat_axis: int = 0, chunk_dim=-1):
def _batching(func):
@wraps(func)
def arg_wrapper(self, data, label=None, *args, **kwargs):
@@ -418,6 +418,9 @@ def arg_wrapper(self, data, label=None, *args, **kwargs):
if len(final_result) and concat_axis is not None and isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, concat_axis)

if chunk_dim != -1:
final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1]))

if len(final_result):
return final_result

0 comments on commit 64163cb

Please sign in to comment.
You can’t perform that action at this time.