In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from keras.layers import Activation,Dense,Input,Lambda
from keras.losses import mean_squared_error,mse
from keras.models import Sequential,Model,save_model,load_model
import keras.backend as K
import matplotlib.pyplot as plt
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.callbacks import LearningRateScheduler,ModelCheckpoint,ReduceLROnPlateau
from keras.preprocessing.image import  ImageDataGenerator
from keras.optimizers import Adam
%matplotlib inline
plt.rcParams['figure.figsize']=(10,10)
plt.rcParams['image.cmap']='viridis'

## First Step, Prepare Mnist Dataset
* different from supervise learning,I only using X
* using dense Layer as encoder and decoder,so X is 784 dims

In [None]:
mnist_data_path='../AI_database/mnist/MNIST_DATA'
def imshow(X,Y=None,classes=None):
    '''
        show Batch of image in grids sqrt(h) x sqrt(w)
        X is a numpy array,size (m,h,w,c)
        Y is a numpy array,size (m,#classes)
    '''
    m=X.shape[0]
    gridSize=int(m**0.5)
    for i in range(0,gridSize):
        for j in range(0,gridSize):
            _idx=i*gridSize+j
            im=X[_idx]
            plt.subplot(gridSize,gridSize,_idx+1)
            plt.axis('off')
            plt.imshow(im)
            if Y is not None:
                label=classes[np.argmax(Y[_idx])]
                plt.title(label)

def load_dataset(flaten=False,one_hot=True):
    def _make_one_hot(d,C=10):
        return (np.arange(C)==d[:,None]).astype(np.int32)

    mnist=input_data.read_data_sets(mnist_data_path)
    X_train,Y_train=mnist.train.images,mnist.train.labels
    X_test,Y_test=mnist.test.images,mnist.test.labels

    if flaten==False:
        X_train=X_train.reshape((-1,28,28,1))
        X_test = X_test.reshape((-1, 28, 28,1))
    if one_hot:
        Y_train = _make_one_hot(Y_train)
        Y_test=_make_one_hot(Y_test)


    print('\n-------------------------------------------------------------------------')
    print('load %d train Example,%d Test Example'%(X_train.shape[0],X_test.shape[0]))
    print('Train Images  Shape:'+str(X_train.shape))
    print('Train Labels  Shape:' + str(Y_train.shape))
    print('Test  Images  Shape:'+str(X_test.shape))
    print('Test  Labels  Shape:' + str(Y_test.shape))
    print('-------------------------------------------------------------------------')
    return (X_train,Y_train,X_test,Y_test)

In [None]:
X_train,Y_train,X_test,Y_test=load_dataset(flaten=True,one_hot=False)

## define show Help function

In [None]:
def reshaper(X):
    m=X.shape[0]
    return X.reshape((m,28,28))

## Define a encoder
* ###  the struct is X---->H1---->(Uz,logVarz,z) 

In [None]:
def vae_encoder(n_x=784,n_h=512,n_z=2):
    '''
        n_x:dim of Image
        n_x:dim of hidden units
        n_z:dim of latent variable
        return:,a model with outputs=[u_z,logvar_z,z] all have same dims n_z
    '''
    def sample(args):
        '''
            args:u,logvar,2 tensor object
            sample from N(0,I),then transform it to N(u,exp(logvar))
            return:a tensor 
        '''
        u,logvar=args
        batch,ndim=tf.shape(u)[0],tf.shape(u)[1]
        
        z=K.random_normal(shape=[batch,ndim])
        return u+K.exp(0.5*logvar)*z
    
    X_Input=Input(shape=(n_x,),name='encode_input')
    X=Dense(n_h,activation='relu',name='encode_hidden',kernel_initializer='he_normal')(X_Input)
    
    u_z=Dense(n_z,name='encode_mean',kernel_initializer='he_normal')(X)
    logvar_z=Dense(n_z,name='encode_log_var',kernel_initializer='he_normal')(X)
    Z=Lambda(sample,name='encoder_z')([u_z,logvar_z])
    
    model=Model(inputs=X_Input,outputs=[u_z,logvar_z,Z],name='encoder')
    
    return model
    

