In [1]:
# Cell 1 - Imports
# !pip install -U tensorflow pillow opencv-python albumentations

import os
from pathlib import Path
import random
import numpy as np
import tensorflow as tf
from PIL import Image, ImageFilter, ImageEnhance
import cv2
import matplotlib.pyplot as plt

print("TensorFlow version:", tf.__version__)


TensorFlow version: 2.12.0


In [2]:
# Cell 2 - paths and constants
ROOT = Path.cwd().parent  # adjust if notebook is in a subfolder
processed_root = ROOT / "data_processed"  # same structure as your uploaded notebooks
models_dir = ROOT / "models"
augment_samples_dir = models_dir / "augment_samples"
models_dir.mkdir(exist_ok=True)
augment_samples_dir.mkdir(exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE

# get classes
classes = sorted([p.name for p in processed_root.iterdir() if p.is_dir()])
num_classes = len(classes)
print("Classes found:", num_classes, classes)


Classes found: 47 ['african_violet_saintpaulia_ionantha', 'aloe_vera', 'anthurium_anthurium_andraeanum', 'areca_palm_dypsis_lutescens', 'asparagus_fern_asparagus_setaceus', 'begonia_begonia_spp', 'bird_of_paradise_strelitzia_reginae', 'birds_nest_fern_asplenium_nidus', 'boston_fern_nephrolepis_exaltata', 'calathea', 'cast_iron_plant_aspidistra_elatior', 'chinese_evergreen_aglaonema', 'chinese_money_plant_pilea_peperomioides', 'christmas_cactus_schlumbergera_bridgesii', 'chrysanthemum', 'ctenanthe', 'daffodils_narcissus_spp', 'dracaena', 'dumb_cane_dieffenbachia_spp', 'elephant_ear_alocasia_spp', 'english_ivy_hedera_helix', 'hyacinth_hyacinthus_orientalis', 'iron_cross_begonia_begonia_masoniana', 'jade_plant_crassula_ovata', 'kalanchoe', 'lilium_hemerocallis', 'lily_of_the_valley_convallaria_majalis', 'money_tree_pachira_aquatica', 'monstera_deliciosa_monstera_deliciosa', 'orchid', 'parlor_palm_chamaedorea_elegans', 'peace_lily', 'poinsettia_euphorbia_pulcherrima', 'polka_dot_plant_hypo

In [3]:
# Cell 3 - dataset cleaning (quick)
from PIL import UnidentifiedImageError

def is_image_ok(path):
    try:
        img = Image.open(path)
        img.verify()
        return True
    except Exception:
        return False

def convert_to_jpeg(path):
    try:
        img = Image.open(path).convert("RGB")
        new_path = path.with_suffix(".jpg")
        img.save(new_path, "JPEG", quality=95)
        if new_path != path:
            os.remove(path)
        return new_path
    except Exception:
        return None

# run cleaning
bad_files = []
converted = 0
checked = 0
for cls in classes:
    for split in ["train", "val", "test"]:
        folder = processed_root / cls / split
        if not folder.exists():
            continue
        for f in list(folder.iterdir()):
            if not f.is_file():
                continue
            checked += 1
            if f.suffix.lower() not in [".jpg", ".jpeg", ".png", ".bmp"]:
                new = convert_to_jpeg(f)
                if new:
                    converted += 1
                else:
                    bad_files.append((str(f), "unsupported"))
                    continue
            if not is_image_ok(f):
                # try convert clean
                try:
                    img = Image.open(f).convert("RGB")
                    img.save(f, "JPEG", quality=90)
                except Exception:
                    bad_files.append((str(f), "corrupt"))
                    try:
                        os.remove(f)
                    except:
                        pass

print("Checked:", checked, "Converted:", converted, "Bad found:", len(bad_files))
for p, reason in bad_files[:10]:
    print("Bad:", p, reason)


KeyboardInterrupt: 

In [None]:
# Cell 4 - augmentation function
import numpy as np
from io import BytesIO

def augment_numpy(image_np, rng=None):
    """
    image_np: uint8 HWC (0..255)
    returns augmented uint8 HWC
    """
    if rng is None:
        rng = np.random.default_rng()
    img = Image.fromarray(image_np)

    # Random close crop (simulate zoom/partial plant framing)
    if rng.random() < 0.8:
        w, h = img.size
        scale = rng.uniform(0.85, 1.0)
        new_w, new_h = int(w*scale), int(h*scale)
        left = rng.integers(0, max(1, w-new_w))
        top = rng.integers(0, max(1, h-new_h))
        img = img.crop((left, top, left+new_w, top+new_h))
        img = img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
    else:
        img = img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)

    # Brightness & contrast
    if rng.random() < 0.6:
        img = ImageEnhance.Brightness(img).enhance(rng.uniform(0.7, 1.25))
    if rng.random() < 0.6:
        img = ImageEnhance.Contrast(img).enhance(rng.uniform(0.7, 1.25))

    # Slight green tint (Pi camera bias)
    if rng.random() < 0.5:
        arr = np.array(img).astype(np.float32)
        arr[:,:,1] = np.clip(arr[:,:,1] * rng.uniform(0.98, 1.06), 0, 255)
        img = Image.fromarray(arr.astype(np.uint8))

    # JPEG compression artifacts
    if rng.random() < 0.5:
        q = int(rng.uniform(35, 85))
        buff = BytesIO()
        img.save(buff, format='JPEG', quality=q)
        buff.seek(0)
        img = Image.open(buff).convert('RGB')

    arr = np.array(img).astype(np.float32)

    # Gaussian noise
    if rng.random() < 0.5:
        noise = rng.normal(0, rng.uniform(5, 20), arr.shape)
        arr = np.clip(arr + noise, 0, 255)

    # Motion blur
    if rng.random() < 0.35:
        k = int(rng.integers(3, 9))
        kernel = np.zeros((k, k))
        if rng.random() < 0.5:
            kernel[k//2, :] = 1.0 / k
        else:
            kernel[:, k//2] = 1.0 / k
        arr = cv2.filter2D(arr.astype(np.uint8), -1, kernel)

    # Vignette / shadow
    if rng.random() < 0.35:
        h, w, _ = arr.shape
        Y, X = np.ogrid[:h, :w]
        dist = np.sqrt((X - w/2)**2 + (Y - h/2)**2)
        mask = 1 - 0.6 * (dist / np.sqrt(w*w + h*h))
        mask = np.clip(mask, 0.4, 1).astype(np.float32)
        arr = (arr * mask[..., None])

    # Small median smoothing sometimes
    if rng.random() < 0.25:
        arr = cv2.medianBlur(arr.astype(np.uint8), 3)

    out = np.clip(arr, 0, 255).astype(np.uint8)
    return out


In [None]:
# Cell 5 - tf wrappers for loading and augment (SAFE VERSION for TF 2.12)

def load_image_simple(path, label):
    """Load image, resize, return float32 [0,1]; skip corrupt ones"""
    def _load(p, lbl):
        try:
            path_str = p.decode() if isinstance(p, bytes) else str(p)
            img = Image.open(path_str).convert("RGB")
            img = img.resize((IMG_SIZE, IMG_SIZE))
            arr = np.array(img, dtype=np.float32) / 255.0
            return arr, np.int32(lbl)
        except Exception as e:
            print("⚠️ Skipping bad file:", p, e)
            return np.zeros((IMG_SIZE, IMG_SIZE, 3), np.float32), np.int32(-1)
    img, lbl = tf.py_function(_load, [path, label], (tf.float32, tf.int32))
    img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    lbl.set_shape([])
    return img, lbl


def load_and_augment_train(path, label):
    """Load and apply augment_numpy via py_function, skipping bad files"""
    def _process(p, lbl):
        try:
            path_str = p.decode() if isinstance(p, bytes) else str(p)
            img = Image.open(path_str).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
            arr = np.array(img, dtype=np.uint8)
            arr = augment_numpy(arr)
            arr = arr.astype(np.float32) / 255.0
            return arr, np.int32(lbl)
        except Exception as e:
            print("⚠️ Skipping bad file:", p, e)
            return np.zeros((IMG_SIZE, IMG_SIZE, 3), np.float32), np.int32(-1)
    img, lbl = tf.py_function(_process, [path, label], (tf.float32, tf.int32))
    img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    lbl.set_shape([])
    return img, lbl


In [None]:
# Cell 6 - build path lists and datasets (SAFE VERSION)

def get_paths_and_labels(class_names, split):
    paths, labels = [], []
    for idx, cls in enumerate(class_names):
        p = processed_root / cls / split
        if not p.exists():
            continue
        for f in p.iterdir():
            if f.is_file():
                paths.append(str(f))
                labels.append(idx)
    return paths, labels


train_paths, train_labels = get_paths_and_labels(classes, "train")
val_paths, val_labels = get_paths_and_labels(classes, "val")
test_paths, test_labels = get_paths_and_labels(classes, "test")

print("Train/Val/Test sizes:", len(train_paths), len(val_paths), len(test_paths))

# ✅ Safe tf.data pipelines
train_ds = (
    tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    .shuffle(4000, reshuffle_each_iteration=True)
    .map(load_and_augment_train, num_parallel_calls=tf.data.AUTOTUNE)
    .filter(lambda img, lbl: tf.not_equal(lbl, -1))
    .batch(BATCH_SIZE, drop_remainder=True)   # ensures full, non-empty batches
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
    .map(load_image_simple, num_parallel_calls=tf.data.AUTOTUNE)
    .filter(lambda img, lbl: tf.not_equal(lbl, -1))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((test_paths, test_labels))
    .map(load_image_simple, num_parallel_calls=tf.data.AUTOTUNE)
    .filter(lambda img, lbl: tf.not_equal(lbl, -1))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)


In [None]:
# Cell 7 - save sample images (before and after augmentation)
import random
from pathlib import Path
sample_dir = augment_samples_dir / "sample_images"
sample_dir.mkdir(exist_ok=True)

def save_random_originals(n=8):
    picked = random.sample(train_paths, min(n, len(train_paths)))
    for i, p in enumerate(picked):
        try:
            img = Image.open(p).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
            img.save(sample_dir / f"orig_{i:02d}.jpg")
        except:
            pass
    print("Saved originals to:", sample_dir)

def save_augmented_from_pipeline(n=8):
    # take batches from train_ds (which yields augmented images)
    i = 0
    for batch in train_ds:
        imgs, labels = batch
        for j in range(imgs.shape[0]):
            arr = (imgs[j].numpy() * 255).astype(np.uint8)
            Image.fromarray(arr).save(sample_dir / f"aug_{i:03d}.jpg")
            i += 1
            if i >= n:
                print("Saved augmented samples to:", sample_dir)
                return

# Run them
save_random_originals(8)
save_augmented_from_pipeline(8)


In [None]:
# Cell 8 - show a few saved samples
from IPython.display import display
files = sorted(list(sample_dir.glob("*.jpg")))
for f in files[:8]:
    display(Image.open(f).resize((200,200)))


In [None]:
# Cell 9 - build model
from tensorflow.keras import layers, models

# Option A (recommended): MobileNetV3Small (fast + small)
base = tf.keras.applications.MobileNetV3Small(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights="imagenet",
    minimalistic=True
)

# Alternative Option B (if you prefer MobileNetV2 as used in 03):
# base = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE,IMG_SIZE,3), include_top=False, weights='imagenet')

base.trainable = False  # freeze base initially

inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = base(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)

model = models.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

model.summary()


In [None]:
# Cell 10 - training stage 1 (train head, stable version)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, Callback

EPOCHS_HEAD = 10

lr_schedule = ReduceLROnPlateau(
    monitor="val_loss", factor=0.5, patience=2, min_lr=1e-7, verbose=1
)
early_stop = EarlyStopping(
    monitor="val_loss", patience=3, restore_best_weights=True, verbose=1
)

class SaveAugmentSamples(Callback):
    """Saves one augmented sample per epoch (optional visual check)"""
    def __init__(self, out_dir, sample_path):
        super().__init__()
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(parents=True, exist_ok=True)
        self.sample_path = sample_path
        self.original = Image.open(sample_path).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
    def on_epoch_end(self, epoch, logs=None):
        aug_np = augment_numpy(np.array(self.original))
        Image.fromarray(aug_np).save(self.out_dir / f"epoch_{epoch+1:02d}_aug.jpg")

sample_val = val_paths[0] if len(val_paths) > 0 else train_paths[0]
save_aug_cb = SaveAugmentSamples(out_dir=augment_samples_dir / "epochs", sample_path=sample_val)

# ✅ Re-compile to clear any graph cache
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

# ✅ Start training (this will now run cleanly)
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_HEAD,
    callbacks=[lr_schedule, early_stop, save_aug_cb],
    verbose=1
)


In [None]:
# Cell 11 - fine-tuning
base.trainable = True
# Optionally unfreeze only top layers:
# for layer in base.layers[:-30]:
#     layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

history_fine = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,
    verbose=1
)


In [None]:
# Cell 12 - evaluation & save TF model
test_loss, test_acc = model.evaluate(test_ds, verbose=1)
print(f"Test accuracy: {test_acc:.4f}")

tf_model_path = models_dir / "plant_classifier_tf"
model.save(tf_model_path)
print("Saved Keras model to:", tf_model_path)


In [None]:
# Cell 13 - TFLite INT8 quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Full integer quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Ensure input/output are int8
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

# Representative dataset: yield 1 sample at a time as a numpy array scaled to uint8 if needed
def representative_dataset_gen():
    for image_batch, _ in val_ds.take(100):
        # image_batch is float32 [0,1], shape [B, H, W, C]
        # convert to uint8 [0,255]
        img_uint8 = (image_batch * 255.0).numpy().astype(np.uint8)
        for i in range(img_uint8.shape[0]):
            yield [img_uint8[i:i+1]]

converter.representative_dataset = representative_dataset_gen

tflite_model = converter.convert()
tflite_path = models_dir / "plant_classifier_int8.tflite"
tflite_path.write_bytes(tflite_model)
print("Saved quantized TFLite to:", tflite_path, "size KB:", len(tflite_model)/1024)
