Skip to content
Permalink
Browse files

fix(image_encoder): enable batching encoding

  • Loading branch information...
numb3r3 committed Jul 17, 2019
1 parent 316c9db commit cba5e1905e272047481c41033474792f00b6da7a
Showing with 4 additions and 0 deletions.
  1. +2 −0 gnes/encoder/image/base.py
  2. +2 −0 gnes/encoder/image/inception.py
@@ -19,6 +19,7 @@
import numpy as np

from ..base import BaseImageEncoder
from ...helper import batching


class BasePytorchEncoder(BaseImageEncoder):
@@ -72,6 +73,7 @@ def forward(self, x):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = self._model.to(self._device)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
import torch
self._model.eval()
@@ -17,6 +17,7 @@
import numpy as np
from gnes.helper import batch_iterator
from ..base import BaseImageEncoder
from ...helper import batching
from PIL import Image


@@ -59,6 +60,7 @@ def post_init(self):
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
img = [(np.array(Image.fromarray(im).resize((self.inception_size_x,

0 comments on commit cba5e19

Please sign in to comment.
You can’t perform that action at this time.