In [None]:
%matplotlib inline


from beras.data_utils import HDF5Tensor
import matplotlib.pyplot as plt
import numpy as np
from deepdecoder.utils import visualise_tiles, zip_visualise_tiles, np_binary_mask
from deepdecoder.mask_loss import pyramid_loss, to_keras_loss
from deepdecoder.networks import dcgan_generator, dcgan_discriminator, dummy_dcgan_generator
from deepdecoder.grid_curriculum import get_generator_and_callback, reduced_id_lecture, \
    z_rot_lecture, y_rot_lecture, x_rot_lecture, z_rot_lecture_around, AroundPoints
from deepdecoder.data import normalize_generator
from beesgrid import MASK, CONFIG_ROTS, CONFIG_RADIUS, CONFIG_CENTER, TAG_SIZE
from keras.optimizers import SGD, Adam, RMSprop
from keras.callbacks import Callback
from keras.regularizers import l2
import importlib
import h5py
import pylab
import time

pylab.rcParams['figure.figsize'] = (18, 18)

In [None]:
curriculum = [
    reduced_id_lecture(0.03) + z_rot_lecture_around(2),
    reduced_id_lecture(0.15) + z_rot_lecture_around(4),
    reduced_id_lecture(0.03) + x_rot_lecture(0.5) + z_rot_lecture_around(2) ,
    reduced_id_lecture(0.03) + y_rot_lecture(0.5) + z_rot_lecture_around(2),
    x_rot_lecture(0.5) + y_rot_lecture(0.5) + z_rot_lecture_around(2),
    x_rot_lecture(0.5) + y_rot_lecture(0.5) + z_rot_lecture_around(4),
    x_rot_lecture(1.) + y_rot_lecture(1.) + z_rot_lecture_around(8),
    x_rot_lecture(1.) + y_rot_lecture(1.) + z_rot_lecture_around(16),
    x_rot_lecture(1.) + y_rot_lecture(1.) + z_rot_lecture_around(360),
    x_rot_lecture(1.) + y_rot_lecture(1.) + z_rot_lecture(1),
]
for c in curriculum:
    c.pass_limit = 0.02

In [None]:
batch_size = 128
generator_input_dim = 50
curriculum_grids_generator, curriculum_cb = get_generator_and_callback(curriculum, batch_size)

In [None]:
def grid_generator(input_dim=50):
    for params, grid_idx in normalize_generator(curriculum_grids_generator):
        z = np.random.uniform(-1, 1, (params.shape[0], input_dim - params.shape[1]))
        yield np.concatenate([params, z], axis=1), grid_idx
        
def listify(generator):
    for input, labels in generator:
        yield [input], labels

In [None]:
params, grids = next(grid_generator(generator_input_dim))
print(params.shape)
assert params.shape[1] == generator_input_dim

In [None]:
g = dcgan_generator(input_dim=generator_input_dim)

In [None]:

# g = dummy_dcgan_generator(input_dim=40)

In [None]:
start = time.time()
g.compile(Adam(lr=0.003, clipvalue=0.5), to_keras_loss(pyramid_loss))
print("Compiling done in {:.2f}s".format(time.time() - start))

In [None]:
def visualise_g():
    params, grids_idx = next(grid_generator())
    grids = np_binary_mask(grids_idx)
    pred_grids = g.predict(params)
    zip_visualise_tiles(grids, pred_grids)

In [None]:
visualise_g()

In [None]:
history = g.fit_generator(listify(grid_generator()), samples_per_epoch=100*batch_size, 
                          nb_epoch=10000, verbose=1, callbacks=[curriculum_cb], nb_worker=1)

In [None]:
visualise_g()

In [None]:
g.save_weights("generator_pyramdi_loss.hdf5", overwrite=True)

In [None]:
history = _14

In [None]:
plt.plot(np.log(history.history['loss']))