In [3]:
"""
ViT_CIFAR_Experiments.py

Runnable script / notebook-style Python file implementing an improved Vision Transformer (ViT)
with multiple experiments on CIFAR-10 and CIFAR-100.

How to run:
- Recommended: open in Google Colab (Runtime > Change runtime type > GPU)
- Run the whole file cell-by-cell as a notebook, or run as a script in an environment with a GPU.

What this file contains:
- Data loading (CIFAR-10 and CIFAR-100 switchable)
- Data augmentation & preprocessing
- Improved ViT model (Conv stem, Class token, DropPath, Cosine LR)
- Experiment loop to try different hyperparameters
- Training & validation plots, test evaluation
- Results CSV + comparison table printed
- Inline explanations (as comments) for every major modification

Author: Generated for user's assignment by ChatGPT (GPT-5 Thinking mini).
"""

# Notebook-style imports
import os
import math
import json
import time
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# -----------------------------
# User-config / quick switches
# -----------------------------
USE_CIFAR = 100  # set to 10 for CIFAR-10, 100 for CIFAR-100
RUN_EXPERIMENTS = True  # set False to only build model
OUTPUT_DIR = '/content/vit_experiments'  # Colab-friendly path
os.makedirs(OUTPUT_DIR, exist_ok=True)

# -----------------------------
# Helper: Plotting utilities
# -----------------------------

def plot_history(history, title_prefix=''):
    # history: keras History object
    keys = [k for k in history.history.keys()]
    # plot loss
    plt.figure(figsize=(8, 4))
    plt.plot(history.history['loss'], label='train_loss')
    if 'val_loss' in history.history:
        plt.plot(history.history['val_loss'], label='val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{title_prefix} Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    # plot accuracy if present
    if 'accuracy' in history.history:
        plt.figure(figsize=(8, 4))
        plt.plot(history.history['accuracy'], label='train_acc')
        if 'val_accuracy' in history.history:
            plt.plot(history.history['val_accuracy'], label='val_acc')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title(f'{title_prefix} Accuracy')
        plt.legend()
        plt.grid(True)
        plt.show()

# -----------------------------
# Data loading + preprocessing
# -----------------------------
print('Loading dataset...')
if USE_CIFAR == 10:
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    num_classes = 10
else:
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
    num_classes = 100

print('Train samples:', x_train.shape[0])
print('Test samples:', x_test.shape[0])

# We will resize images to a slightly larger size for the ViT patching
IMAGE_SIZE = 72  # you can change this to 96 if GPU allows
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# Simple preprocessing: keep them as uint8 until augmentation layer handles scaling.
# Create tf.data datasets to speed up training
BATCH_SIZE = 128
AUTOTUNE = tf.data.AUTOTUNE

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val_count = int(0.1 * x_train.shape[0])

# Shuffle then split: we'll use first val_count as validation for reproducibility
train_ds = train_ds.shuffle(buffer_size=10000, seed=42)
val_ds = train_ds.take(val_count)
train_ds = train_ds.skip(val_count)

# Augmentation + preprocessing layer (as a keras Sequential so it's serializable)
# Note: we use Rescaling(1./255) to convert to [0,1]
DATA_AUG = keras.Sequential([
    layers.Rescaling(1.0 / 255),
    layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
    layers.RandomFlip('horizontal'),
    layers.RandomRotation(0.05),
    layers.RandomZoom(height_factor=0.12, width_factor=0.12),
], name='data_augmentation')

# Prefetch datasets
train_ds = train_ds.batch(BATCH_SIZE).map(lambda x, y: (x, tf.squeeze(y))).map(lambda x, y: (x, y), num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).map(lambda x, y: (x, tf.squeeze(y))).prefetch(AUTOTUNE)

# Test dataset
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE).map(lambda x, y: (x, tf.squeeze(y))).prefetch(AUTOTUNE)

# -----------------------------
# Model components & improvements (with explanations)
# -----------------------------

