### IMPORTS

In [1]:
!pip install keras



In [2]:
import keras
import tensorflow as tf
from tensorflow.keras import layers
from keras import regularizers

In [3]:
from keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

### DOWNLOADING AND PREPARING DATASET

In [4]:
vocab_size = 20000  # Only consider the top 20k words
num_tokens_per_example = 200  # Only consider the first 200 words of each movie review
(x_train, y_train), (x_val, y_val) = tf.keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=num_tokens_per_example)
x_val = tf.keras.preprocessing.sequence.pad_sequences(x_val, maxlen=num_tokens_per_example)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
25000 Training sequences
25000 Validation sequences


### DEFINING HYPERPARAMETERS

In [14]:
embed_dim = 16  # Embedding size for each token.
num_heads = 2  # Number of attention heads
ff_dim = 16  # Hidden layer size in feedforward network.
num_experts = 5  # Number of experts used in the Switch Transformer.
batch_size = 50  # Batch size.
learning_rate = 0.001  # Learning rate.
dropout_rate = 0.5  # Increase dropout rate
num_epochs = 3  # Number of epochs.
num_tokens_per_batch = (
    batch_size * num_tokens_per_example
)  # Total number of tokens per batch.
print(f"Number of tokens per batch: {num_tokens_per_batch}")

Number of tokens per batch: 10000


### IMPLEMENTING TOKEN & POSITIONING EMBEDDING LAYER

In [15]:
class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
      maxlen = tf.shape(x)[-1]
      positions = tf.range(start=0, limit=maxlen, delta=1)
      positions = self.pos_emb(positions)
      x = self.token_emb(x)
      return x + positions

### IMPLEMENTING FEEDFORWARD NETWORK

In [16]:
def create_feedforward_network(ff_dim, embed_dim, name=None):
    return keras.Sequential(
        [
            layers.Dense(ff_dim, activation="relu", kernel_regularizer=regularizers.l2(0.01)),
            layers.Dense(embed_dim, kernel_regularizer=regularizers.l2(0.01))
        ],
        name=name
    )

### IMPLEMENTING LOAD-BALANCED LOSS

In [17]:
def load_balanced_loss(router_probs, expert_mask):
    num_experts = tf.shape(expert_mask)[-1]
    density = tf.reduce_mean(expert_mask, axis=0)
    density_proxy = tf.reduce_mean(router_probs, axis=0)
    loss = tf.reduce_mean(density_proxy * density) * tf.cast((num_experts**2), tf.float32)
    return loss

### IMPLEMENTING ROUTER AS LAYER

In [18]:
class Router(layers.Layer):
    def __init__(self, num_experts, expert_capacity):
        self.num_experts = num_experts
        self.route = layers.Dense(units=num_experts)
        self.expert_capacity = expert_capacity
        super().__init__()

    def call(self, inputs, training=False):
        router_logits = self.route(inputs)

        if training:
            router_logits += tf.random.uniform(shape=tf.shape(router_logits), minval=0.9, maxval=1.1)
        router_probs = tf.nn.softmax(router_logits, axis=-1)
        expert_gate, expert_index = tf.math.top_k(router_probs, k=1)
        expert_mask = tf.one_hot(expert_index, self.num_experts)
        aux_loss = load_balanced_loss(router_probs, expert_mask)
        self.add_loss(aux_loss)

        position_in_expert = tf.cast(tf.cumsum(expert_mask, axis=0) * expert_mask, tf.int32)
        expert_mask *= tf.cast(tf.less(position_in_expert, self.expert_capacity), tf.float32)
        expert_mask_flat = tf.reduce_sum(expert_mask, axis=-1)
        expert_gate *= expert_mask_flat

        combined_tensor = tf.expand_dims(
            expert_gate * expert_mask_flat * tf.squeeze(tf.one_hot(expert_index, self.num_experts), 1),
            -1,
        ) * tf.squeeze(tf.one_hot(position_in_expert, self.expert_capacity), 1)
        dispatch_tensor = tf.cast(combined_tensor, tf.float32)

        return dispatch_tensor, combined_tensor

### IMPLEMENTING SWITCH LAYER

In [19]:
class Switch(layers.Layer):
    def __init__(
        self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1
    ):
        self.num_experts = num_experts
        self.embed_dim = embed_dim
        self.experts = [
            create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)
        ]

        self.expert_capacity = num_tokens_per_batch // self.num_experts
        self.router = Router(self.num_experts, self.expert_capacity)
        super().__init__()

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        num_tokens_per_example = tf.shape(inputs)[1]

        inputs = tf.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
        dispatch_tensor, combine_tensor = self.router(inputs)
        expert_inputs = tf.einsum("ab,acd->cdb", inputs, dispatch_tensor)
        expert_inputs = tf.reshape(
            expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
        )
        expert_input_list = tf.unstack(expert_inputs, axis=0)
        expert_output_list = [
            self.experts[idx](expert_input)
            for idx, expert_input in enumerate(expert_input_list)
        ]
        expert_outputs = tf.stack(expert_output_list, axis=1)
        expert_outputs_combined = tf.einsum(
            "abc,xba->xc", expert_outputs, combine_tensor
        )
        outputs = tf.reshape(
            expert_outputs_combined,
            [batch_size, num_tokens_per_example, self.embed_dim],
        )
        return outputs

### IMPLEMENT TRANSFORMER BLOCK LAYER

In [20]:
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = ffn
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.dropout2 = layers.Dropout(dropout_rate)

    def call(self, inputs, training=False):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

### IMPLEMENTING CLASSIFIER

In [21]:
def create_classifier():
    switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)
    transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)

    inputs = layers.Input(shape=(num_tokens_per_example,))
    embedding_layer = TokenAndPositionEmbedding(
        num_tokens_per_example, vocab_size, embed_dim
    )
    x = embedding_layer(inputs)
    x = transformer_block(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(ff_dim, activation="relu")(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(2, activation="softmax")(x)

    classifier = tf.keras.Model(inputs=inputs, outputs=outputs)
    return classifier

### TRAINING AND EVALUATION

In [22]:
def run_experiment(classifier):
    classifier.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    history = classifier.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_data=(x_val, y_val),
        callbacks=[early_stopping]
    )
    return history

classifier = create_classifier()
run_experiment(classifier)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x794095615e40>