Skip to content

Commit

Permalink
updated synthesis model to produce 16x16x16 output, see #10
Browse files Browse the repository at this point in the history
  • Loading branch information
bodokaiser committed Sep 30, 2017
1 parent 0d82515 commit 403eb61
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
12 changes: 7 additions & 5 deletions python/mrtoct/model/gan/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from mrtoct.model import layers


def generator_conv(inputs, kernel_size, num_filters, activation=tf.nn.relu):
def generator_conv(inputs, kernel_size, num_filters, padding,
activation=tf.nn.relu):
"""Creates a synthesis generator conv layer."""
outputs = layers.Conv3D(num_filters, kernel_size)(inputs)
outputs = layers.Conv3D(num_filters, kernel_size, padding=padding)(inputs)
outputs = layers.BatchNorm()(outputs)
outputs = layers.Activation(activation)(outputs)

Expand All @@ -18,17 +19,18 @@ def generator_network(params):

for i, ks in enumerate([9, 3, 3, 3, 9, 3, 3, 7, 3]):
with tf.variable_scope(f'conv{i}'):
outputs = generator_conv(outputs, ks, 64 if i in [4, 5] else 32)
outputs = generator_conv(outputs, ks, 64 if i in [4, 5] else 32,
'valid' if ks == 9 else 'same')

with tf.variable_scope('final'):
outputs = generator_conv(outputs, 3, 1, tf.nn.tanh)
outputs = generator_conv(outputs, 3, 1, 'same', activation=tf.nn.tanh)

return layers.Network(inputs, outputs, name='generator')


def discriminator_conv(inputs, num_filters):
"""Creates a synthesis discriminator conv layer."""
outputs = layers.Conv3D(num_filters, 5, 1)(inputs)
outputs = layers.Conv3D(num_filters, 5)(inputs)
outputs = layers.BatchNorm()(outputs)
outputs = layers.Activation(tf.nn.relu)(outputs)
outputs = layers.MaxPool3D(5, 1)(outputs)
Expand Down
2 changes: 1 addition & 1 deletion python/mrtoct/model/gan/synthesis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_generator(self):
x = tf.ones([10, 32, 32, 32, 1])
y = network(x)

self.assertAllEqual(x.shape, y.shape)
self.assertAllEqual(tf.TensorShape([10, 16, 16, 16, 1]), y.shape)

def test_discriminator(self):
network = model.synthesis.discriminator_network(self.params)
Expand Down
5 changes: 3 additions & 2 deletions python/mrtoct/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ def __init__(self, *args, **kwargs):
class Conv3D(tf.layers.Conv3D):
"""Same as tf.layers.Conv3D but with better defaults."""

def __init__(self, *args, **kwargs):
def __init__(self, num_filters, kernel_size, padding='same'):
init = tf.contrib.layers.xavier_initializer()

super().__init__(*args, kernel_initializer=init, padding='same', **kwargs)
super().__init__(num_filters, kernel_size,
padding=padding, kernel_initializer=init)


class Conv2DTranspose(tf.layers.Conv2DTranspose):
Expand Down
2 changes: 1 addition & 1 deletion python/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main(args):
tf.logging.set_verbosity(tf.logging.INFO)

hparams = tf.contrib.training.HParams(
sample_num=1000,
sample_num=8000,
seeds=[args.seed,
args.seed * 2,
args.seed * 3],
Expand Down

0 comments on commit 403eb61

Please sign in to comment.