In [None]:
%matplotlib inline
from deepdecoder.networks import dcgan_generator, dcgan_discriminator
from beras.util import tile, smooth
from beras.gan import sequential_to_gan, gan_binary_crossentropy, gan_linear_losses
from beras.models import asgraph
import matplotlib.pyplot as plt
import numpy as np  
from deepdecoder.mask_loss import mask_loss, mask_loss_sobel, mask_loss_mse, mask_loss_adaptive_mse
from deepdecoder.utils import binary_mask
from deepdecoder.data import grids_lecture_generator, load_real_hdf5_tags
from deepdecoder.visualise import plt_hist
from beesgrid import MASK
from keras.optimizers import SGD, Adam, RMSprop
from keras.callbacks import Callback
from keras.layers.normalization import BatchNormalization
from keras.models import Graph
from keras.initializations import normal
import seaborn as sns
import matplotlib
import time
from itertools import count
from more_itertools import take
from skimage.filters import gaussian_filter

import pylab
pylab.rcParams['figure.figsize'] = (18, 18)

In [None]:
n = 32
generatr_input_dim = 100
num_batches_per_epoch = 1

nb_fake = 96
nb_real = 36
batch_size = nb_fake + nb_real
z_shape = (nb_fake, generatr_input_dim)

In [None]:
def normal_002(shape):
    return normal(shape, scale=0.02)

