In [None]:
# --- 0. Install Required Libraries ---
!pip install tensorflow numpy matplotlib

# --- 1. Import Libraries ---
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# --- 2. Define Paths and Parameters ---
DATASET_PATH = '/content/dataset'  # Update this to your dataset path
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
NUM_CLASSES = 6  # Paper, Plastic, Metal, Glass, Trash, Cardboard

# --- 3. Data Augmentation & Loading ---
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

train_gen = datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training'
)

val_gen = datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

# --- 4. Build MobileNetV2 Model ---
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,
    weights='imagenet'
)
base_model.trainable = False  # Freeze base layers

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# --- 5. Train the Model ---
history = model.fit(train_gen, epochs=5, validation_data=val_gen)

# --- 6. Plot Training Accuracy ---
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()

# --- 7. Save Model in Keras Format ---
model.save("model.h5")

# --- 8. Convert to TensorFlow Lite ---
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('recycle_classifier.tflite', 'wb') as f:
    f.write(tflite_model)

print("✅ TFLite model saved as 'recycle_classifier.tflite'")

# --- 9. Test Inference with TFLite Interpreter ---
interpreter = tf.lite.Interpreter(model_path="recycle_classifier.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Get a sample image from validation set
test_image = next(val_gen)[0][0:1]  # Take one sample
test_image = tf.image.resize(test_image, IMG_SIZE).numpy().astype(input_details['dtype'])

interpreter.set_tensor(input_details['index'], test_image)
interpreter.invoke()

output_data = interpreter.get_tensor(output_details['index'])
predicted_class = output_data.argmax()
class_labels = list(train_gen.class_indices.keys())

print(f"\n🔍 Predicted Class: {class_labels[predicted_class]}")
print(f"Raw Output: {output_data}")