Skip to content

Commit

Permalink
Example of how to train the lower-dimensional InfoGAN used in the exp…
Browse files Browse the repository at this point in the history
…eriments.
  • Loading branch information
asross committed Dec 28, 2020
1 parent 775aa3c commit 929f1c8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
19 changes: 15 additions & 4 deletions examples/GAN/InfoGAN-mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@
NUM_CLASS = 10
NUM_UNIFORM = 2
DIST_PARAM_DIM = NUM_CLASS + NUM_UNIFORM
NOISE_DIM = 62
NOISE_DIM = 3
# prior: the assumption how the latent factors are presented in the dataset
DIST_PRIOR_PARAM = [1.] * NUM_CLASS + [0.] * NUM_UNIFORM


def shapeless_placeholder(x, axis, name):
"""
Make the static shape of a tensor less specific.
Expand Down Expand Up @@ -108,6 +107,7 @@ class Model(GANModelDesc):
def inputs(self):
return [tf.TensorSpec((None, 28, 28), tf.float32, 'input')]

@auto_reuse_variable_scope
def generator(self, z):
l = FullyConnected('fc0', z, 1024, activation=BNReLU)
l = FullyConnected('fc1', l, 128 * 7 * 7, activation=BNReLU)
Expand Down Expand Up @@ -146,6 +146,7 @@ def build_graph(self, real_sample):
z_noise = shapeless_placeholder(
tf.random_uniform([BATCH, NOISE_DIM], -1, 1), 0, name='z_noise')
z = tf.concat([zc, z_noise], 1, name='z')
z2 = tf.placeholder('float', [None, int(z.shape[1])], name='z2')

with argscope([Conv2D, Conv2DTranspose, FullyConnected],
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
Expand All @@ -154,6 +155,9 @@ def build_graph(self, real_sample):
fake_sample_viz = tf.cast((fake_sample) * 255.0, tf.uint8, name='viz')
tf.summary.image('gen', fake_sample_viz, max_outputs=30)

fake_sample2 = self.generator(z2)
fake_sample_viz2 = tf.identity(fake_sample2, name='viz2')

# may need to investigate how bn stats should be updated across two discrim
with tf.variable_scope('discrim'):
real_pred, _ = self.discriminator(real_sample)
Expand Down Expand Up @@ -220,8 +224,8 @@ def sample(model_path):
pred = OfflinePredictor(PredictConfig(
session_init=SmartInit(model_path),
model=Model(),
input_names=['z_code', 'z_noise'],
output_names=['gen/viz']))
input_names=['z2'],
output_names=['gen/viz2']))

# sample all one-hot encodings (10 times)
z_cat = np.tile(np.eye(10), [10, 1])
Expand All @@ -231,6 +235,13 @@ def sample(model_path):

IMG_SIZE = 400

out = tf.reshape(pred.output_tensors[0], [-1, 28*28])
out = tf.identity(out, 'X_out')
decoder = tf.graph_util.convert_variables_to_constants(pred.sess, pred.graph.as_graph_def(), ['X_out'])
with tf.gfile.GFile('InfoGAN-mnist.pb', "wb") as f:
f.write(decoder.SerializeToString())
return None

while True:
# only categorical turned on
z_noise = np.random.uniform(-1, 1, (100, NOISE_DIM))
Expand Down
9 changes: 9 additions & 0 deletions examples/GAN/steps_to_save
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Train the model
python InfoGAN-mnist.py

# Save the model in an alternate format
python InfoGAN-mnist.py --sample --load ./info_train/model-50000

# Copy the model to the right place; from there, run the tensorflowjs converter
# as normal.
cp InfoGAN-mnist.pb /path/to/web/interface

0 comments on commit 929f1c8

Please sign in to comment.