In [None]:
%env THEANO_FLAGS=cuda.root=/opt/cuda,device=gpu,floatX=float32,allow_gc=True,lib.cnmem=0.1

In [None]:
#,lib.cnmem=1

In [None]:
%matplotlib inline
from deepdecoder.networks import mask_blending_generator, get_mask_driver, get_lighting_generator, \
    get_offset_merge_mask, get_mask_weight_blending, get_offset_back, get_offset_front, \
    get_offset_middle, mask_generator, mask_blending_discriminator, get_mask_postprocess

from deepdecoder.utils import zip_visualise_tiles, visualise_tiles
from deepdecoder.data import normalize_generator, grids_lecture_generator, \
    load_real_hdf5_tags, z_generator, nb_normalized_params, real_generator, np_binary_mask, weight_pyramid
from deepdecoder.mask_loss import to_keras_loss, pyramid_loss
from deepdecoder.grid_curriculum import Lecture, Normal
from beras.gan import GAN, gan_binary_crossentropy, gan_linear_losses, gan_outputs
from beras.callbacks import VisualiseGAN, SaveModels
from keras.models import Sequential, Graph
from keras.layers.core import Dense, Layer
from beras.transform import tile
from keras.optimizers import Adam, SGD
from beesgrid import NUM_MIDDLE_CELLS

from keras.engine.topology import Input

from skimage.transform import pyramid_reduce, pyramid_laplacian
from skimage.filters import gaussian_filter
from scipy.ndimage.interpolation import zoom

import numpy as np
import time
import pylab
import theano.tensor as T
import os
import matplotlib.pyplot as plt
import sys 

print(sys.getrecursionlimit())
sys.setrecursionlimit(10000)
pylab.rcParams['figure.figsize'] = (16, 16)

In [None]:
nb_fake = 32

nb_real = 32 // 2
offset_nb_units = 48
dis_nb_units = 32
lr = 0.00005
beta_1 = 0.5
nb_input_mask_generator = 19 - NUM_MIDDLE_CELLS
z_dim_offset = 20
z_dim_driver = 20
z_dim_bits = 12
z_dim = z_dim_offset + z_dim_driver + z_dim_bits

In [None]:
g_mask = lambda x: mask_generator(x, nb_units=32, dense_factor=3, nb_dense_layers=2, trainable=False)
g_mask_weights="../models/holy/mask_generator_n32_black_white//mask_generator.hdf5"

In [None]:
d = lambda x: gan_outputs(mask_blending_discriminator(x, n=dis_nb_units), 
                         fake_for_gen=(0, nb_fake),
                         fake_for_dis=(0, nb_fake - nb_real),
                         real=(nb_fake, nb_fake + nb_real),
                        )

In [None]:
def merge16(namespace):
    def call(x):
        return get_offset_merge_mask(x, nb_units=offset_nb_units // 3, nb_conv_layers=2,
                                     poolings=[True, True],
                                     ns=namespace)
    return call
def merge32(namespace):
    def call(x):
        return get_offset_merge_mask(x, nb_units=offset_nb_units // 3, nb_conv_layers=2,
                                     poolings=[True, False],
                                     ns=namespace)
    return call
def merge(namespace):
    return lambda x: get_offset_merge_mask(x, nb_units=offset_nb_units // 3, nb_conv_layers=2,
                                     ns=namespace)

