In [None]:
%matplotlib inline


from beras.data_utils import HDF5Tensor
import matplotlib.pyplot as plt
import numpy as np
from beras.gan import GAN
from deepdecoder.utils import visualise_tiles, zip_visualise_tiles, np_binary_mask
from deepdecoder.networks import dcgan_generator, dcgan_discriminator, diff_gan
from deepdecoder.data import gen_diff_gan
from deepdecoder.model_utils import add_uniform_noise, plot_weights_histogram
from deepdecoder.grid_curriculum import get_generator_and_callback, reduced_id_lecture, exam, \
    z_rot_lecture, y_rot_lecture, x_rot_lecture
from keras.optimizers import SGD, Adam, RMSprop
from keras.callbacks import Callback
import pylab
import time
import h5py
pylab.rcParams['figure.figsize'] = (18, 18)
import theano


In [None]:
g = dcgan_generator(input_dim=40)

In [None]:
g.load_weights('g_z025_13_01.hdf5')
add_uniform_noise(g, 0.04)

In [None]:
discriminator = dcgan_discriminator()
nb_z = 19
gan = diff_gan(g, discriminator, nb_z=nb_z)
optimizer = lambda: Adam(lr=0.0002, beta_1=0.5)
start = time.time()
gan.compile(optimizer(), optimizer(), gan_regulizer=GAN.L2Regularizer)
print("Done Compiling in {0:.2f}s".format(time.time() - start))

In [None]:
tags_fname = '/home/leon/data/tags.hdf5'
h5 = h5py.File(tags_fname, 'r')
batch_size = gan.batch_size
epoch_size = 50*batch_size
nb_tags = h5['tags'].shape[0]
nb_tags = (nb_tags // epoch_size)*epoch_size
tags = HDF5Tensor(tags_fname, 'tags', 0, nb_tags)
assert len(tags) % epoch_size == 0

In [None]:
generator = gen_diff_gan(epoch_size, outputs=('grid_idx'))

In [None]:
def get_conds(batch):
    return {'grid_idx': batch.grid_idx,
             'z_rot90': batch.z_bins,
             'grid_params': batch.params}

In [None]:
def draw_diff_gan():
    batch = next(gen_diff_gan(batch_size, outputs=('grid_idx', 'grid_bw')))
    outs = gan.debug(tags[0:batch_size], conditionals=get_conds(batch))
    zip_visualise_tiles(outs.real, outs.g_out, outs.fake, batch.grid_bw)

In [None]:
draw_diff_gan()
for i in range(200):
    print(i)
    ti = (i*epoch_size) % nb_tags
    batch = next(generator)
    gan.fit(tags[ti:ti+epoch_size], 
            get_conds(batch),
            nb_epoch=1, verbose=1)
    if i % 15 == 0 or i < 10 or (i < 30 and i % 3 == 0): 
        draw_diff_gan()