In [None]:
!pip install transformers datasets

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import json
from PIL import Image
from transformers import AutoTokenizer, TFGemmaForCausalLM, GemmaConfig

ImportError: cannot import name 'descriptor' from 'google.protobuf' (unknown location)

In [None]:
# --- Configuration ---
MODEL_NAME = "google/gemma-3n-E2B" # Or gemma-3n-E4B for a larger model
# Check Google's official Gemma 3n resources for recommended image input sizes.
# As of current info, Gemma 3 models support 256x256, 512x512, or 768x768.
# Let's target a common size for efficiency on device.
IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 32
EPOCHS = 10 # You might need more or less
LEARNING_RATE = 2e-5


In [None]:


# Path to your dataset (e.g., download from Kaggle or your own collected data)
# Assume structure: data_dir/category1/image1.jpg, data_dir/category2/image2.jpg
DATA_DIR = 'path/to/your/waste_dataset'
OUTPUT_DIR = 'model_output'
TFLITE_MODEL_PATH = os.path.join(OUTPUT_DIR, 'model.tflite')
LABELS_PATH = os.path.join(OUTPUT_DIR, 'labels.txt')

os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
# --- 1. Data Preparation ---

# Option A: Using tf.keras.preprocessing.image_dataset_from_directory
# Simple if your data is organized in folders by class
def load_dataset_from_directory(data_dir, img_height, img_width, batch_size):
    print(f"Loading dataset from: {data_dir}")
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        label_mode='int' # Use 'int' for integer labels, 'categorical' for one-hot
    )

    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        label_mode='int'
    )
    
    class_names = train_ds.class_names
    print(f"Detected class names: {class_names}")

    # Save labels for Flutter app
    with open(LABELS_PATH, 'w') as f:
        for name in class_names:
            f.write(f"{name}\n")

    return train_ds, val_ds, class_names


In [None]:

train_ds, val_ds, class_names = load_dataset_from_directory(DATA_DIR, IMG_HEIGHT, IMG_WIDTH, BATCH_SIZE)
NUM_CLASSES = len(class_names)


In [None]:

# Normalize pixel values to 0-1 range (important for TFLite conversion later)
def preprocess_image(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    return image, label



train_ds = train_ds.map(preprocess_image).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess_image).cache().prefetch(buffer_size=tf.data.AUTOTUNE)


In [None]:

# --- 2. Load Gemma 3n (Visual Encoder) and Build Custom Head ---

# Gemma 3n uses a vision encoder (like SigLIP or MobileNet-v5-300).
# For direct image classification, we primarily need this visual part.
# The Hugging Face `transformers` library provides an easy way to load pre-trained models.

# Load Gemma 3n configuration (this is for the full multimodal model)
gemma_config = GemmaConfig.from_pretrained(MODEL_NAME)
# Note: For strict image classification, you might need to specifically extract or use
# the vision encoder part if it's not directly exposed for a classification task.
# Google usually provides specific examples for such tasks.

# Placeholder for Gemma 3n's visual feature extractor.
# In a real scenario, you'd load Gemma 3n's pre-trained visual encoder,
# likely provided as a component within the overall model or a separate
# model. For on-device, it's often a specialized MobileNet variant.
# We'll simulate this with a pre-trained image model for demonstration,
# as direct Keras `GemmaForImageClassification` might not be a standard API yet.

# Option 1: Use a well-known image classification backbone (e.g., MobileNetV2)
# and assume it acts as the "visual encoder" part, then fine-tune it.
# This is a practical approach if Gemma 3n's exact visual encoder isn't
# directly exposed as a separate Keras layer for classification fine-tuning.
# However, for the challenge, you should use the *official* Gemma 3n visual encoder
# if Google provides a clear path for it.

# Let's use MobileNetV2 as a placeholder for the vision encoder for now.
# If Google releases a specific Keras layer for Gemma 3n's vision encoder,
# replace this with that.

base_model = tf.keras.applications.MobileNetV2(
    input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
    include_top=False, # Don't include the classification head of MobileNetV2
    weights='imagenet' # Use pre-trained ImageNet weights
)
base_model.trainable = False # Freeze the base model initially

inputs = keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x) # Add a dropout layer for regularization
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x) # Classification head

model = keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(), # For integer labels
    metrics=['accuracy']
)

model.summary()

# --- 3. Train (Fine-tune) the Model ---

print("\nStarting model training...")
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds
)
print("Model training complete!")

# Save the trained Keras model
model.save(os.path.join(OUTPUT_DIR, 'waste_classifier_model.h5'))
print(f"Keras model saved to {os.path.join(OUTPUT_DIR, 'waste_classifier_model.h5')}")


In [None]:
# --- 4. Convert to TensorFlow Lite ---

print("\nConverting model to TensorFlow Lite...")
converter = tf.lite.TFLiteConverter.from_keras_model(model)

# Apply optimizations for on-device deployment
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Ensure the input and output types are float32 if the model was trained with float32.
# For integer quantization, you'd provide a representative dataset.
converter.target_spec.supported_types = [tf.float32] 

tflite_model = converter.convert()

with open(TFLITE_MODEL_PATH, 'wb') as f:
    f.write(tflite_model)

print(f"TensorFlow Lite model saved to {TFLITE_MODEL_PATH}")

# --- 5. Verify the TFLite Model (Optional) ---
print("\nVerifying TFLite model...")
interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL_PATH)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input details:", input_details)
print("Output details:", output_details)


In [None]:

# Test with a dummy image
dummy_image = np.random.rand(1, IMG_HEIGHT, IMG_WIDTH, 3).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], dummy_image)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

print("TFLite output shape:", tflite_output.shape)
print("TFLite output (first 5 values):", tflite_output[0, :5])

print("\nBackend model setup complete!")