In [None]:
%matplotlib inline


from beras.data_utils import HDF5Tensor
import matplotlib.pyplot as plt
import numpy as np
from beras.gan import GAN
from deepdecoder.utils import visualise_tiles, zip_visualise_tiles, np_binary_mask
from deepdecoder.networks import dcgan_generator, dcgan_discriminator, mogan_learn_bw_grid, \
    dcgan_seperated_generator, dcgan_variational_add_generator
from deepdecoder.data import gen_diff_gan, normalize_grid_params
from deepdecoder.model_utils import add_uniform_noise, plot_weights_histogram
from deepdecoder.grid_curriculum import get_generator_and_callback, reduced_id_lecture, exam, \
    z_rot_lecture, y_rot_lecture, x_rot_lecture, grid_generator
from keras.optimizers import SGD, Adam, RMSprop
from keras.callbacks import Callback
from keras.initializations import glorot_uniformt
import pylab
import time
import h5py
pylab.rcParams['figure.figsize'] = (18, 18)
import theano
from dotmap import DotMap

In [None]:
g_nb_out_channels = 1
g = dcgan_variational_add_generator(nb_z=40)

In [None]:
#g_model.load_weights('g_z025_single_channel_small.hdf5')
#add_uniform_noise(g, 0.06)
def reinitialize_v_deconv():
    deconv = g.layers[-2]
    deconv_weight = deconv.get_weights()[0]
    v_weight = deconv_weight[:, 1:]
    print(v_weight.shape)
    shared_weight = glorot_uniform(v_weight.shape)
    v_weight = shared_weight.get_value()
    del shared_weight
    deconv_weight[:, 1:] = v_weight
    deconv.set_weights([deconv_weight])
    assert (deconv_weight[:, 1:] == v_weight).all()
    
#reinitialize_v_deconv()

def add_noise_to_variation(std=0.02):
    for name, layer in g.nodes.items():
        if name.startswith('var'):
            weights  = layer.get_weights()
            for w in weights:
                w += np.random.normal(0, std, w.shape)
            layer.set_weights(weights)

def reset_variation():
    for name, layer in g.nodes.items():
        if name.startswith('var'):
            weights  = layer.get_weights()
            for i, w in enumerate(weights):
                shared_weight = glorot_uniform(w.shape)
                weights[i] = shared_weight.get_value()
                del shared_weight
            layer.set_weights(weights)

add_noise_to_variation()

In [None]:
#reset_variation()

In [None]:
discriminator = dcgan_discriminator()
nb_z = 19
optimizer = lambda: Adam(lr=0.0002, beta_1=0.5)
mogan, grid_bw_loss_weight = mogan_learn_bw_grid(g, discriminator, optimizer, nb_z=nb_z)
start = time.time()
mogan.compile()
print("Done Compiling in {0:.2f}s".format(time.time() - start))

In [None]:
tags_fname = '/home/leon/data/tags.hdf5'
h5 = h5py.File(tags_fname, 'r')
batch_size = mogan.gan.batch_size
epoch_size = 100*batch_size
nb_tags = h5['tags'].shape[0]
nb_tags = (nb_tags // epoch_size)*epoch_size
tags = HDF5Tensor(tags_fname, 'tags', 0, nb_tags)
assert len(tags) % epoch_size == 0
print(nb_tags // epoch_size)

In [None]:
def get_conds(batch):
    return {
        'grid_idx': batch.grid_idx,
        'z_rot90': batch.z_rot90,
        'grid_params': batch.grid_params
    }

In [None]:
curriculum = [
    reduced_id_lecture(0.03),
    reduced_id_lecture(0.1),
    reduced_id_lecture(0.5),
    
    reduced_id_lecture(0.5) + z_rot_lecture(0.05),
    reduced_id_lecture(0.5) + z_rot_lecture(0.15),
    reduced_id_lecture(0.5) + z_rot_lecture(0.25),
    
    reduced_id_lecture(0.5) + x_rot_lecture(0.5),
    reduced_id_lecture(0.5) + y_rot_lecture(0.5),
    
    z_rot_lecture(0.25),
    x_rot_lecture(0.5) + y_rot_lecture(0.5) + z_rot_lecture(0.25),
    x_rot_lecture(1.) + y_rot_lecture(1.) + z_rot_lecture(0.05),
    x_rot_lecture(1.) + y_rot_lecture(1.) + z_rot_lecture(0.25),
]
for c in curriculum:
    c.pass_limit = 0.02

In [None]:
batch_size = 128
samples_per_epoch=100*batch_size

nb_channels = 1
grid_raw_generator, curriculum_cb = get_generator_and_callback(curriculum, samples_per_epoch)
draw_raw_generator = grid_generator(curriculum_cb, batch_size)

def grid_generator_wrapper(input_dim=40):
    while True:
        params, grid_idx = next(grid_raw_generator)
        size = len(params)
        params = normalize_grid_params(params)
        z_bins = np.random.choice(4, (size, 1))
        yield DotMap({
            'grid_params': params.astype(np.float32),
            'grid_idx': grid_idx,
            'z_rot90': z_bins,
        })
        
def draw_generator(input_dim=40):
    while True:
        params, grid_idx = next(draw_raw_generator)
        size = len(params)
        params = normalize_grid_params(params)
        z_bins = np.random.choice(4, (size, 1))
        yield DotMap({
            'grid_params': params.astype(np.float32),
            'grid_idx': grid_idx,
            'grid_bw': np_binary_mask(grid_idx),
            'z_rot90': z_bins,
        })
        
input_dim = 40
generator = grid_generator_wrapper(input_dim)
batch = next(grid_generator_wrapper(input_dim))
params, grids = batch['grid_params'], batch['grid_idx']                
print(params.shape)

class CallbackFilterOnTrainBegin(Callback):
    def __init__(self, cb):
        self.cb = cb
    def on_epoch_begin(self, epoch, log={}):
        self.cb.on_epoch_begin(epoch, log)
        
    def on_epoch_end(self, epoch, log={}):
        self.cb.on_epoch_end(epoch, log)
        if not self.cb.model.stop_training:
            add_noise_to_variation(std=0.02)

    def on_batch_end(self, batch, log={}):
        if 'cond_loss' in log:
            log['loss'] = log['cond_loss']
        self.cb.on_batch_end(batch, log)

In [None]:
def draw_diff_gan():
    batch = next(draw_generator())
    outs = mogan.gan.debug(tags[0:batch_size], conditionals=get_conds(batch))
    zip_visualise_tiles(outs.real, batch.grid_bw, outs.g_out[:, 0], outs.g_out[:, 1], outs.fake)

In [None]:
draw_diff_gan()

In [None]:
grid_bw_loss_weight.set_value(1.0)
# curriculum_cb.on_train_begin()
curriculum_cb.model = mogan
mogan.stop_training = False
for i in range(300):
    print(i)
    ti = (i*epoch_size) % nb_tags
    batch = next(generator)
    batch['real'] = tags[ti:ti+epoch_size]
    mogan.fit(batch, nb_epoch=1, verbose=1, callbacks=[CallbackFilterOnTrainBegin(curriculum_cb)])
    if i % 15 == 0 or i < 15 or (i < 30 and i % 2 == 0) or (i < 50 and i % 5 == 0): 
        draw_diff_gan()

In [None]:
mogan.gan.save('gan_17_1_total_sperated')