In [28]:
import numpy as np
from keras.layers import (Lambda, Input, Reshape,
                          Dense, UpSampling2D,
                          Conv2D, Concatenate,
                          Flatten, MaxPool2D,)
from keras.losses import mse, mae, binary_crossentropy
from keras.models import Model
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split

import keras.backend as K

In [6]:
def settrainable(model, toset):
    for layer in model.layers:
        layer.trainable = toset
    model.trainable = toset

In [20]:
input_shape = (256,8,1) # (nfft,n_lookback,n_channels)
filters = [32,32,32]
layers = 3
kernel_size = [3,3]
pool_size = [2,2]
intermediate = 16
latent_dim = 50
n_classes = 8

inputs = Input(shape=input_shape)
x = inputs
for i in range(layers):
    x = Conv2D(filters[i],
               kernel_size=kernel_size,
               activation='relu',
               padding='same')(x)
    x = MaxPool2D(pool_size=pool_size)(x)


# shape info needed to build decoder model
shape = K.int_shape(x)

x = Flatten()(x)
# generate latent vector Q(z|X)
features = Dense(latent_dim, activation='linear',name='features')(x)
classes = Dense(n_classes, activation='softmax',name='classes')(x)

# no reparameterization trick here, since we will get the distribution
#   through adversarial training against a stocastic process we don't
#   need to sample here

encoder = Model(inputs, [features,classes], name='encoder')
encoder.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           (None, 256, 8, 1)    0                                            
__________________________________________________________________________________________________
conv2d_28 (Conv2D)              (None, 256, 8, 32)   320         input_13[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_21 (MaxPooling2D) (None, 128, 4, 32)   0           conv2d_28[0][0]                  
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 128, 4, 32)   9248        max_pooling2d_21[0][0]           
__________________________________________________________________________________________________
max_poolin

(None, 32, 1, 32)

In [23]:
filters = [32,64,64]
upsampling_size = pool_size

