In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np

class SimpleKerasTransformer(Model):
    """Simplified Transformer using built-in Keras components"""

    def __init__(self, vocab_size, d_model=64, num_heads=4, num_layers=2, max_seq_len=10):
        super().__init__()

        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # Embeddings (Keras handles positional encoding internally in newer versions)
        self.encoder_embedding = layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.decoder_embedding = layers.Embedding(vocab_size, d_model, mask_zero=True)

        # Positional encoding (simple learned embeddings)
        self.encoder_pos_embedding = layers.Embedding(max_seq_len, d_model)
        self.decoder_pos_embedding = layers.Embedding(max_seq_len, d_model)

        # Encoder layers (using built-in components)
        self.encoder_layers = []
        for _ in range(num_layers):
            encoder_layer = {
                'attention': layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads),
                'ffn': tf.keras.Sequential([
                    layers.Dense(d_model * 2, activation='relu'),
                    layers.Dense(d_model)
                ]),
                'norm1': layers.LayerNormalization(),
                'norm2': layers.LayerNormalization(),
                'dropout': layers.Dropout(0.1)
            }
            self.encoder_layers.append(encoder_layer)

        # Decoder layers
        self.decoder_layers = []
        for _ in range(num_layers):
            decoder_layer = {
                'self_attention': layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads),
                'cross_attention': layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads),
                'ffn': tf.keras.Sequential([
                    layers.Dense(d_model * 2, activation='relu'),
                    layers.Dense(d_model)
                ]),
                'norm1': layers.LayerNormalization(),
                'norm2': layers.LayerNormalization(),
                'norm3': layers.LayerNormalization(),
                'dropout': layers.Dropout(0.1)
            }
            self.decoder_layers.append(decoder_layer)

        # Output layer
        self.output_layer = layers.Dense(vocab_size)

    def create_causal_mask(self, size):
        """Create causal mask for decoder self-attention"""
        mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0)
        return mask[tf.newaxis, tf.newaxis, :, :]

    def call(self, inputs, training=False):
        encoder_input, decoder_input = inputs

        # Get sequence lengths
        enc_seq_len = tf.shape(encoder_input)[1]
        dec_seq_len = tf.shape(decoder_input)[1]

        # Encoder
        # Embeddings + positional encoding
        enc_positions = tf.range(enc_seq_len)[tf.newaxis, :]
        enc_emb = self.encoder_embedding(encoder_input)
        enc_pos_emb = self.encoder_pos_embedding(enc_positions)
        enc_output = enc_emb + enc_pos_emb

        # Encoder layers
        for layer in self.encoder_layers:
            # Self-attention
            attn_output = layer['attention'](enc_output, enc_output, training=training)
            attn_output = layer['dropout'](attn_output, training=training)
            enc_output = layer['norm1'](enc_output + attn_output)

            # Feed forward
            ffn_output = layer['ffn'](enc_output)
            ffn_output = layer['dropout'](ffn_output, training=training)
            enc_output = layer['norm2'](enc_output + ffn_output)

        # Decoder
        # Embeddings + positional encoding
        dec_positions = tf.range(dec_seq_len)[tf.newaxis, :]
        dec_emb = self.decoder_embedding(decoder_input)
        dec_pos_emb = self.decoder_pos_embedding(dec_positions)
        dec_output = dec_emb + dec_pos_emb

        # Create causal mask
        causal_mask = self.create_causal_mask(dec_seq_len)

        # Decoder layers
        for layer in self.decoder_layers:
            # Masked self-attention
            self_attn_output = layer['self_attention'](
                dec_output, dec_output,
                attention_mask=causal_mask,
                training=training
            )
            self_attn_output = layer['dropout'](self_attn_output, training=training)
            dec_output = layer['norm1'](dec_output + self_attn_output)

            # Cross-attention
            cross_attn_output = layer['cross_attention'](
                dec_output, enc_output, training=training
            )
            cross_attn_output = layer['dropout'](cross_attn_output, training=training)
            dec_output = layer['norm2'](dec_output + cross_attn_output)

            # Feed forward
            ffn_output = layer['ffn'](dec_output)
            ffn_output = layer['dropout'](ffn_output, training=training)
            dec_output = layer['norm3'](dec_output + ffn_output)

        # Final output
        output = self.output_layer(dec_output)
        return output


