<a href="https://colab.research.google.com/github/d-nazli/mobileNet_model_training/blob/main/fowersV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import json, random, os
from pathlib import Path

In [None]:

data_dir   = "/content/drive/MyDrive/staj"
tgz_path   = f"{data_dir}/102flowers.tgz"
labels_path= f"{data_dir}/imagelabels.mat"
label_map_path = "/content/label_map.json"

In [None]:

!tar -xzf "$tgz_path" -C /content/
print("total picture:", len(os.listdir("/content/jpg")))

In [None]:

with open(label_map_path) as f:
    label_map = json.load(f)

In [None]:

import os, json, shutil, scipy.io
from pathlib import Path

IMAGES_DIR = Path("/content/jpg")
LABELS_MAT = "/content/drive/MyDrive/staj/imagelabels.mat"
LABEL_MAP  = "/content/label_map.json"


labels_mat = scipy.io.loadmat(LABELS_MAT)["labels"][0]
with open(LABEL_MAP) as f:
    id2name = json.load(f)


DATA_DIR = Path("/content/data_species_named")
if DATA_DIR.exists(): shutil.rmtree(DATA_DIR)
DATA_DIR.mkdir(parents=True, exist_ok=True)


for k in range(1, 103):
    (DATA_DIR / id2name[str(k)].replace(" ", "_")).mkdir(parents=True, exist_ok=True)


for i, cls_id in enumerate(labels_mat, start=1):
    fname = f"image_{i:05d}.jpg"
    dst   = DATA_DIR / id2name[str(int(cls_id))].replace(" ", "_") / fname
    shutil.copy(IMAGES_DIR / fname, dst)

print("The pictures were divided into folders:", DATA_DIR)


In [None]:

import tensorflow as tf
import numpy as np
import os, json
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.callbacks import *
from tensorflow.keras.layers import *
from tensorflow.keras.applications import *


def setup_gpu():
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(" GPU memory growth enabled")
        except:
            print(" GPU error")

setup_gpu()


class Config:
    DATA_DIR = "/content/data_species_named"
    LABEL_MAP_PATH = "/content/label_map.json"
    IMG_SIZE = (224, 224)
    BATCH_SIZE = 32
    EPOCHS_STAGE1 = 20
    EPOCHS_STAGE2 = 15
    SEED = 42


    LR_STAGE1 = 1e-3
    LR_STAGE2 = 1e-5


    WEAK_CLASSES = [
        "sweet_pea", "mallow", "petunia", "pink_primrose",
        "desert-rose", "yellow_iris", "lotus", "water_lily"
    ]

config = Config()


def load_class_names_from_labelmap(label_map_path):
    with open(label_map_path, "r") as f:
        id2name = json.load(f)
    return [id2name[str(i)] for i in range(1, len(id2name)+1)]


def create_datasets():
    print("Dataset loading...")

    train_ds = tf.keras.utils.image_dataset_from_directory(
        config.DATA_DIR,
        validation_split=0.2,
        subset="training",
        seed=config.SEED,
        image_size=config.IMG_SIZE,
        batch_size=config.BATCH_SIZE,
        label_mode='int'
    )

    val_ds = tf.keras.utils.image_dataset_from_directory(
        config.DATA_DIR,
        validation_split=0.2,
        subset="validation",
        seed=config.SEED,
        image_size=config.IMG_SIZE,
        batch_size=config.BATCH_SIZE,
        label_mode='int'
    )

    class_names = load_class_names_from_labelmap(config.LABEL_MAP_PATH)
    num_classes = len(class_names)

    print(f" Total classes: {num_classes}")
    print(f"Weak classes: {len(config.WEAK_CLASSES)}")

    return train_ds, val_ds, class_names, num_classes

#augmentation
def create_augmentation_layers():
    normal_aug = tf.keras.Sequential([
        RandomFlip("horizontal"),
        RandomRotation(0.15),
        RandomZoom(0.15),
        RandomContrast(0.15),
        RandomBrightness(0.1),
    ])

    strong_aug = tf.keras.Sequential([
        RandomFlip("horizontal_and_vertical"),
        RandomRotation(0.25),
        RandomZoom(0.25),
        RandomContrast(0.25),
        RandomBrightness(0.2),
        RandomTranslation(0.1, 0.1),
    ])

    return normal_aug, strong_aug

def smart_augment(images, labels, weak_indices, normal_aug, strong_aug):
    def augment_one(inputs):
        image, label = inputs
        is_weak = tf.reduce_any(tf.equal(label, weak_indices))

        def weak_aug():
            return tf.cond(
                tf.random.uniform([]) < 0.7,
                lambda: strong_aug(tf.expand_dims(image, 0))[0],
                lambda: normal_aug(tf.expand_dims(image, 0))[0]
            )

        def normal():
            return normal_aug(tf.expand_dims(image, 0))[0]

        return tf.cond(is_weak, weak_aug, normal)

    augmented_images = tf.map_fn(
        augment_one,
        (images, labels),
        fn_output_signature=tf.float32
    )

    return augmented_images, labels


