In [1]:
# ===================================================================
# VISION TRANSFORMER WITH TRANSFER LEARNING (RECOMMENDED)
# Much better accuracy for small datasets like yours
# ===================================================================

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_hub as hub

# Set seeds
SEED = 42
keras.utils.set_random_seed(SEED)
tf.random.set_seed(SEED)
np.random.seed(SEED)

# ===================================================================
# CONFIGURATION
# ===================================================================

NUM_CLASSES = 4
INPUT_SHAPE = (224, 224, 3)
BATCH_SIZE = 32
EPOCHS = 50

# Dataset paths
train_dir = '/kaggle/input/type-of-plastic-waste-dataset/train'
val_dir = '/kaggle/input/type-of-plastic-waste-dataset/val'

# ===================================================================
# DATA LOADING WITH AUGMENTATION
# ===================================================================

train_datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=SEED
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

print("Class indices:", train_generator.class_indices)

# ===================================================================
# BUILD VIT MODEL WITH PRE-TRAINED WEIGHTS
# ===================================================================

def create_pretrained_vit():
    """
    Creates ViT with ImageNet pre-trained weights
    This will give MUCH better accuracy than training from scratch
    """
    
    # Input layer
    inputs = layers.Input(shape=INPUT_SHAPE)
    
    # Load pre-trained ViT from TensorFlow Hub
    # Using ViT-B16 (Base model, 16x16 patches)
    vit_url = "https://tfhub.dev/sayakpaul/vit_b16_fe/1"
    vit_backbone = hub.KerasLayer(vit_url, trainable=True)
    
    # Extract features
    features = vit_backbone(inputs)
    
    # Classification head
    x = layers.Dropout(0.3)(features)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    return model


# Alternative: Use Keras Applications ViT (if tensorflow_hub doesn't work)
def create_vit_keras_applications():
    """
    Alternative implementation using Keras 3.x built-in ViT
    """
    from tensorflow.keras.applications import ViTB16
    
    # Load pre-trained ViT backbone
    base_model = ViTB16(
        include_top=False,
        weights='imagenet',
        input_shape=INPUT_SHAPE,
        pooling='avg'
    )
    
    # Freeze early layers, fine-tune later layers
    for layer in base_model.layers[:-20]:
        layer.trainable = False
    
    # Build model
    inputs = layers.Input(shape=INPUT_SHAPE)
    x = base_model(inputs, training=False)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    return model


# ===================================================================
# SIMPLE BUT EFFECTIVE ALTERNATIVE: USE A LIGHTER ARCHITECTURE
# ===================================================================

def create_efficient_vit_from_scratch():
    """
    Much lighter ViT that CAN train from scratch on small datasets
    Based on DeiT (Data-efficient Image Transformers)
    """
    
    # Lightweight configuration
    IMAGE_SIZE = 224
    PATCH_SIZE = 16
    NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
    PROJECTION_DIM = 192  # Much smaller than 768
    NUM_HEADS = 3  # Reduced from 12
    TRANSFORMER_LAYERS = 4  # Reduced from 12
    MLP_HEAD_UNITS = [512, 256]
    
    inputs = layers.Input(shape=INPUT_SHAPE)
    
    # Data augmentation inside model
    augmented = keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ])(inputs)
    
    # Patch creation
    patches = layers.Conv2D(PROJECTION_DIM, PATCH_SIZE, PATCH_SIZE)(augmented)
    patches = layers.Reshape((NUM_PATCHES, PROJECTION_DIM))(patches)
    
    # Positional encoding
    positions = tf.range(start=0, limit=NUM_PATCHES, delta=1)
    position_embedding = layers.Embedding(
        input_dim=NUM_PATCHES, output_dim=PROJECTION_DIM
    )(positions)
    encoded = patches + position_embedding
    
    # Transformer blocks
    for _ in range(TRANSFORMER_LAYERS):
        # Layer norm + attention
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded)
        attention = layers.MultiHeadAttention(
            num_heads=NUM_HEADS, key_dim=PROJECTION_DIM // NUM_HEADS, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention, encoded])
        
        # Layer norm + MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = layers.Dense(PROJECTION_DIM * 2, activation=tf.nn.gelu)(x3)
        x3 = layers.Dropout(0.1)(x3)
        x3 = layers.Dense(PROJECTION_DIM)(x3)
        x3 = layers.Dropout(0.1)(x3)
        encoded = layers.Add()([x3, x2])
    
    # Classification head
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded)
    representation = layers.GlobalAveragePooling1D()(representation)
    representation = layers.Dropout(0.3)(representation)
    
    # MLP head
    for units in MLP_HEAD_UNITS:
        representation = layers.Dense(units, activation=tf.nn.gelu)(representation)
        representation = layers.Dropout(0.3)(representation)
    
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(representation)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