In [None]:
encoder=vae_encoder()
encoder.summary()

In [None]:
SVG(model_to_dot(encoder,show_shapes=True).create(prog='dot', format='svg'))

## Define a Decoder
* ###  the struct is Z---->H1---->Xhat 

In [None]:
def vae_decoder(n_x=784,n_h=512,n_z=2):
    '''
        n_x:the shape(dims) of reconstruct Image
        n_h:hidden Units
        n_z:latent dims
        return:a model with outputs is reconstruct Image
    '''
    Z_Input=Input(shape=(n_z,),name='decoder_input')
    X=Dense(n_h,activation='relu',name='decoder_hidden',kernel_initializer='he_normal')(Z_Input)
    X=Dense(n_x,activation='sigmoid',name='decoder_reconstruct',kernel_initializer='he_normal')(X)
    
    model=Model(inputs=Z_Input,outputs=X,name='decoder')
    return model

In [None]:
decoder=vae_decoder()
decoder.summary()

In [None]:
SVG(model_to_dot(decoder,show_shapes=True).create(prog='dot', format='svg'))

## combine Encoder and Decoder(vae_mlp)
<img src='images/vae_loss.png' />
<img src='images/vae_kl_loss.png' />

In [None]:
def vae_mlp(encoder,decoder,n_x=784):
    '''
        encoder: X(784)---->[uz,logvar_z,z]
        decoder:Z(2)-------->Xhat(784)
        
        loss=(X-Xhat)**2+KL{ qz|x |N(0,I)}
        notice Pz =N(0,I)
    '''
    
    X=Input(shape=(n_x,),name='MyInput')
    uz,logvar_z,z=encoder(X)
    Xhat=decoder(z)
    
    
    #work out the loss
    Re_loss=K.sum((X-Xhat)**2,axis=-1) #shape (?,)
    KL_loss=1+logvar_z-K.square(uz)-K.exp(logvar_z) #shape(?,n_z)
    KL_loss=-0.5*K.sum(KL_loss,axis=-1)  #shape(?,)
    loss=K.mean(Re_loss+0*KL_loss)  #shape(?)
    
    model=Model(inputs=X,outputs=Xhat)
    model.add_loss(loss)
    return model

In [None]:
vae=vae_mlp(encoder,decoder)

In [None]:
vae.summary()

In [None]:
SVG(model_to_dot(vae,show_shapes=True).create(prog='dot', format='svg'))

In [None]:
model_path='outputs/vae.h5'
def callback():
    def myLearnRateScheduler(epoch,lr):
        print('ecpch:%d,learn rate %f'%(epoch,lr))
        return lr
    lr_scheduler=LearningRateScheduler(myLearnRateScheduler)
    checkpoint=ModelCheckpoint(model_path,monitor='val_loss',save_best_only=True,verbose=1)
    reduceOnpleateau=ReduceLROnPlateau(monitor='val_loss',min_delta=5e-5,factor=0.9,verbose=1,patience=50)
    return [lr_scheduler,checkpoint,reduceOnpleateau]

In [None]:
vae.compile(optimizer='adam')

In [None]:
batch=64
epoch=50
vae.fit(X_train,batch_size=batch,epochs=epoch,validation_data=(X_test,None),callbacks=callback())

In [None]:
## save or load model

In [None]:
encoder_path='outputs/vae_encoder'
decoder_path='outputs/vae_decoder'
encoder.save(encoder_path)
decoder.save(decoder_path)

In [None]:
encoder=load_model(encoder_path,custom_objects={'tf':tf})
decoder=load_model(decoder_path,custom_objects={'tf':tf})

## View the distribution of Data

In [None]:
def plot_distribution(encoder,X,Y):
    plt.rcParams['figure.figsize']=(15,15)
    plt.rcParams['image.cmap']='hsv'
    Z=encoder.predict(X,batch_size=512)[0]
    plt.scatter(Z[:,0],Z[:,1],c=Y)
    plt.colorbar()
