Skip to content

Commit

Permalink
axis_incremental fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hans committed Mar 15, 2015
1 parent d6ee75e commit b1f41b5
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions sampler/axis_incremental.py
Expand Up @@ -37,6 +37,7 @@

condition_dim = embeddings.shape[1]
m, n = len(args.axes), 10
shift = 7.5


# Prepare generator
Expand All @@ -48,26 +49,28 @@

# Sample some noise data -- this needs to be shared between orig and mod
# sample pairs
noise_data = generator.get_noise((m * n, generator.noise_dim)).eval()
noise_data = generator.get_noise((n, generator.noise_dim)).eval()


# Begin modifying axes
base_conditional_data = args.conditional_sampler(generator, n, 1,
embedding_file=args.embedding_file)
print 'Mean for each axis:'
pprint.pprint(zip(args.axes, base_conditional_data[:, args.axes].mean(axis=1)))
base_conditional_data[:, args.axes] -= 0.5 * shift

mod_conditional_data = base_conditional_data.copy()

# Build up a flat array of modified conditional data
mod_conditional_steps = []
for axis in args.axes:
# TODO customize
shift = 5.

mod_conditional_data[:, axis] += shift
mod_conditional_steps.extend(mod_conditional_data.copy())

mod_conditional_steps = np.array(mod_conditional_steps)

samples_orig = topo_sample_f(noise_data, base_conditional_data).swapaxes(0, 3)
samples_mod = topo_sample_f(noise_data, mod_conditional_data).swapaxes(0, 3)
samples_mod = topo_sample_f(np.tile(noise_data, (m, 1)), mod_conditional_steps).swapaxes(0, 3)

pv = PatchViewer(grid_shape=(m + 1, n), patch_shape=(32,32),
is_color=True)
Expand Down

0 comments on commit b1f41b5

Please sign in to comment.