In [1]:
import numpy as np
import tensorflow as tf
from keras.layers import (Lambda, Input, Reshape,
                          Dense, UpSampling2D,
                          Conv2D, Concatenate,
                          Flatten, MaxPool2D,
                          Subtract)
from keras.losses import mse, mae, binary_crossentropy
from keras.models import Model, save_model
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split

import keras.backend as K

Using TensorFlow backend.


In [2]:
def settrainable(model, toset):
    for layer in model.layers:
        layer.trainable = toset
    model.trainable = toset

In [3]:
input_shape = (256,8,1) # (nfft,n_lookback,n_channels)
filters = [32,32,32]
layers = 3
kernel_size = [3,3]
pool_size = [2,2]
intermediate = 16
latent_dim = 50
n_classes = 8

inputs = Input(shape=input_shape)
x = inputs
for i in range(layers):
    x = Conv2D(filters[i],
               kernel_size=kernel_size,
               activation='relu',
               padding='same')(x)
    x = MaxPool2D(pool_size=pool_size)(x)


# shape info needed to build decoder model
shape = K.int_shape(x)

x = Flatten()(x)
# generate latent vector Q(z|X)
features = Dense(latent_dim, activation='linear',name='features')(x)

# no reparameterization trick here, since we will get the distribution
#   through adversarial training against a stocastic process we don't
#   need to sample here

encoder = Model(inputs, features, name='encoder')
encoder.summary()

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 256, 8, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 256, 8, 32)        320       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 128, 4, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 128, 4, 32)        9248      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 64, 2, 32)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 64, 2, 32)         9248      
_________________________________________________________________
max_

In [4]:
filters = [32,64,64]
upsampling_size = pool_size

latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
# latent_inputs = Input(shape=(latent_dim+n_classes,), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for i in range(layers):
    x = Conv2D(filters=filters[i],
               kernel_size=kernel_size,
               activation='relu',
               padding='same')(x)
    x = UpSampling2D(size=upsampling_size)(x)

x = Conv2D(filters=1,
           kernel_size=kernel_size,
           activation='relu',
           padding='same',
           name='decoder_output')(x)

outputs=x

decoder = Model(latent_inputs,outputs)
decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              52224     
_________________________________________________________________
reshape_1 (Reshape)          (None, 32, 1, 32)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 32, 1, 32)         9248      
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 64, 2, 32)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 64, 2, 64)         18496     
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 128, 4, 64)        0         
__________

Build the discriminator

In [5]:
disc_dim = 256
feat_disc_inputs = Input(shape=(latent_dim,), name='disc_input')
x = feat_disc_inputs
x = Dense(disc_dim, activation='relu')(x)
x = Dense(disc_dim, activation='relu')(x)
feat_disc_outputs = Dense(1,activation='sigmoid')(x)

feat_disc = Model(feat_disc_inputs,feat_disc_outputs,name='feat_disc')
feat_disc.compile(optimizer=Adam(lr=1e-4), 
                      loss="binary_crossentropy")
feat_disc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
disc_input (InputLayer)      (None, 50)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               13056     
_________________________________________________________________
dense_3 (Dense)              (None, 256)               65792     
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 257       
Total params: 79,105
Trainable params: 79,105
Non-trainable params: 0
_________________________________________________________________


Define composite models

In [6]:
outputs = decoder(encoder(inputs))
# outputs = decoder(Concatenate(axis=-1)(encoder(inputs)))
ae = Model(inputs, outputs, name='ae')
ae.compile(optimizer=Adam(lr=1e-4), 
           loss="binary_crossentropy")
ae.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 256, 8, 1)         0         
_________________________________________________________________
encoder (Model)              (None, 50)                70066     
_________________________________________________________________
model_1 (Model)              (None, 256, 8, 1)         117473    
Total params: 187,539
Trainable params: 187,539
Non-trainable params: 0
_________________________________________________________________


In [7]:
outputs = decoder(encoder(inputs))
reconst_loss = Subtract()([outputs,inputs])
# outputs = decoder(Concatenate(axis=-1)(encoder(inputs)))
ae_re = Model(inputs, reconst_loss, name='ae_re')
ae_re.compile(optimizer=Adam(lr=1e-4), 
              loss="binary_crossentropy")