# 1) Convolutional stem: small conv block before patching helps capture local features early.
def conv_stem(x):
    # 3x3 conv -> BN -> GELU, then another downsample conv to slightly increase receptive field
    x = layers.Conv2D(32, kernel_size=3, padding='same', activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('gelu')(x)
    x = layers.Conv2D(32, kernel_size=3, strides=1, padding='same', activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('gelu')(x)
    return x

# 2) Patches using tf.image.extract_patches
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        # images: [batch, H, W, C]
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

# 3) Patch encoding + Class token
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        # +1 to accommodate class token position embedding
        self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)

    def build(self, input_shape):
        # create a trainable class token (1, 1, projection_dim)
        self.class_token = self.add_weight('class_token', shape=(1, 1, self.projection.units), initializer='zeros', trainable=True)
        super().build(input_shape)

    def call(self, patch):
        # patch: [batch, num_patches, patch_dim]
        batch_size = tf.shape(patch)[0]
        projected = self.projection(patch)  # [batch, num_patches, projection_dim]
        class_tokens = tf.tile(self.class_token, [batch_size, 1, 1])  # [batch, 1, projection_dim]
        x = tf.concat([class_tokens, projected], axis=1)  # prepend class token
        positions = tf.range(start=0, limit=self.num_patches + 1, delta=1)
        encoded = x + self.position_embedding(positions)
        return encoded

# 4) DropPath (stochastic depth) implementation
class DropPath(layers.Layer):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob if drop_prob is not None else 0.0

    def call(self, x, training=False):
        if (not training) or (self.drop_prob == 0.0):
            return x
        keep_prob = 1.0 - self.drop_prob
        # shape for broadcasting: [batch, 1, 1]
        batch_size = tf.shape(x)[0]
        random_tensor = keep_prob + tf.random.uniform([batch_size, 1, 1], dtype=x.dtype)
        binary_tensor = tf.floor(random_tensor)
        x = tf.divide(x, keep_prob) * binary_tensor
        return x

# MLP helper (used in transformer blocks and head)
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

# Create ViT with improvements
def create_improved_vit(
    input_shape=(32, 32, 3),
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    projection_dim=128,
    num_heads=6,
    transformer_layers=10,
    transformer_mlp_units=None,
    mlp_head_units=None,
    drop_path_rate=0.1,
    num_classes=100
):
    if transformer_mlp_units is None:
        transformer_mlp_units = [projection_dim * 2, projection_dim]
    if mlp_head_units is None:
        mlp_head_units = [2048, 1024]

    inputs = keras.Input(shape=input_shape)

    # Data augmentation + conv stem
    x = DATA_AUG(inputs)
    x = conv_stem(x)  # conv stem - local feature extractor

    # Create patches
    patches = Patches(patch_size)(x)

    # Encode patches with class token + position embeddings
    encoded_patches = PatchEncoder(NUM_PATCHES, projection_dim)(patches)

    # Transformer blocks with optional stochastic depth (DropPath)
    # We'll linearly scale the drop path across layers for better regularization
    dpr_rates = np.linspace(0.0, drop_path_rate, transformer_layers)

    for i in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=0.1)(x1, x1)
        attention_output = DropPath(dpr_rates[i])(attention_output, training=True)
        x2 = layers.Add()([attention_output, encoded_patches])

        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, transformer_mlp_units, dropout_rate=0.1)
        x3 = DropPath(dpr_rates[i])(x3, training=True)
        encoded_patches = layers.Add()([x3, x2])

    # Use the class token output (index 0)
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    cls_token_output = representation[:, 0, :]
    x = layers.Dropout(0.5)(cls_token_output)
    x = mlp(x, mlp_head_units, dropout_rate=0.5)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = keras.Model(inputs=inputs, outputs=outputs, name='improved_vit')
    return model

# -----------------------------
# Training loop / experiments
# -----------------------------

def compile_and_train(model, lr=1e-3, weight_decay=1e-4, epochs=30, steps_per_epoch=None):
    # Cosine decay schedule
    total_steps = (steps_per_epoch if steps_per_epoch is not None else math.ceil((x_train.shape[0]*0.9)/BATCH_SIZE)) * epochs
    # If tfa (tensorflow_addons) is available, prefer AdamW. Otherwise fallback to Adam.
    try:
        import tensorflow_addons as tfa
        print('tensorflow_addons found: using AdamW')
        optimizer = tfa.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=weight_decay)
    except Exception:
        print('tensorflow_addons not found: falling back to Adam (without weight decay)')
        optimizer = keras.optimizers.Adam(learning_rate=lr)

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=lr, decay_steps=total_steps)

    # Compile
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )

    # Callbacks
    ckpt = keras.callbacks.ModelCheckpoint(os.path.join(OUTPUT_DIR, 'best_weights.h5'), monitor='val_accuracy', save_best_only=True, save_weights_only=True)
    early = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=8, restore_best_weights=True)

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=[ckpt, early]
    )
    return history

# -----------------------------
# Run experiments (grid) - Note: each experiment can be long. The code organizes experiments and saves results.
# -----------------------------