plot_distribution(encoder,X_test,Y_test)

## view how letent z affect Image 

In [None]:
def plot_model(decoder,n=15,dim=28,reshaper=None,channels=1):
    '''
        using decoder to generate image
        n:is num per axis
        dim:one image size
    '''
    plt.rcParams['figure.figsize']=(15,15)
    xrange=np.linspace(-4,4,n)
    yrange=np.linspace(4,-4,n)
    zx,zy=np.meshgrid(xrange,yrange)

    z=np.stack([zx.ravel(),zy.ravel()],axis=1) #shape[10000,2]
    I=reshaper(decoder.predict(z))
    F=np.zeros((n*dim,n*dim,channels),dtype=np.float32)
    F=np.squeeze(F)
    for i in range(n):
        for j in range(n):
            F[i*dim:i*dim+dim,j*dim:j*dim+dim]=I[i*n+j]
    xlabels=np.round(xrange,2)
    ylabels=np.round(yrange,2)
    xloc=np.arange(dim//2,dim*n+dim//2,dim)
    yloc=xloc.copy()
    plt.xticks(xloc,xlabels,size='large',rotation=45)
    plt.yticks(yloc,ylabels,size='large',rotation=45)
    plt.xlabel('z1')
    plt.ylabel('z2')
    plt.imshow(F)


In [None]:
plot_model(decoder,n=15,dim=28,reshaper=reshaper)

### Now Try Cifar10 set

In [None]:
from cifar.CIFAR10Utils import load_dataset

In [None]:
X_train,Y_train,X_test,Y_test,_=load_dataset(flaten=True,one_hot=False,filename='../AI_database/cifar/CIFAR10_DATA')

In [None]:
encoder_cifar=vae_encoder(n_x=3072,n_h=512,n_z=2)
SVG(model_to_dot(encoder_cifar,show_shapes=True).create(prog='dot', format='svg'))

In [None]:
decoder_cifar=vae_decoder(n_x=3072,n_h=512,n_z=2)
SVG(model_to_dot(decoder_cifar,show_shapes=True).create(prog='dot', format='svg'))

In [None]:
vae_cifar=vae_mlp(encoder_cifar,decoder_cifar,n_x=3072)
SVG(model_to_dot(vae_cifar,show_shapes=True).create(prog='dot', format='svg'))

In [None]:
vae_cifar.compile(optimizer=Adam(4e-4))
gen=ImageDataGenerator(
    width_shift_range=0,
    height_shift_range=0,
    horizontal_flip=True,
    zoom_range=[0.5,1.0],  #<1 mean zoom in,>1 zoom out  
    rescale=1.0   ,  #mul image by rescale
#     zca_whitening=True,
#     zca_epsilon=0.1
).flow(X_train.reshape(-1,32,32,3),None)
def myGen(gen):
    while True:
        x=next(gen)
        x=x.reshape(-1,3*32*32)
        yield (x,None)
generator=myGen(gen)

In [None]:
batch=64
epoch=5000

vae_cifar.fit(X_train[0:64],batch_size=batch,epochs=epoch,validation_data=(X_test,None),callbacks=callback())
# vae_cifar.fit_generator(generator,epochs=epoch,steps_per_epoch=50000//batch,
#                         validation_data=(X_test,None),callbacks=callback())

In [None]:
cifar_encoder='outputs/cifar_encoder'
cifar_decoder='outputs/cifar_decoder'
encoder_cifar.save(cifar_encoder)
decoder_cifar.save(cifar_decoder)

In [None]:
plot_distribution(encoder_cifar,X_test,Y_test)

In [None]:
plot_model(decoder_cifar,n=5,dim=32,reshaper=lambda x:np.reshape(x,(-1,32,32,3)),channels=3)

In [None]:
plt.rcParams['figure.figsize']=(3,3)
plt.imshow(np.reshape(X_test[11],(32,32,3)))
# print(np.reshape(X_test[9],(32,32,3)))