In [None]:
g = dcgan_generator(n, input_dim=generatr_input_dim, init=normal_002)
d = dcgan_discriminator(n//2, image_views=d_image_views)
gan = sequential_to_gan(g, d, nb_real=nb_real, nb_fake=nb_fake)

# gan.load_weights("lapgan/models/dcgan_g64_d32/{}.hdf5")

In [None]:
tags = load_real_hdf5_tags('/home/leon/data/tags_plain_t6_o36.hdf5',
                           nb_fake,  num_batches_per_epoch)
nb_tags = len(tags)
print(nb_tags)
mean_image = (tags[0:1248] / 255).mean(axis=0)

In [None]:
def plot_mean_image(mean_image):                                                                                                                                                               
    fig = plt.figure()                                                                                         
    ax = fig.gca()                                                                                             
    ax.set_xticks(np.arange(0, 64, 8))                                                                         
    ax.set_yticks(np.arange(0, 64, 8))                                                                         
    plt.grid(True)                                                                                             
    plt.imshow(mean_image, cmap='gray')                                                                        
    plt.colorbar()   
    plt.show()
    
plot_mean_image(mean_image[0])

print(mean_image.shape)
im = mean_image[0].copy()
t = 3
im[t:, t:] = mean_image[0, :-t, :-t]

plot_mean_image(im)

In [None]:
def visualise_tiles(images):
    tiled_fakes = tile(images)
    plt.imshow(tiled_fakes[0], cmap='gray')
    plt.grid(False)
    plt.show()
def should_visualise(i):
    return i % 50 == 0 or \
        (i < 1000 and i % 20 == 0) or \
        (i < 100 and i % 5 == 0) or \
        i < 15
s = np.tanh(1)
def tags_generator():
    tag_bs = nb_fake
    for i in count(step=tag_bs):
        ti = i % nb_tags
        tag_batch = tags[ti:ti+tag_bs] / 255 - mean_image
        #imgs = 2*s*tag_batch-s
        #reals = [gaussian_filter(img, sigma=2) for img in tag_batch]
        z = np.random.uniform(-1, 1, z_shape)
        inputs = {'real': tag_batch, 'z': z}#np.stack(reals)}
        yield inputs
    
class VisualiseCb(Callback):
    #def on_batch_end(self, batch, log={}):
    #    visualise_tiles(gan.generate())
    def on_epoch_end(self, epoch, log={}):
        if should_visualise(epoch):
            fake = gan.generate(z_shape=(128, generatr_input_dim))
            print(fake.shape)
            visualise_tiles(fake)

In [None]:
generator = tags_generator()
print(next(generator)['real'].max())
print(next(generator)['real'].min())
#visualise_tiles(smooth(next(generator)['real'], sigma=2.5).eval())
#visualise_tiles(next(generator)['real'])

In [None]:
print("Compiling...")
start = time.time()
optimizer = lambda lr: Adam(lr=lr, beta_1=0.5)
gan.compile(Adam(lr=0.0002, beta_1=0.5), Adam(lr=0.0002, beta_1=0.5), gan_binary_crossentropy)
print("Done Compiling in {0:.2f}s".format(time.time() - start))

In [None]:
VisualiseCb().on_epoch_end(0)

In [None]:
gan.compile_debug()

In [None]:
def display_debug(i=0, prefix="gen"):
    out = gan.debug(next(tags_generator()))
    for name, arr in sorted(out.items()):
        if not name.startswith(prefix):
            continue
        plt.title(name)
        plt.grid(False)
        if arr.ndim == 4:
            t = tile(arr[i])
            plt.imshow(t[0], cmap='gray')
        else:
            continue
            print(arr.shape)
            last = arr.shape[-1]
            plt.imshow(arr[i].r, cmap='gray')
        plt.colorbar()
        plt.show()

In [None]:
display_debug(1, "dis")

In [None]:
display_debug(1, "gen")

In [None]:
bns = [l for l in g.layers if type(l) is BatchNormalization]
for i, bn in enumerate(bns):
    beta = bn.beta.get_value()
    gamma = bn.gamma.get_value()
    #bn.gamma.set_value(np.random.uniform(-1, 1, gamma.shape).astype(np.float32))
    print(bn.running_std.get_value())
    print(beta.shape)
    print(gamma.shape)

In [None]:
gan.fit_generator(tags_generator(), nb_batches_per_epoch=50, 
                  nb_epoch=300, verbose=1, callbacks=[VisualiseCb()])

In [None]:
display_debug(0, "gen")

In [None]:
def weights_histogram(model, bins=50):                                          
    hists = []                                                                  
    for i, layer in enumerate(model.layers):                                    
        name = str(type(layer)   ) + "_{}".format(i)                               
        weights = layer.get_weights()                                           
        for wi, weight in enumerate(weights):                                   
            plt_hist(weight, name)
            plt.show()
    return hists 
weights_histogram(g)

In [None]:
dense = g.layers[0]
weight = dense.get_weights()[0]
print(weight.shape)
n = 8
colorplate = sns.cubehelix_palette()
#cmap = matplotlib.colors.ListedColormap(colorplate)
cmap = matplotlib.colors.ListedColormap(sns.hls_palette(256, .33, .85, .6))
plt.imshow(weight.reshape((100*n, 4096//n)), cmap=cmap)
plt.colorbar()
plt.show()

In [None]:
g.layers

In [None]:
visualise_tiles(g.layers[3].params[0].get_value())

In [None]:
import numpy as np
arr = np.arange(len(X))
np.random.shuffle(arr)
visualise_tiles(X[arr[:128]])

In [None]:
shp = gan.z_shape[1:]
z_point = lambda: np.random.uniform(-1, 1, shp)
for i in range(2):
    out = gan.interpolate(z_point(), z_point())
    visualise_tiles(out)

In [None]:
print("Compiling...")
start = time.time()
loss_fn = lambda t, p: mask_loss_adaptive_mse(t, p).loss
gan.compile_optimize_image(Adam(), loss_fn)
print("Done Compiling in {0:.2f}s".format(time.time() - start))

In [None]:
if 'opt_z' not in globals():
    opt_z = np.random.uniform(-1, 1, gan.z_shape)
duration = 0
for i in range(20):
    zip_visualise_tiles(bw_mask(masks_idx), gan.generate(opt_z))
    iterations = 256
    start = time.time()
    opt_images, opt_z = gan.optimize_image(masks_idx, iterations, z_start=opt_z, verbose=1)
    duration += time.time() - start
    if i % 5:
        zip_visualise_tiles(bw_mask(masks_idx), opt_images)
        jjjjj
zip_visualise_tiles(bw_mask(masks_idx), opt_images)

print("Done Optimizing in {0:.2f}s".format(duration))

In [None]:
z = np.zeros(gan.z_shape, dtype=np.float32)
z[0] = gan.random_z_point()
for i in range(1, len(z), 2):
    z[i] = gan.random_z_point()
    z[i+1] = np.max(z[0] + 0.5 * np.ones(z[0].shape)
    z[i] = np.clip(z[i], -1, 1)
visualise_tiles(gan.generate(z))

In [None]:
for _ in range(10):
    visualise_tiles(gan.neighborhood(opt_z[5], std=0.30))

In [None]:
z_point = opt_z[0]
for i in range(50):
    print(i)
    z_min = np.copy(z_point)
    z_max = np.copy(z_point)
    z_min[i] = -1
    z_max[i] = 1
    visualise_tiles(gan.interpolate(z_min, z_max))

In [None]:
print(len(masks_idx))
print(len(opt_images))