Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): add normalize option in cvae encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Jul 23, 2019
1 parent eb48779 commit 649ed13
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion gnes/encoder/image/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, model_dir: str,
latent_dim: int = 300,
batch_size: int = 64,
select_method: str = 'MEAN',
l2_normalize: bool = False,
use_gpu: bool = True,
*args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -36,6 +37,7 @@ def __init__(self, model_dir: str,
self.latent_dim = latent_dim
self.batch_size = batch_size
self.select_method = select_method
self.l2_normalize = l2_normalize
self.use_gpu = use_gpu

def post_init(self):
Expand Down Expand Up @@ -69,4 +71,7 @@ def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret.append(_var)
elif self.select_method == 'MEAN_VAR':
ret.append(np.concatenate([_mean, _var]), axis=1)
return np.concatenate(ret, axis=0).astype(np.float32)
v = np.concatenate(ret, axis=0).astype(np.float32)
if self.l2_normalize:
v = v / (v**2).sum(axis=1, keepdims=True)**0.5
return v

0 comments on commit 649ed13

Please sign in to comment.