def create_balanced_dataset(train_ds, class_names):
    print("The balanced dataset is being created...")

    normal_aug, strong_aug = create_augmentation_layers()
    weak_indices = tf.constant([class_names.index(c) for c in config.WEAK_CLASSES if c in class_names], dtype=tf.int32)

    def augment_fn(images, labels):
        return smart_augment(images, labels, weak_indices, normal_aug, strong_aug)

    augmented_ds = train_ds.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
    weak_ds = train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, weak_indices))).map(
        augment_fn, num_parallel_calls=tf.data.AUTOTUNE
    )

    final_ds = augmented_ds.concatenate(weak_ds.repeat(3))
    return final_ds.shuffle(1000).prefetch(tf.data.AUTOTUNE)

#model
def create_mobile_model(num_classes):
    print(" Mobile model is being created...")

    base_model = tf.keras.applications.MobileNetV3Small(
        input_shape=config.IMG_SIZE + (3,),
        include_top=False,
        weights="imagenet",
        alpha=1.0,
        minimalistic=False,
        include_preprocessing=False
    )
    base_model.trainable = False

    inputs = tf.keras.Input(shape=config.IMG_SIZE + (3,))
    x = tf.keras.applications.mobilenet_v3.preprocess_input(inputs)
    x = base_model(x, training=False)
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation="softmax", kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)

    model = tf.keras.Model(inputs, outputs, name="MobileFlowerClassifier")
    return model, base_model

#focal loss
def focal_loss(gamma=2.0, alpha=0.25):
    def focal_loss_fn(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_true = tf.one_hot(tf.cast(y_true, tf.int32), tf.shape(y_pred)[1])
        y_pred = tf.clip_by_value(y_pred, 1e-8, 1.0 - 1e-8)
        ce_loss = -y_true * tf.math.log(y_pred)
        pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
        focal_weight = alpha * tf.pow(1 - pt, gamma)
        focal_loss = focal_weight * ce_loss
        return tf.reduce_mean(tf.reduce_sum(focal_loss, axis=1))
    return focal_loss_fn

#class weight
def calculate_class_weights(train_ds, num_classes):
    print("Class weights is calculated...")
    all_labels = []
    for _, labels in train_ds:
        all_labels.extend(labels.numpy())
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=all_labels)
    return {i: weight for i, weight in enumerate(class_weights)}


def create_callbacks():
    return [
        ModelCheckpoint("best_mobile_flower_model.keras", monitor="val_accuracy", save_best_only=True, mode="max"),
        EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=True),
        ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=4, min_lr=1e-7),
        TensorBoard(log_dir="./logs", histogram_freq=1)
    ]

# training pipeline
def train_model():
    train_ds, val_ds, class_names, num_classes = create_datasets()
    train_balanced = create_balanced_dataset(train_ds, class_names)
    val_ds = val_ds.cache().prefetch(tf.data.AUTOTUNE)
    model, base_model = create_mobile_model(num_classes)
    class_weights = calculate_class_weights(train_ds, num_classes)
    callbacks = create_callbacks()

    print("\n STAGE 1...")
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=config.LR_STAGE1),
        loss=focal_loss(),
        metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="top5_acc")]
    )
    history1 = model.fit(train_balanced, validation_data=val_ds,
                         epochs=config.EPOCHS_STAGE1, class_weight=class_weights, callbacks=callbacks)

    print("\n STAGE 2...")
    base_model.trainable = True
    for layer in base_model.layers[:-30]:
        layer.trainable = False
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=config.LR_STAGE2),
        loss=focal_loss(gamma=1.5),
        metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name="top5_acc")]
    )
    history2 = model.fit(train_balanced, validation_data=val_ds,
                         epochs=config.EPOCHS_STAGE2, class_weight=class_weights, callbacks=callbacks)
    return model, class_names, history1, history2, val_ds


def export_for_mobile(model, class_names, val_ds):
    print("\n Mobile export...")
    model.save("mobile_flower_classifier_full.keras")
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]

    def representative_data_gen():
        for images, _ in val_ds.take(100):
            for image in images:
                yield [tf.expand_dims(image, 0)]
    converter.representative_dataset = representative_data_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
    tflite_model = converter.convert()

    with open("mobile_flower_classifier.tflite", "wb") as f:
        f.write(tflite_model)
    with open("class_names.txt", "w") as f:
        for name in class_names:
            f.write(f"{name}\n")

    print(f"tflite model size: {len(tflite_model) / 1024 / 1024:.2f} MB")
    return tflite_model


if __name__ == "__main__":
    print(" MOBILE FLOWER CLASSIFIER - PRODUCTION VERSION ")
    print("=" * 60)
    model, class_names, hist1, hist2, val_ds = train_model()
    tflite_model = export_for_mobile(model, class_names, val_ds)
    print("\n model was complated")