# Simplified data creation (same as before but cleaner)
def create_simple_data(level=1):
    """Simplified data creation"""
    data_levels = {
        1: [("hello", "வணக்கம்"), ("good", "நல்ல"), ("thank", "நன்றி"), ("water", "தண்ணீர்"), ("food", "உணவு")],
        2: [("good morning", "காலை வணக்கம்"), ("thank you", "நன்றி நீங்கள்"), ("good night", "இனிய இரவு")],
        3: [("how are you", "நீங்கள் எப்படி இருக்கிறீர்கள்"), ("what is this", "இது என்ன ஆகும்")],
    }

    examples = data_levels.get(level, data_levels[1])
    max_len = 4 + level
    return examples, max_len

def prepare_data_simple(examples, max_len):
    """Simplified data preparation using Keras utilities"""

    # Create vocabulary
    vocab = {"<PAD>": 0, "<START>": 1, "<END>": 2, "<UNK>": 3}

    all_words = set()
    for eng, tam in examples:
        all_words.update(eng.split() + tam.split())

    for word in sorted(all_words):
        vocab[word] = len(vocab)

    reverse_vocab = {v: k for k, v in vocab.items()}

    # Prepare sequences
    eng_seqs, tam_input_seqs, tam_target_seqs = [], [], []

    for eng, tam in examples:
        # English (encoder input)
        eng_tokens = [vocab.get(w, vocab["<UNK>"]) for w in eng.split()]
        eng_seq = tf.keras.preprocessing.sequence.pad_sequences(
            [eng_tokens], maxlen=max_len, padding='post')[0]

        # Tamil input (decoder input) - with START token
        tam_tokens = [vocab["<START>"]] + [vocab.get(w, vocab["<UNK>"]) for w in tam.split()]
        tam_input_seq = tf.keras.preprocessing.sequence.pad_sequences(
            [tam_tokens], maxlen=max_len, padding='post')[0]

        # Tamil target (decoder output) - with END token
        tam_target_tokens = [vocab.get(w, vocab["<UNK>"]) for w in tam.split()] + [vocab["<END>"]]
        tam_target_seq = tf.keras.preprocessing.sequence.pad_sequences(
            [tam_target_tokens], maxlen=max_len, padding='post')[0]

        eng_seqs.append(eng_seq)
        tam_input_seqs.append(tam_input_seq)
        tam_target_seqs.append(tam_target_seq)

    return (np.array(eng_seqs), np.array(tam_input_seqs), np.array(tam_target_seqs),
            vocab, reverse_vocab)

# Simplified loss and metrics using Keras built-ins
def create_masked_loss():
    """Create masked loss using Keras built-in functionality"""
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

    def masked_loss(y_true, y_pred):
        mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
        loss = loss_fn(y_true, y_pred)
        masked_loss = loss * mask
        return tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)

    return masked_loss

def create_masked_accuracy():
    """Create masked accuracy metric"""
    def masked_accuracy(y_true, y_pred):
        y_pred_class = tf.cast(tf.argmax(y_pred, axis=-1), tf.int32)
        mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
        accuracy = tf.cast(tf.equal(y_true, y_pred_class), tf.float32) * mask
        return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

    return masked_accuracy

# Simplified translation function
def translate_simple(model, sentence, vocab, reverse_vocab, max_len):
    """Simplified translation using the model"""

    # Encode input
    words = sentence.split()
    eng_seq = [vocab.get(w, vocab["<UNK>"]) for w in words]
    eng_input = tf.keras.preprocessing.sequence.pad_sequences(
        [eng_seq], maxlen=max_len, padding='post')

    # Start with START token
    decoder_input = [vocab["<START>"]]

    for _ in range(max_len - 1):
        # Pad and predict
        dec_input = tf.keras.preprocessing.sequence.pad_sequences(
            [decoder_input], maxlen=max_len, padding='post')

        predictions = model([eng_input, dec_input], training=False)

        # Get next token
        next_token = tf.argmax(predictions[0, len(decoder_input)-1, :]).numpy()

        if next_token == vocab["<END>"] or next_token == vocab["<PAD>"]:
            break

        decoder_input.append(next_token)

    # Convert to words
    words = [reverse_vocab.get(token, "") for token in decoder_input[1:]]  # Skip START
    return " ".join([w for w in words if w not in ["<START>", "<END>", "<PAD>", "<UNK>", ""]])

