In [None]:
# ==============================
# Install packages (Colab)
# ==============================
!pip install -q tensorflow tensorflow-datasets streamlit pillow

# ==============================
# Imports
# ==============================
import tensorflow as tf
import tensorflow_datasets as tfds
from collections import Counter
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# Optional: for Colab drive
from google.colab import drive
drive.mount('/content/drive')

# ==============================
# Load PlantVillage dataset
# ==============================
dataset, info = tfds.load("plant_village", with_info=True, as_supervised=True)
train_ds_full = dataset['train']

NUM_CLASSES = info.features['label'].num_classes  # 38 classes
IMG_SIZE = (128, 128)
BATCH_SIZE = 32

print(f"Number of classes: {NUM_CLASSES}")
print(f"Class names: {info.features['label'].names}")

# ==============================
# Preprocessing
# ==============================
def preprocess(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train_ds_full = train_ds_full.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

# ==============================
# Train / Validation split
# ==============================
total_examples = info.splits["train"].num_examples
train_size = int(0.8 * total_examples)
val_size = total_examples - train_size

train_ds = train_ds_full.take(train_size).shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = train_ds_full.skip(train_size).take(val_size).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# ==============================
# Compute class weights
# ==============================
counter = Counter()
for images, labels in train_ds:
    for l in labels.numpy():
        counter[int(l)] += 1

total_samples = sum(counter.values())
class_weights = {cls: total_samples/count for cls, count in counter.items()}

print("Class weights:", class_weights)

# ==============================
# Data augmentation
# ==============================
AUTOTUNE = tf.data.AUTOTUNE

def augment(images, labels):
    images = tf.image.random_flip_left_right(images)
    images = tf.image.random_flip_up_down(images)
    images = tf.image.random_brightness(images, 0.1)
    return images, labels

train_data_aug = train_ds.map(augment, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

# ==============================
# Build MobileNetV2 Functional Model
# ==============================
inputs = Input(shape=(128,128,3))
base_model = MobileNetV2(weights='imagenet', include_top=False, input_tensor=inputs)
base_model.trainable = False  # Freeze base

x = GlobalAveragePooling2D()(base_model.output)
x = Dense(512, activation='relu')(x)
x = Dropout(0.3)(x)
outputs = Dense(NUM_CLASSES, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

# ==============================
# Callbacks
# ==============================
model = tf.keras.models.load_model('/content/drive/MyDrive/plant_disease_detection(0).keras')
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
checkpoint = ModelCheckpoint(
    '/content/drive/MyDrive/plant disease models/plant_disease_detection(1).keras',
    save_best_only=True,
    save_weights_only=False,
    verbose=1
)

# ==============================
# Train top layers
# ==============================
history = model.fit(
    train_data_aug,
    validation_data=val_ds,
    epochs=200,
    class_weight=class_weights,
    callbacks=[early_stop, reduce_lr, checkpoint]
)

# ==============================
# Fine-tuning last 20 layers of MobileNetV2
# ==============================
base_model.trainable = True
for layer in base_model.layers[:-20]:
    layer.trainable = False

model.compile(
    optimizer=Adam(1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history_fine = model.fit(
    train_data_aug,
    validation_data=val_ds,
    epochs=200,
    class_weight=class_weights,
    callbacks=[early_stop, reduce_lr, checkpoint]
)

# ==============================
# Save model
# ==============================
model.save("plant_disease_detection_model.keras")
print("✅ Model saved successfully")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m79.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m113.6 MB/s[0m eta [36m0:00:00[0m
[?25hDrive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).




Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/plant_village/1.0.2...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/plant_village/incomplete.TAYN5M_1.0.2/plant_village-train.tfrecord*...:   …

Dataset plant_village downloaded and prepared to /root/tensorflow_datasets/plant_village/1.0.2. Subsequent calls will reuse this data.
Number of classes: 38
Class names: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry___healthy', 'Cherry___Powdery_mildew', 'Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', 'Corn___healthy', 'Corn___Northern_Leaf_Blight', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___healthy', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Potato___Late_blight', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___healthy', 'Strawberry___Leaf_scorch', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___healthy', 'Tomato___Late_blight'

  base_model = MobileNetV2(weights='imagenet', include_top=False, input_tensor=inputs)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


Epoch 1/200
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 320ms/step - accuracy: 0.9786 - loss: 2.0771
Epoch 1: val_loss improved from inf to 0.16722, saving model to /content/drive/MyDrive/plant disease models/plant_disease_detection(1).keras
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m562s[0m 410ms/step - accuracy: 0.9786 - loss: 2.0771 - val_accuracy: 0.9550 - val_loss: 0.1672 - learning_rate: 3.1250e-05
Epoch 2/200
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 305ms/step - accuracy: 0.9794 - loss: 1.9763
Epoch 2: val_loss did not improve from 0.16722
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m533s[0m 389ms/step - accuracy: 0.9794 - loss: 1.9764 - val_accuracy: 0.9551 - val_loss: 0.1674 - learning_rate: 3.1250e-05
Epoch 3/200
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 293ms/step - accuracy: 0.9764 - loss: 2.1413
Epoch 3: val_loss improved from 0.16722 to 0.16719, saving model to