latent_inputs = Input(shape=(latent_dim+n_classes,), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for i in range(layers):
    x = Conv2D(filters=filters[i],
               kernel_size=kernel_size,
               activation='relu',
               padding='same')(x)
    x = UpSampling2D(size=upsampling_size)(x)

x = Conv2D(filters=1,
           kernel_size=kernel_size,
           activation='relu',
           padding='same',
           name='decoder_output')(x)

outputs=x

decoder = Model(latent_inputs,outputs)
decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 58)                0         
_________________________________________________________________
dense_11 (Dense)             (None, 1024)              60416     
_________________________________________________________________
reshape_3 (Reshape)          (None, 32, 1, 32)         0         
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 32, 1, 32)         9248      
_________________________________________________________________
up_sampling2d_5 (UpSampling2 (None, 64, 2, 32)         0         
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 64, 2, 64)         18496     
_________________________________________________________________
up_sampling2d_6 (UpSampling2 (None, 128, 4, 64)        0         
__________

Build the discriminators

In [35]:
disc_dim = 256
feat_disc_inputs = Input(shape=(latent_dim,), name='disc_input')
x = feat_disc_inputs
x = Dense(disc_dim, activation='relu')(x)
x = Dense(disc_dim, activation='relu')(x)
feat_disc_outputs = Dense(1,activation='sigmoid')(x)

feat_disc = Model(feat_disc_inputs,feat_disc_outputs,name='feat_disc')
feat_disc.compile(optimizer=Adam(lr=1e-4), 
                      loss="binary_crossentropy")
feat_disc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
disc_input (InputLayer)      (None, 50)                0         
_________________________________________________________________
dense_21 (Dense)             (None, 256)               13056     
_________________________________________________________________
dense_22 (Dense)             (None, 256)               65792     
_________________________________________________________________
dense_23 (Dense)             (None, 1)                 257       
Total params: 79,105
Trainable params: 79,105
Non-trainable params: 0
_________________________________________________________________


In [36]:
disc_dim = 256
class_disc_inputs = Input(shape=(n_classes,), name='disc_input')
x = class_disc_inputs
x = Dense(disc_dim, activation='relu')(x)
x = Dense(disc_dim, activation='relu')(x)
class_disc_outputs = Dense(1,activation='sigmoid')(x)

class_disc = Model(class_disc_inputs,class_disc_outputs,name='class_disc')
class_disc.compile(optimizer=Adam(lr=1e-4), 
                      loss="binary_crossentropy")
class_disc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
disc_input (InputLayer)      (None, 8)                 0         
_________________________________________________________________
dense_24 (Dense)             (None, 256)               2304      
_________________________________________________________________
dense_25 (Dense)             (None, 256)               65792     
_________________________________________________________________
dense_26 (Dense)             (None, 1)                 257       
Total params: 68,353
Trainable params: 68,353
Non-trainable params: 0
_________________________________________________________________


Define composite models

In [37]:
outputs = decoder(Concatenate(axis=-1)(encoder(inputs)))
ae = Model(inputs, outputs, name='ae')
ae.compile(optimizer=Adam(lr=1e-4), 
           loss="binary_crossentropy")
ae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           (None, 256, 8, 1)    0                                            
__________________________________________________________________________________________________
encoder (Model)                 [(None, 50), (None,  78266       input_13[0][0]                   
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 58)           0           encoder[5][0]                    
                                                                 encoder[5][1]                    
__________________________________________________________________________________________________
model_1 (Model)                 (None, 256, 8, 1)    125665      concatenate_3[0][0]              
Total para

In [38]:
feat_disc_output = feat_disc(encoder(inputs)[0])
enc_feat_disc = Model(inputs,feat_disc_output,name='enc_feat_disc')
enc_feat_disc.compile(optimizer=Adam(lr=1e-4), 
                      loss="binary_crossentropy")
enc_feat_disc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_13 (InputLayer)        (None, 256, 8, 1)         0         
_________________________________________________________________
encoder (Model)              [(None, 50), (None, 8)]   78266     
_________________________________________________________________
feat_disc (Model)            (None, 1)                 79105     
Total params: 157,371
Trainable params: 157,371
Non-trainable params: 0
_________________________________________________________________


In [40]:
class_disc_output = class_disc(encoder(inputs)[1])
enc_class_disc = Model(inputs,class_disc_output,name='enc_class_disc')
enc_class_disc.compile(optimizer=Adam(lr=1e-4), 
                      loss="binary_crossentropy")
enc_class_disc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_13 (InputLayer)        (None, 256, 8, 1)         0         
_________________________________________________________________
encoder (Model)              [(None, 50), (None, 8)]   78266     
_________________________________________________________________
class_disc (Model)           (None, 1)                 68353     
Total params: 146,619
Trainable params: 146,619
Non-trainable params: 0
_________________________________________________________________


define sampling procedures for the latent feature and class distributions

In [52]:
def sample_classes(labels,n_classes=None):
    if n_classes == None:
        n_classes = np.max(labels)
    ulabel = labels == -1
    labels[ulabel] = np.random.randint(0,n_classes,np.sum(ulabel))
    return labels

def sample_features(n_samples,n_dimensions):
    return np.random.multivariate_normal(np.zeros(n_dimensions),np.eye(n_dimensions),n_samples)

In [8]:
n_samples = 10000
X = np.random.uniform(0,1,(n_samples,)+input_shape)
X_train, X_test = train_test_split(X,train_size=0.5)

In [9]:
epochs = 1
batch_size = 1000
for i_epoch in range(epochs):
    np.random.shuffle(X_train)
    
    for i_batch in range(int(X_train.shape[0]/batch_size)):
        settrainable(ae, True)
        settrainable(encoder, True)
        settrainable(decoder, True)
        
        batch = X_train[i_batch*batch_size:(i_batch+1)*batch_size]
 
        # first train the autoencoder
        ae.train_on_batch(batch,batch)
        
        settrainable(discriminator, True)
        batchpred = encoder.predict(batch)
        fakepred = np.random.normal(0,1,(batch_size,latent_dim,))
        
        # now train the discriminator giving it ones for true, and 
        #     zeros for fake
        discbatch_x = np.concatenate([batchpred,fakepred])
        discbatch_y = np.concatenate([np.ones(batch_size),
                                      np.zeros(batch_size)])
        discriminator.train_on_batch(discbatch_x,discbatch_y)
        
        # now train the encoder descriminator but only update the 
        #     encoder weights and try to fool the discriminator
        settrainable(enc_disc, True)
        settrainable(encoder, True)
        settrainable(discriminator, False)
        enc_disc.train_on_batch(batch, np.ones(batch_size))
        
        print("Reconstruction Loss:", 
                  ae.evaluate(X_train, X_train, verbose=0))
        print("Adversarial Loss:", 
                  enc_disc.evaluate(X_train, 
                                    np.ones(X_train.shape[0]),
                                    verbose=0))

Instructions for updating:
Use tf.cast instead.


  'Discrepancy between trainable weights and collected trainable'


Reconstruction Loss: 2.991969310506185
Adversarial Loss: 0.6893415473620097
Reconstruction Loss: 2.715244542312622
Adversarial Loss: 0.6873656316121419
Reconstruction Loss: 2.562117127609253
Adversarial Loss: 0.6854145938555399
Reconstruction Loss: 2.4570766058603923
Adversarial Loss: 0.6833689728101094
Reconstruction Loss: 2.3773221271514893
Adversarial Loss: 0.6812911276181539
Reconstruction Loss: 2.3129668285369873
Adversarial Loss: 0.6791802398999532
Reconstruction Loss: 2.2588912197113036
Adversarial Loss: 0.677025761381785


In [9]:
ae.save('./autoencoder.h5')