In [1]:
import tensorflow as tf
import numpy as np

In [2]:
input_shape = 32
latent_shape = 8

In [3]:
class Autoencoder:
    def __init__(self, input_shape, latent_shape):
        self.X = tf.placeholder(tf.float32, 
                                shape=(None,input_shape),name = 'X')
        self.W1 = tf.Variable(
                    tf.random_normal(shape=(input_shape,latent_shape)))
        self.b1 = tf.Variable(
                    np.zeros(latent_shape).astype(np.float32))

        self.W2 = tf.Variable(
                    tf.random_normal(shape=(latent_shape,input_shape)))
        self.b2 = tf.Variable(
                    np.zeros(input_shape).astype(np.float32))

        self.Z = tf.nn.relu(tf.matmul(self.X, self.W1) + self.b1)
        logits = tf.matmul(self.Z, self.W2) + self.b2

        self.X_hat = tf.nn.sigmoid(logits)

        self.loss = tf.reduce_sum(
                        tf.nn.sigmoid_cross_entropy_with_logits(
                        labels = self.X,
                        logits = logits
                        )
                    )

        self.optimizer = tf.train.RMSPropOptimizer(learning_rate=0.005
                                                  ).minimize(self.loss)
        self.init_op = tf.global_variables_initializer()
        self.sess = tf.get_default_session()
        if(self.sess == None):
            self.sess = tf.Session()
        self.sess.run(self.init_op)
        
    def fit(self, X, epochs=10, bs=64):
        n_batches = X.shape[0] // bs
        print("Training {} batches".format(n_batches))
        
        for i in range(epochs):
            print("Epoch: {}".format(i))
            X_perm = np.random.permutation(X)
            for j in range(n_batches):
                batch = X_perm[j*bs:(j+1)*bs]
                _, _ = self.sess.run((self.optimizer, self.loss),
                                      feed_dict={self.X: batch})
    
    def save(self,export_dir='./'):
        tf.saved_model.simple_save(self.sess,
                                   export_dir,
                                   inputs={"X":self.X},
                                   outputs={"X_hat":self.X_hat})
    
    def predict(self, X):
        return self.sess.run(self.X_hat, feed_dict={self.X: X})
    
    def encoder(self, X):
        return self.sess.run(self.Z, feed_dict={self.X: X})
    
    def decode(self, Z):
        return self.sess.run(self.X_hat, feed_dict={self.Z: Z})
    
    def terminate(self):
        self.sess.close()
        del self.sess

In [4]:
ae = Autoencoder(input_shape, latent_shape)

W0729 21:44:42.715497 139786816022336 deprecation.py:323] From /home/kyjohnso/projects/mlirad/autoencoders/venv/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0729 21:44:42.781613 139786816022336 deprecation.py:506] From /home/kyjohnso/projects/mlirad/autoencoders/venv/lib/python3.6/site-packages/tensorflow/python/training/rmsprop.py:119: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [56]:
n_samples = 10000
X = np.random.uniform(0,1,(n_samples,input_shape))
ae.fit(X)

Training 156 batches
Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9


In [58]:
print(X.shape)

(10000, 32)


In [25]:
saved_dir = './saved/5/'
ae.save(saved_dir)

In [59]:
def representative_dataset_gen():
    n_samples = 100
    for i in range(n_samples):
        rep = np.random.uniform(0,1,(1,input_shape))
        rep = np.array(rep,dtype=np.float32)
        yield [rep]

In [60]:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

In [64]:
ae_tflite = converter.convert()

In [65]:
open("./ae_tflite.tflite","wb").write(ae_tflite)

2200