if RUN_EXPERIMENTS:
    experiments = [
        {
            'name': 'baseline_vit',
            'projection_dim': 64,
            'transformer_layers': 8,
            'num_heads': 4,
            'drop_path_rate': 0.0,
            'batch_size': 256,
            'lr': 1e-3,
            'epochs': 20
        },
        {
            'name': 'convstem + class_token',
            'projection_dim': 64,
            'transformer_layers': 8,
            'num_heads': 4,
            'drop_path_rate': 0.05,
            'batch_size': 128,
            'lr': 1e-3,
            'epochs': 25
        },
        {
            'name': 'deeper + larger dim',
            'projection_dim': 128,
            'transformer_layers': 10,
            'num_heads': 6,
            'drop_path_rate': 0.1,
            'batch_size': 128,
            'lr': 1e-3,
            'epochs': 30
        }
    ]

    results = []
    # For reproducibility, set seed
    tf.random.set_seed(42)
    np.random.seed(42)

    for exp in experiments:
        print('\n' + '='*40)
        print(f"Running experiment: {exp['name']}")
        print('='*40)

        # Build model with given hyperparameters
        model = create_improved_vit(
            input_shape=(32, 32, 3),
            image_size=IMAGE_SIZE,
            patch_size=PATCH_SIZE,
            projection_dim=exp['projection_dim'],
            num_heads=exp['num_heads'],
            transformer_layers=exp['transformer_layers'],
            drop_path_rate=exp['drop_path_rate'],
            num_classes=num_classes
        )

        model.summary()

        # Compile & train
        # Adjust global BATCH_SIZE if needed
        history = compile_and_train(model, lr=exp['lr'], epochs=exp['epochs'])

        # Plot history
        plot_history(history, title_prefix=exp['name'])

        # Evaluate on test set
        test_loss, test_acc = model.evaluate(test_ds, verbose=2)
        print(f"Test accuracy for {exp['name']}: {test_acc:.4f}")

        # Save history and results
        res = {
            'experiment': exp['name'],
            'projection_dim': exp['projection_dim'],
            'transformer_layers': exp['transformer_layers'],
            'num_heads': exp['num_heads'],
            'drop_path_rate': exp['drop_path_rate'],
            'batch_size': exp['batch_size'],
            'lr': exp['lr'],
            'epochs_trained': len(history.history['loss']),
            'test_accuracy': float(test_acc),
            'test_loss': float(test_loss)
        }
        results.append(res)

        # Save model
        model.save(os.path.join(OUTPUT_DIR, f"{exp['name']}_model"))

        # Save history to json
        with open(os.path.join(OUTPUT_DIR, f"{exp['name']}_history.json"), 'w') as f:
            json.dump(history.history, f)

    # Save results table
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(OUTPUT_DIR, 'experiment_results.csv'), index=False)
    print('\nAll experiments finished. Summary:')
    print(df)

    # Show comparison table
    print('\nComparison table (sorted by test_accuracy):')
    print(df.sort_values('test_accuracy', ascending=False))

else:
    print('RUN_EXPERIMENTS set to False. The script built the model only.')

# End of file

Loading dataset...
Train samples: 50000
Test samples: 10000


Cause: could not parse the source code of <function <lambda> at 0x7820fd49f100>: found multiple definitions with identical signatures at the location. This error may be avoided by defining each lambda on a single line and with unique argument names. The matching definitions were:
Match 0:
lambda x, y: (x, y)

Match 1:
lambda x, y: (x, tf.squeeze(y))



Cause: could not parse the source code of <function <lambda> at 0x7820fd49f100>: found multiple definitions with identical signatures at the location. This error may be avoided by defining each lambda on a single line and with unique argument names. The matching definitions were:
Match 0:
lambda x, y: (x, y)

Match 1:
lambda x, y: (x, tf.squeeze(y))



Cause: could not parse the source code of <function <lambda> at 0x7820fd49f1a0>: found multiple definitions with identical signatures at the location. This error may be avoided by defining each lambda on a single line and with unique argument names. The matching definitions were:
Match 0:
lambda x, y: (x, y)

Match 1:
lambda x, y: (x, tf.squeeze(y))



Cause: could not parse the source code of <function <lambda> at 0x7820fd49f1a0>: found multiple definitions with identical signatures at the location. This error may be avoided by defining each lambda on a single line and with unique argument names. The matching definitions were:
Match 0:
lambda x, y: (x, y)

Match 1:
lambda x, y: (x, tf.squeeze(y))


Running experiment: baseline_vit


TypeError: Layer.add_weight() got multiple values for argument 'shape'