Skip to content
Permalink
Browse files

fix(encoder): add normalize option in cvae encoder

  • Loading branch information...
Larryjianfeng committed Jul 23, 2019
1 parent eb48779 commit 649ed1314b9c12167a958d6f8e259944ebdf96e3
Showing with 6 additions and 1 deletion.
  1. +6 −1 gnes/encoder/image/cvae.py
@@ -28,6 +28,7 @@ def __init__(self, model_dir: str,
latent_dim: int = 300,
batch_size: int = 64,
select_method: str = 'MEAN',
l2_normalize: bool = False,
use_gpu: bool = True,
*args, **kwargs):
super().__init__(*args, **kwargs)
@@ -36,6 +37,7 @@ def __init__(self, model_dir: str,
self.latent_dim = latent_dim
self.batch_size = batch_size
self.select_method = select_method
self.l2_normalize = l2_normalize
self.use_gpu = use_gpu

def post_init(self):
@@ -69,4 +71,7 @@ def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret.append(_var)
elif self.select_method == 'MEAN_VAR':
ret.append(np.concatenate([_mean, _var]), axis=1)
return np.concatenate(ret, axis=0).astype(np.float32)
v = np.concatenate(ret, axis=0).astype(np.float32)
if self.l2_normalize:
v = v / (v**2).sum(axis=1, keepdims=True)**0.5
return v

0 comments on commit 649ed13

Please sign in to comment.
You can’t perform that action at this time.