In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa


In [None]:
num_classes = 10
input_shape = (160,14, 11)

In [None]:
import pickle

pickle_in = open("X_train_encoded_full_2D.pickle","rb")
X_train = pickle.load(pickle_in)

pickle_in = open("y_train_full.pickle","rb")
y_train = pickle.load(pickle_in)


pickle_in = open("X_test_encoded_full_2D.pickle","rb")
X_test = pickle.load(pickle_in)

pickle_in = open("y_test_full.pickle","rb")
y_test = pickle.load(pickle_in)

In [None]:
y_train = np.array(y_train)
y_test = np.array(y_test)

print(f"x_train shape: {X_train.shape}")
print(f"x_test shape: {X_test.shape}")

X_train_pos= np.transpose(X_train, (0, 2, 3, 1))
#X_test_pos= np.transpose(X_test, (0, 2, 3, 1))


print(f"x_train shape: {X_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {X_test.shape} - y_test shape: {y_test.shape}")

In [None]:
X_train_pos.shape

In [None]:
#Configure the hyperparameters
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 200
image_size = 20  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
dct_projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

In [None]:
#Use data augmentation
data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Normalization(),
        layers.experimental.preprocessing.Resizing(image_size, image_size),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(X_train_pos)

In [None]:
#Implement multilayer perceptron (MLP)
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [None]:
#Implement patch creation as a layer
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'patch_size':self.patch_size
        })
        return config

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(4, 4))
image = X_train_pos[np.random.choice(range(X_train_pos.shape[0]))][:,:,0:3]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

In [None]:
#Implement the patch encoding layer

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.dct_projection = layers.Dense(units=dct_projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch, dct_patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch)+ self.dct_projection(dct_patch)+self.position_embedding(positions)
        return encoded
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_patches':self.num_patches,
        })
        return config



In [None]:
#Implement 2D DCT layer
def dct_2d(
        feature_map,
        norm=None # can also be 'ortho'
):
    X1 = tf.signal.dct(feature_map, type=2, norm=norm)
    X1_t = tf.transpose(X1, perm=[0, 1, 3, 2])
    X2 = tf.signal.dct(X1_t, type=2, norm=norm)
    X2_t = tf.transpose(X2, perm=[0, 1, 3, 2])
    return X2_t

In [None]:
def create_STTM_classifier():
    inputs = layers.Input(shape=input_shape)
    inputs_pos = layers.Permute((2, 3,1)) (inputs)
    augmented = data_augmentation(inputs_pos)
    patches = Patches(patch_size)(augmented)
    # 2D DCT block
    dct_2d_map=layers.Lambda(dct_2d)(inputs)
    #dct_2d_map_pos = tf.transpose(dct_2d_map, perm=[0, 2, 3, 1])
    dct_2d_map_pos= layers.Permute((2, 3,1))(dct_2d_map)
    # Augment data.
    dct_augmented = data_augmentation(dct_2d_map_pos)
    # Create patches.
    dct_patches = Patches(patch_size)(dct_augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches,dct_patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

In [None]:
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "STTM_2D_DCT_v4.weights.best.hdf5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        verbose=1,
        save_best_only=True,
        mode='max',
    )

    history = model.fit(
        x=X_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(X_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


STTM_classifier = create_STTM_classifier()
history = run_experiment(STTM_classifier)

In [None]:
# serialize model to JSON
model_json = STTM_classifier.to_json()
with open("STTM_2D_DCT_v4.json", "w") as json_file:
    json_file.write(model_json)
# serialize weights to HDF5
STTM_classifier.save_weights("STTM_2D_DCT_v4.h5")
print("Saved model to disk")

In [None]:
dct_2d_map_pos