Skip to content
Permalink
Browse files

fix(encoder): fix tf scope error in cvae encoder

  • Loading branch information...
Larryjianfeng committed Jul 23, 2019
1 parent ab6c88c commit eb487799b3e4b602738765d9ad5edea997147930
Showing with 12 additions and 11 deletions.
  1. +12 −11 gnes/encoder/image/cvae.py
@@ -41,19 +41,20 @@ def __init__(self, model_dir: str,
def post_init(self):
import tensorflow as tf
from .cvae_cores.model import CVAE
g = tf.Graph()
with g.as_default():
self._model = CVAE(self.latent_dim)
self.inputs = tf.placeholder(tf.float32,
(None, 120, 120, 3))

self._model = CVAE(self.latent_dim)
self.inputs = tf.placeholder(tf.float32,
(None, 120, 120, 3))
self.mean, self.var = self._model.encode(self.inputs)

self.mean, self.var = self._model.encode(self.inputs)

config = tf.ConfigProto(log_device_placement=False)
if self.use_gpu:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)
config = tf.ConfigProto(log_device_placement=False)
if self.use_gpu:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)

def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []

0 comments on commit eb48779

Please sign in to comment.
You can’t perform that action at this time.