ae_re.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 256, 8, 1)    0                                            
__________________________________________________________________________________________________
encoder (Model)                 (None, 50)           70066       input_1[0][0]                    
__________________________________________________________________________________________________
model_1 (Model)                 (None, 256, 8, 1)    117473      encoder[2][0]                    
__________________________________________________________________________________________________
subtract_1 (Subtract)           (None, 256, 8, 1)    0           model_1[2][0]                    
                                                                 input_1[0][0]                    
Total para

In [8]:
feat_disc_output = feat_disc(encoder(inputs))
enc_feat_disc = Model(inputs,feat_disc_output,name='enc_feat_disc')
enc_feat_disc.compile(optimizer=Adam(lr=1e-4), 
                      loss="binary_crossentropy")
enc_feat_disc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 256, 8, 1)         0         
_________________________________________________________________
encoder (Model)              (None, 50)                70066     
_________________________________________________________________
feat_disc (Model)            (None, 1)                 79105     
Total params: 149,171
Trainable params: 149,171
Non-trainable params: 0
_________________________________________________________________


define sampling procedures for the latent feature and class distributions

In [9]:
def sample_classes(labels,n_classes=None):
    if n_classes == None:
        n_classes = np.max(labels)
    ulabel = labels == -1
    labels[ulabel] = np.random.randint(0,n_classes,np.sum(ulabel))
    labels = labels.astype(int)
    oh = np.zeros((labels.shape[0],n_classes),dtype=int)
    oh[range(labels.shape[0]),labels]=1
    return oh

def sample_features(n_samples,n_dimensions):
    return np.random.multivariate_normal(np.zeros(n_dimensions),np.eye(n_dimensions),n_samples)

In [10]:
n_samples = 10000
X = np.random.uniform(0,1,(n_samples,)+input_shape)
X_train, X_test = train_test_split(X,train_size=0.5)

In [11]:
epochs = 1
batch_size = 1000
for i_epoch in range(epochs):
    np.random.shuffle(X_train)
    
    for i_batch in range(int(X_train.shape[0]/batch_size)):
        settrainable(ae, True)
        settrainable(encoder, True)
        settrainable(decoder, True)
        
        batch = X_train[i_batch*batch_size:(i_batch+1)*batch_size]
 
        # first perform reconstruction training
        ae.train_on_batch(batch,batch)
        
        # now perform regularization training
        settrainable(feat_disc, True)
        batch_features = encoder.predict(batch)
        fake_features = sample_features(batch_size,latent_dim,)
        
        # now train the feat_disc giving it ones for true, and 
        #     zeros for fake
        discbatch_x = np.concatenate([batch_features,fake_features])
        discbatch_y = np.concatenate([np.ones(batch_size),
                                      np.zeros(batch_size)])
        feat_disc.train_on_batch(discbatch_x,discbatch_y)
        
        # now train the enc_feat_disc but only update the 
        #     encoder weights and try to fool the discriminator
        settrainable(enc_feat_disc, True)
        settrainable(encoder, True)
        settrainable(feat_disc, False)
        enc_feat_disc.train_on_batch(batch, np.ones(batch_size))
        
        print("Reconstruction Loss:", 
                  ae.evaluate(X_train, X_train, verbose=0))
        
        print("Feature Adversarial Loss:", 
                  enc_feat_disc.evaluate(X_train, 
                                         np.ones(X_train.shape[0]),
                                         verbose=0))

Instructions for updating:
Use tf.cast instead.


  'Discrepancy between trainable weights and collected trainable'


Reconstruction Loss: 5.310517658996582
Feature Adversarial Loss: 0.6490545263290405
Reconstruction Loss: 4.17730817565918
Feature Adversarial Loss: 0.6358655561447144
Reconstruction Loss: 3.3097448795318605
Feature Adversarial Loss: 0.6217564010620117
Reconstruction Loss: 2.8004827194213866
Feature Adversarial Loss: 0.6092704731941223
Reconstruction Loss: 2.55276851234436
Feature Adversarial Loss: 0.5974244488716125


In [None]:
save_model(ae, 'ae.h5')

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model_file('ae.h5')

In [None]:
ae_tflite = converter.convert()

In [None]:
open("ae_tflite.tflite","wb").write(ae_tflite)