In [None]:
g = mask_blending_generator(
    mask_driver= lambda x: get_mask_driver(x, nb_units=offset_nb_units, nb_output_units=nb_input_mask_generator),                                                                                                                                  
    mask_generator=g_mask,                                                         
    light_merge_mask16=merge('light_merge16'),
    offset_merge_light16=merge16('offset_merge_light16'),
    offset_merge_mask16=merge('offset_merge16'),                                                    
    offset_merge_mask32=merge('offset_merge32'),                                                    
    lighting_generator=lambda x: get_lighting_generator(x, offset_nb_units // 2),                                                     
    offset_front=lambda x: get_offset_front(x, offset_nb_units),
    offset_middle=lambda x: get_offset_middle(x, offset_nb_units),
    offset_back=lambda x: get_offset_back(x, offset_nb_units),
    mask_weight_blending32=lambda x: get_mask_weight_blending(x, min=0.15),
    mask_weight_blending64=get_mask_weight_blending,
    mask_generator_weights=g_mask_weights,
    mask_postprocess=lambda x: get_mask_postprocess(x, offset_nb_units // 3),
    z_for_driver=(0, z_dim_driver),
    z_for_offset=(z_dim_driver, z_dim_driver + z_dim_offset),
    z_for_bits=(z_dim_driver + z_dim_offset, z_dim_driver + z_dim_offset + z_dim_bits))

In [None]:
gan = GAN(g, d, z_shape=(z_dim,), real_shape=(1, 64, 64))
gan.add_gan_regularizer(GAN.StopRegularizer(high=3.5))

In [None]:
gan._gan_regularizers[0].high.set_value(100)

In [None]:
g_optimizer = Adam(lr, beta_1)
d_optimizer = Adam(lr, beta_1)
gan.build(g_optimizer, d_optimizer, gan_binary_crossentropy)

In [None]:
debug_keys = gan.debug_dict().keys()
for k in sorted(gan.debug_dict().keys()):
    print(k)

In [None]:
print("Compiling...")
start = time.time()
mask_gen_layer =  'mask_gen.22_activation'
driver_layer = 'driver.10_linearinbounds'
gen_out_layer = 'blending_post'
light_layer = 'lighting.16_gaussianblur'
post_layer = 'mask_post_high'
debug_keys = ['selection', 'blending', 'addlight', gen_out_layer, driver_layer, mask_gen_layer,
               'offset.back_out.01_linearinbounds', light_layer,  post_layer]
gan.compile_debug(debug_keys)
print("Done Compiling in {0:.2f}s".format(time.time() - start))

In [None]:
print("Compiling...")                                                       
start = time.time()      
gan.compile()
print("Done Compiling in {0:.2f}s".format(time.time() - start))

In [None]:
nb_visualise = 50
sample_z = np.random.uniform(-1, 1, (nb_visualise, z_dim)).astype(np.float32)

In [None]:
class VisualiseMasks(VisualiseGAN):
    def __init__(self, inputs, **kwargs):
        self.inputs = inputs
        super().__init__(**kwargs)
    def __call__(self):
        self.inputs['z'] = self.z
        outs = self.model.debug(self.inputs)
        tiles = []
        for mask, blending in zip(
            self.preprocess(outs[mask_gen_layer]), 
            self.preprocess(outs[gen_out_layer])):
            tiles.append(mask)
            tiles.append(blending)
        tiled = tile(tiles, columns_must_be_multiple_of=2)
        plt.imshow(tiled[0], cmap='gray')
        if self.show:
            plt.show()

In [None]:
vis = VisualiseMasks(nb_samples=nb_visualise, output_dir='visualise/', show=True, preprocess=lambda x: np.clip(x, -1, 1),
                   inputs={'real': np.zeros((nb_visualise, 1, 64 ,64), dtype='float32')},
                  )

In [None]:
vis.model = gan
vis.on_train_begin(0)

In [None]:
vis()

In [None]:
def generator(nb_real, nb_fake):
    real_gen = real_generator("/home/leon/data/tags_plain_t6.hdf5", nb_real)
    for real in real_gen:
        yield {
            'real': 2*real - 1,
            'z': np.random.uniform(-1, 1, (nb_fake, z_dim)).astype(np.float32)
        }

In [None]:
debug_in = next(generator(1, 48))

In [None]:
out = gan.debug(debug_in)

In [None]:
off = np.clip(out['offset.back_out.01_linearinbounds'], -1, 1)
mask = 2 *np.clip(out['mask_gen.22_activation'], 0, 1) - 1


blending = out['blending']
gen_out = np.clip(out[gen_out_layer], -1, 1)

light = out['addlighting_1']


high_frq = np.clip(out[post_layer], -1, 1)
selection = out['selection']

pylab.rcParams['figure.figsize'] = (32, 32)
zip_visualise_tiles(light, off, mask, selection, blending,  high_frq, gen_out)

In [None]:
gen = generator(nb_real, nb_fake)

In [None]:
visualise_tiles(next(generator(8*8, 0))['real'])

In [None]:
g_optimizer.lr.set_value(np.cast[np.float32](lr))
d_optimizer.lr.set_value(np.cast[np.float32](lr))

In [None]:
hist = gan.fit_generator(gen, nb_batches_per_epoch=100, nb_epoch=500, batch_size=128, verbose=1, callbacks=[vis])

In [None]:
hist = gan.fit_generator(gen, nb_batches_per_epoch=100, nb_epoch=500, batch_size=128, verbose=1, callbacks=[vis])

In [None]:
hist = gan.fit_generator(gen, nb_batches_per_epoch=100, nb_epoch=250
                         , batch_size=128, verbose=1, callbacks=[vis])   

In [None]:
hist = gan.fit_generator(gen, nb_batches_per_epoch=100, nb_epoch=500, batch_size=128, verbose=1, callbacks=[vis])

In [None]:
out = gan.debug(next(gen), train=True)

In [None]:
def show(name, n=0):
    plt.imshow(tile(out[name][n])[0], cmap='gray')
    plt.colorbar()
    plt.show()
def show_batches(name):
    plt.imshow(tile(out[name])[0], cmap='gray')
    plt.colorbar()
    plt.show()

In [None]:
for n, arr in sorted(out.items()):
    print("{}: {}".format(n, arr.shape))

In [None]:
selection_layer.threshold.set_value(-0.08)
selection_layer.smooth_threshold.set_value(0.2)
blend_layer.min_mask_blendings[-1].set_value(0)
blend_layer.min_mask_blendings[-2].set_value(0)
lighting_layer.shift_factor.set_value(1.)
lighting_layer.scale_factor.set_value(0.75)
blend_layer.use_blending.set_value(1)

blend_layer.weights[0].set_value(0)
blend_layer.weights[1].set_value(0)
blend_layer.weights[2].set_value(1)


In [None]:
debug_in = next(generator(1, 48))

In [None]:
for l in gan.layers:
    if l.name == 'pyramidblending_1':
        blending = l
    elif l.name.startswith('selection'):
        selection = l
    elif l.name.startswith('addligh'):
        add_lighting = l

In [None]:
add_lighting.shift_factor.set_value(1)
add_lighting.scale_factor.set_value(0.7)

In [None]:
selection.threshold.set_value(-0.08)
selection.smooth_threshold.set_value(0.2)

In [None]:
blending.mask_weights[0].set_value(1)

In [None]:
[w.set_value(1) for w in blending.offset_weights]

In [None]:
def print_bits(driver_out):
    def symbol(bit):
        if bit > 0:
            return '.'
        else:
            return '#'
        
    for bit in driver_out[:12]:
        print(symbol(bit), end='')
    print()
    for bit in driver_out[:12]:
        print("{}   {}".format(symbol(bit), bit))

In [None]:
n = 10
print_bits(out[driver_layer][n])

In [None]:
out[driver_layer].mean(axis=0), out[driver_layer].std(axis=0)

In [None]:
mask = 2 * zoom(np.clip(out['mask_gen.convolution2d_6'], 0, 1), (1, 1, 2, 2)) - 1


gen_out = out['merge_6']

pylab.rcParams['figure.figsize'] = (32, 32)
print(len(gen_out))

zip_visualise_tiles(mask, gen_out)

In [None]:
for name, val in out.items():
    if name.startswith('merge'):
        if val.shape[1:] == (1, 64, 64):
            print(name)
            print(val.shape)

In [None]:
out['driver.batchnormalization_3'][6]

In [None]:
np.mean(out['driver.batchnormalization_3'], axis=0)

In [None]:
np.std(out['driver.batchnormalization_3'], axis=0)

In [None]:
show('mask_gen.convolution2d_6', 6)

In [None]:
output_dir = '../models/blendgan_working_blending_should_be_improved_to_merge_32/'
os.makedirs(output_dir, exist_ok=True)
gan_graph.save_weights(output_dir + "gan.hdf5")

In [None]:
with open(output_dir + "gan.json", "w+") as f:
    f.write(gan_graph.to_json())