In [1]:
import sklearn
import tensorflow as tf
from tensorflow import keras
import numpy as np

mnist = tf.keras.datasets.mnist
layer_amt = 3

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [2]:
base_model = tf.keras.models.Sequential()
base_model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))

for x in range(layer_amt): 
    base_model.add(tf.keras.layers.Dense(128, activation='relu'))

base_model.add(tf.keras.layers.Dropout(0.2))
base_model.add(tf.keras.layers.Dense(10))


base_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
base_model.fit(x_train, y_train, epochs=15)

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


<keras.src.callbacks.History at 0x29494a020>

In [3]:
base_model.evaluate(x_test,  y_test, verbose=2)

313/313 - 0s - loss: 0.1114 - accuracy: 0.9801 - 210ms/epoch - 672us/step


[0.1114104613661766, 0.9800999760627747]

In [124]:
import tensorflow.keras.backend as K
from tensorflow.keras import activations

@keras.saving.register_keras_serializable()
class PCAProj(tf.keras.layers.Layer):
    def __init__(self, layer, k, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.trainable=False
        self.activation = activations.get(activation)
        self.k = k
        
        # if W is m x n, A is m x k and B is k x n
        w, b = layer.get_weights()
        self.b = K.constant(b)
        if w is not None: 
            u, s, vT = np.linalg.svd(w, full_matrices=False)
            self.A = K.constant(u[:, :k] @ np.diag(s[:k]))
            self.B = K.constant(vT[:k])
        
    def call(self, inputs):  
        return self.activation(tf.matmul(inputs, tf.matmul(self.A, self.B)) + self.b)

    def get_config(self):
        base_config = super().get_config()
        config = {
            "A": keras.saving.serialize_keras_object(self.A),
            "B": keras.saving.serialize_keras_object(self.B),
            "bias": keras.saving.serialize_keras_object(self.b),
            "k": self.k,
            "activation": self.activation
        }
        return {**base_config, **config}
    
    @classmethod
    def from_config(cls, config):
        A = keras.saving.deserialize_keras_object(config.pop("A"))
        B = keras.saving.deserialize_keras_object(config.pop("B"))
        bias = keras.saving.deserialize_keras_object(config.pop("bias"))
        layer = cls([None, bias], **config)
        layer.A = A
        layer.B = B
        return layer

In [127]:
optimized_model = tf.keras.models.Sequential()
optimized_model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))

weights = base_model.get_weights()

ks = [40, 18, 21]
for x in range(layer_amt): 
    optimized_model.add(PCAProj(base_model.layers[x], ks[x], activation='relu'))
    
optimized_model.add(tf.keras.layers.Dropout(0.2))
optimized_model.add(PCAProj(base_model.layers[5], k = 10, activation='relu'))

optimized_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

ValueError: not enough values to unpack (expected 2, got 0)

In [126]:
optimized_model.evaluate(x_test,  y_test, verbose=2)

313/313 - 0s - loss: 0.1295 - accuracy: 0.9714 - 255ms/epoch - 815us/step


[0.1294977068901062, 0.9714000225067139]

In [82]:
memory = 0

for layer in optimized_model.layers: 
    if isinstance(layer, PCAProj): 
        memory += np.prod(layer.A.shape) + np.prod(layer.B.shape)
        memory += np.prod(layer.b.shape)
    if layer.weights: 
        memory += np.prod(layer.weights[0].shape)
        memory += np.prod(layer.weights[1].shape)
        

bad_mem = 0
for layer in base_model.layers: 
    if layer.weights: 
        bad_mem += np.prod(layer.weights[0].shape)
        bad_mem += np.prod(layer.weights[1].shape)
        
memory, bad_mem

(48238, 134794)

In [86]:
optimized_model.save("optimized_model.keras")

In [99]:
test = []

for layer in optimized_model.layers: 
    if isinstance(layer, PCAProj): 
        test.append(layer.A.numpy().flatten())
        test.append(layer.B.numpy().flatten())
        test.append(layer.b.numpy().flatten())
test = np.concatenate(test)
np.save("test.npy", test)

In [92]:
(np.prod(optimized_model.layers[2].A.shape) + np.prod(optimized_model.layers[2].A.shape))/np.prod(base_model.layers[2].get_weights()[0].shape)

0.625

In [104]:
optimized_model.layers[1].A.shape, optimized_model.layers[1].B.shape, optimized_model.layers[2].A.shape, optimized_model.layers[2].B.shape

((784, 40), (40, 128), (128, 40), (40, 128))

In [9]:
optimized_model.layers

[<keras.src.layers.reshaping.flatten.Flatten at 0x2de9005b0>,
 <__main__.PCAProj at 0x2de902440>,
 <__main__.PCAProj at 0x2de8f6830>,
 <__main__.PCAProj at 0x2de8f6980>,
 <keras.src.layers.regularization.dropout.Dropout at 0x104f36ad0>,
 <__main__.PCAProj at 0x2de8f5db0>]