# Simplified training function
def train_simple_level(level=1):
    """Simplified training using built-in Keras components"""

    print(f"\n=== Training Level {level} (Keras Built-in) ===")

    # Get data
    examples, max_len = create_simple_data(level)
    eng_data, tam_input, tam_target, vocab, reverse_vocab = prepare_data_simple(examples, max_len)

    print(f"Level {level}: {len(examples)} examples, vocab size: {len(vocab)}")

    model = SimpleKerasTransformer(
            vocab_size=len(vocab),
            d_model=64,
            num_heads=4,
            num_layers=2,
            max_seq_len=max_len
        )

    # Compile with built-in components
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=create_masked_loss(),
        metrics=[create_masked_accuracy()]
    )

    # Expand data for training
    repetitions = max(10, 50 // len(examples))
    eng_expanded = np.tile(eng_data, (repetitions, 1))
    tam_input_expanded = np.tile(tam_input, (repetitions, 1))
    tam_target_expanded = np.tile(tam_target, (repetitions, 1))

    print(f"Training with {len(eng_expanded)} examples...")

    # Train
    history = model.fit(
        [eng_expanded, tam_input_expanded],
        tam_target_expanded,
        epochs=50,
        batch_size=8,
        verbose=1,
        validation_split=0.2
    )

    # Test translations
    print(f"\n=== Testing Level {level} ===")
    correct = 0

    for eng_sentence, expected_tam in examples:
        predicted_tam = translate_simple(model, eng_sentence, vocab, reverse_vocab, max_len)

        print(f"'{eng_sentence}' -> '{predicted_tam}' (expected: '{expected_tam}')")

        # Simple accuracy check
        if any(word in predicted_tam for word in expected_tam.split()):
            correct += 1

    accuracy = (correct / len(examples)) * 100
    print(f"Level {level} Accuracy: {accuracy:.1f}%")

    return accuracy >= 50, model, vocab, reverse_vocab, max_len

# Super simple one-liner approach
def create_minimal_transformer(vocab_size):
    """Most minimal transformer possible with Keras"""

    # Encoder
    enc_input = layers.Input(shape=(None,))
    enc_emb = layers.Embedding(vocab_size, 64, mask_zero=True)(enc_input)
    enc_out = layers.MultiHeadAttention(num_heads=4, key_dim=16)(enc_emb, enc_emb)
    enc_out = layers.LayerNormalization()(enc_out + enc_emb)

    # Decoder
    dec_input = layers.Input(shape=(None,))
    dec_emb = layers.Embedding(vocab_size, 64, mask_zero=True)(dec_input)
    dec_self = layers.MultiHeadAttention(num_heads=4, key_dim=16)(dec_emb, dec_emb, use_causal_mask=True)
    dec_out = layers.LayerNormalization()(dec_self + dec_emb)
    dec_cross = layers.MultiHeadAttention(num_heads=4, key_dim=16)(dec_out, enc_out)
    dec_out = layers.LayerNormalization()(dec_cross + dec_out)

    # Output
    outputs = layers.Dense(vocab_size)(dec_out)

    return Model([enc_input, dec_input], outputs)

# Run the simplified version
def run_simple_training():
    """Run the simplified Keras built-in version"""

    print("=== Simplified Transformer with Keras Built-ins ===\n")

    for level in range(1, 4):
        success, model, vocab, reverse_vocab, max_len = train_simple_level(level)

        if success:
            print(f"✅ Level {level} passed!")
        else:
            print(f"❌ Level {level} needs work")

        print("-" * 50)

    print("\n🎯 Simplified training complete!")

    # Show minimal version
    print("\n=== Minimal Transformer (One-liner style) ===")
    minimal_model = create_minimal_transformer(vocab_size=100)
    print(f"Minimal model created with {minimal_model.count_params():,} parameters")
    minimal_model.summary()

if __name__ == "__main__":
    run_simple_training()

=== Simplified Transformer with Keras Built-ins ===


=== Training Level 1 (Keras Built-in) ===
Level 1: 5 examples, vocab size: 14
Training with 50 examples...
Epoch 1/50
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 422ms/step - loss: 1.8219 - masked_accuracy: 0.3401 - val_loss: 0.8744 - val_masked_accuracy: 0.8125
Epoch 2/50
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - loss: 0.8591 - masked_accuracy: 0.7155 - val_loss: 0.7777 - val_masked_accuracy: 0.6250
Epoch 3/50
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 90ms/step - loss: 0.7275 - masked_accuracy: 0.6977 - val_loss: 0.5015 - val_masked_accuracy: 1.0000
Epoch 4/50
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 63ms/step - loss: 0.4735 - masked_accuracy: 0.9483 - val_loss: 0.2002 - val_masked_accuracy: 1.0000
Epoch 5/50
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 62ms/step - loss: 0.1682 - masked_accuracy: 0.9845 - val_loss: 0.0450 - 