In [None]:
%matplotlib inline
from deepdecoder.networks import dcgan_generator, dcgan_discriminator, gan_grid_idx, mogan_pyramid
from deepdecoder.grid_curriculum import exam, grids_from_lecture
from deepdecoder.data import normalize_generator, grids_lecture_generator, load_real_hdf5_tags
from deepdecoder.utils import zip_visualise_tiles, np_binary_mask, visualise_tiles
from deepdecoder.visualise import plot_multi_objective_grads
from beras.gan import sequential_to_gan
from keras.optimizers import Adam
from keras.objectives import binary_crossentropy, mse
from keras.callbacks import Callback
import numpy as np
import matplotlib.pyplot as plt
import pylab
import time
from itertools import combinations
import seaborn

pylab.rcParams['figure.figsize'] = (16, 16)


In [None]:
batch_size = 128
num_batches_per_epoch = 100
epoch_size = batch_size*num_batches_per_epoch
nb_fake = 96
nb_real = 36
generator_input_dim=50

In [None]:
g = dcgan_generator(n=32, input_dim=generator_input_dim)
d = dcgan_discriminator(n=32)
gan = sequential_to_gan(g, d, nb_real=nb_real, nb_fake=nb_fake)

In [None]:
# g.load_weights("generator_pyramdi_loss.hdf5")

In [None]:
tags = load_real_hdf5_tags('/mnt/storage/beesbook/season_2015_tags/tags_plain_64x64.hdf5', nb_real,  num_batches_per_epoch)
nb_tags = len(tags)

In [None]:
plt.grid(False)
visualise_tiles(tags[0:256])

In [None]:
mogan = mogan_pyramid(g, d, lambda: Adam(lr=0.0002, beta_1=0.5), nb_z=31, 
                      gan_objective=binary_crossentropy,
                      d_loss_grad_weight=0)

In [None]:
weights = {
    # expected loss g / wanted loss tags
    "cond_loss": 12, 
    "g_loss": 1,
}
for name, weight in weights.items():
    mogan.multi_objectives.set_objective_weight(name, weight)
    

In [None]:
start = time.time()
mogan.compile()
mogan.multi_objectives.compile_get_grads()
print("Done Compiling in {}s".format(time.time() - start))

In [None]:
def plt_hist(x, label, num_bins=50, **kwargs):
    hist, bins = np.histogram(x, bins=num_bins)
    width = 0.7 * (bins[1] - bins[0])
    center = (bins[:-1] + bins[1:]) / 2
    plt.bar(center, hist, align='center', width=width, label=label, alpha=0.2, **kwargs)

def plot_multi_objective_grads(params, grads):                                  
    for i, grad_dict in enumerate(grads):                                       
        fig = plt.figure()                                                      
        print(params[i])                                                        
        print(i)                                                                
        fig.add_subplot(2, 1, 2)                                                
        plt.title("{} grad historgram".format(i))                                             
        colors = ('r', 'g', 'b')
        for color, (name, grad) in zip(colors, grad_dict.items()):           
            plt_hist(grad, name, color=color)   
        plt.legend()
        plt.show()                                                              
        print("multiplication")                                                 
        fig.add_subplot(2, 1, 2)                                                
        plt.title("{} multi histogram".format(i))                     
        for color, ((name_a, a), (name_b, b)) in zip(colors,                  
             combinations(grad_dict.items(), 2)):                            
            plt_hist(a*b, "{}-{}".format(name_a, name_b), color=color) 
        plt.legend()
        plt.show()                                                                                   

def plot_grads():
    tag_batch = tags[0:batch_size]
    params, grid_idx = next(grids_lecture_generator(batch_size))
    inputs = [tag_batch.astype(np.float32), grid_idx.astype(np.float32), params.astype(np.float32)]
    grads = mogan.multi_objectives.get_grads(inputs)
    plot_multi_objective_grads(mogan.multi_objectives.params, grads)

In [None]:
def should_visualise(i):
    return i % 50 == 0 or \
        (i < 1000 and i % 20 == 0) or \
        (i < 100 and i % 5 == 0) or \
        i < 15
def visualise():
    vis_params, vis_idx = next(grids_lecture_generator(batch_size))
    zip_visualise_tiles(np_binary_mask(vis_idx),
                        mogan.gan.generate(conditionals={'grid_params': vis_params}))

In [None]:
def grid_generator():
    for i, (params, grid_idx) in enumerate(grids_lecture_generator(batch_size)):
        ti = i % nb_tags
        tag_batch = tags[ti:ti+batch_size]
        inputs = {'real': tag_batch, 'cond_true': grid_idx, 'grid_params': params}
        yield inputs

In [None]:
class VisualiseCallback(Callback):
    def on_epoch_end(self, epoch, logs={}):
        if should_visualise(epoch):
            visualise()
    

In [None]:
list(dict([(1, 3)]).items())

In [None]:
start = time.time()
gan = mogan.gan
gan._compile_debug(mogan.build_dict)
print("Done Compiling in {}s".format(time.time() - start))

In [None]:
gan.debug_output(mogan.build_dict)

In [None]:
list(mogan.build_dict.conditionals_dict.items())

In [None]:
class PrintLossGrad(Callback):
    def on_epoch_end(self, epoch, logs={}):
        inputs = next(grid_generator())
        real = inputs['real']
        del inputs['real']
        del inputs['cond_true']
        print(len(inputs))
        debug = gan.debug(real, conditionals=inputs)
        print("d_loss_grad: {}".format(float(debug.d_loss_grad)))
        visualise_tiles(debug.fake_grad)
PrintLossGrad().on_epoch_end(0)

In [None]:
visualise()
plot_grads()

In [None]:
mogan.fit_generator(grid_generator(), samples_per_epoch=epoch_size, nb_epoch=250, 
                    verbose=1, callbacks=[VisualiseCallback()])

In [None]:
plot_grads()