# ===================================================================
# BUILD AND COMPILE MODEL
# ===================================================================

print("Building Vision Transformer model with transfer learning...")

# Try pre-trained ViT first (RECOMMENDED)
try:
    vit_model = create_pretrained_vit()
    print("✓ Using pre-trained ViT from TensorFlow Hub")
except:
    try:
        vit_model = create_vit_keras_applications()
        print("✓ Using pre-trained ViT from Keras Applications")
    except:
        vit_model = create_efficient_vit_from_scratch()
        print("✓ Using lightweight ViT trained from scratch")

vit_model.summary()

# Compile with appropriate learning rate for fine-tuning
vit_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=2, name='top-2-accuracy')]
)

# ===================================================================
# CALLBACKS
# ===================================================================

callbacks = [
    keras.callbacks.ModelCheckpoint(
        'vit_best_model.keras',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=15,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    keras.callbacks.CSVLogger('vit_training_log.csv')
]

# ===================================================================
# TRAIN THE MODEL
# ===================================================================

print("\nStarting training...")
history = vit_model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=val_generator,
    callbacks=callbacks,
    verbose=1
)

# ===================================================================
# EVALUATE
# ===================================================================

print("\nEvaluating on validation set...")
val_results = vit_model.evaluate(val_generator, verbose=1)

print(f"\n{'='*60}")
print(f"Final Validation Results:")
print(f"{'='*60}")
print(f"Validation Accuracy: {val_results[1]*100:.2f}%")
print(f"Validation Top-2 Accuracy: {val_results[2]*100:.2f}%")
print(f"{'='*60}")

# Save model
vit_model.save('vit_plastic_classifier_final.keras')
vit_model.save_weights('vit_weights.weights.h5')

print("\n✅ Vision Transformer training complete!")

2025-10-20 12:52:58.608448: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760964778.796449      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760964778.843704      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Found 11252 images belonging to 4 classes.
Found 2812 images belonging to 4 classes.
Class indices: {'HDPE': 0, 'PET': 1, 'PP': 2, 'PS': 3}
Building Vision Transformer model with transfer learning...


I0000 00:00:1760964800.105102      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1760964800.105750      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


✓ Using lightweight ViT trained from scratch



Starting training...


  self._warn_if_super_not_called()


Epoch 1/50


I0000 00:00:1760964825.515053      77 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.2799 - loss: 1.4253 - top-2-accuracy: 0.5387
Epoch 1: val_accuracy improved from -inf to 0.27632, saving model to vit_best_model.keras
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m448s[0m 1s/step - accuracy: 0.2799 - loss: 1.4253 - top-2-accuracy: 0.5388 - val_accuracy: 0.2763 - val_loss: 1.3734 - val_top-2-accuracy: 0.5466 - learning_rate: 1.0000e-04
Epoch 2/50
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 723ms/step - accuracy: 0.3058 - loss: 1.3659 - top-2-accuracy: 0.5911
Epoch 2: val_accuracy improved from 0.27632 to 0.39509, saving model to vit_best_model.keras
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m294s[0m 836ms/step - accuracy: 0.3059 - loss: 1.3659 - top-2-accuracy: 0.5911 - val_accuracy: 0.3951 - val_loss: 1.2821 - val_top-2-accuracy: 0.6753 - learning_rate: 1.0000e-04
Epoch 3/50
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 