In [58]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import random

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = to_categorical(y_train), to_categorical(y_test)


class SimpleCNN(Model):
    def __init__(self):
        super().__init__()
        self.model = self.build_model()
        # dont include the final output layer
        self.activation_model = Model(inputs=self.model.input, outputs=[layer.output for layer in self.model.layers[:-1]])

    def build_model(self):
        inputs = Input(shape=(32, 32, 3))
        x = layers.Conv2D(8, 3, activation='relu')(inputs)
        x = layers.MaxPooling2D(2)(x)
        x = layers.Conv2D(8, 3, activation='relu')(x)
        x = layers.MaxPooling2D(2)(x)
        x = layers.Conv2D(8, 3, activation='relu')(x)
        x = layers.Flatten()(x)
        outputs = layers.Dense(10, activation='softmax')(x)
        return Model(inputs=inputs, outputs=outputs)

    def call(self, x):
        return self.model(x)

    def get_activations(self, x):
        return self.activation_model.predict(x)

    def get_specific_activations(self, x, activation_indices):
        activations = self.get_activations(x)
        flattened_activations = tf.concat([tf.reshape(a, [-1]) for a in activations], axis=0)
        return tf.gather(flattened_activations, activation_indices)


model = SimpleCNN()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=1000, epochs=2, validation_data=(x_test, y_test), verbose=1)

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)

activations = model.get_activations(x_test[0:1])
for i, activation in enumerate(activations):
    print(f'Layer {i+1} activation shape: {activation.shape}')

total_activations = sum(a.size for a in activations)
print('Total activations for a single image:', total_activations)

# 10 random integers between 0 and total_activations
print(model.get_specific_activations(x_test[0:1], [random.randint(0, total_activations) for _ in range(10)]))

class BaseAttention(layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = layers.MultiHeadAttention(**kwargs)
    self.layernorm = layers.LayerNormalization()
    self.add = layers.Add()

class GSA(BaseAttention):
  def call(self, x):
    attn_output = self.mha(
        query=x,
        value=x,
        key=x)
    x = self.add([x, attn_output])
    x = self.layernorm(x)
    return x

class FeedForward(layers.Layer):
  def __init__(self, d_model, dff):
    super().__init__()
    self.seq = models.Sequential([
      layers.Dense(dff, activation='relu'),
      layers.Dense(d_model)
    ])
    self.add = layers.Add()

  def call(self, x):
    x = self.add([x, self.seq(x)])
    return x

class TransformLayer(layers.Layer):
  def __init__(self,*, d_model, num_heads, dff):
    super().__init__()

    self.self_attention = GSA(
        num_heads=num_heads,
        key_dim=d_model)

    self.ffn = FeedForward(d_model, dff)

  def call(self, x):
    x = self.self_attention(x)
    x = self.ffn(x)
    return x

class Transformer(Model):
  def __init__(self, *, num_layers, d_model, num_heads, dff):
    super().__init__()
    self.embedder = layers.Dense(d_model)
    self.layerstack = [TransformLayer(d_model=d_model, num_heads=num_heads, dff=dff) for _ in range(num_layers)]

  def call(self, x):
    x = self.embedder(x)
    for layer in self.layerstack:
        x = layer(x)
    return x



Epoch 1/2
 1/50 [..............................] - ETA: 14s - loss: 2.3180 - accuracy: 0.1030

2023-04-19 11:40:24.601188: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/2
 1/50 [..............................] - ETA: 0s - loss: 2.1698 - accuracy: 0.1930

2023-04-19 11:40:25.718798: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.




2023-04-19 11:40:26.917098: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


313/313 - 3s - loss: 1.8655 - accuracy: 0.3281 - 3s/epoch - 8ms/step

Test accuracy: 0.3280999958515167
Layer 1 activation shape: (1, 32, 32, 3)
Layer 2 activation shape: (1, 30, 30, 8)
Layer 3 activation shape: (1, 15, 15, 8)
Layer 4 activation shape: (1, 13, 13, 8)
Layer 5 activation shape: (1, 6, 6, 8)
Layer 6 activation shape: (1, 4, 4, 8)
Layer 7 activation shape: (1, 128)
Layer 8 activation shape: (1, 10)
Total activations for a single image: 13978
tf.Tensor(
[0.25490198 0.31761384 0.858051   0.12029882 0.7586678  0.03061792
 0.         0.68235296 0.716138   0.20148279], shape=(10,), dtype=float32)


2023-04-19 11:40:29.507552: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
