## Train Model

Train the model.

In [10]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from google.colab import files
import zipfile
import shutil
import logging
import pickle
import os

## Setup Logger

In [11]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

MODEL_FILENAME = "trained_model.h5"
CLASSES_FILENAME = "classes.pkl"

## Create Model

In [12]:
def create_model(num_classes):
    """
    Example CNN model. It can be customized or replaced
    with a pretrained model (e.g., MobileNet, ResNet, etc.).
    """
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu',
                      input_shape=(224, 224, 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy'])
    return model

 ## Setup Data Generators

In [13]:
def setup_data_generators(
            data_dir: str, batch_size: int
        ) -> tuple[ImageDataGenerator, ImageDataGenerator]:
    """ Set up data generators for training and validation """
    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        validation_split=0.2
    )
    val_datagen = ImageDataGenerator(rescale=1. / 255, validation_split=0.2)

    train_generator = train_datagen.flow_from_directory(
        directory=data_dir,
        target_size=(224, 224),
        batch_size=batch_size,
        class_mode='categorical',
        subset='training',
        shuffle=True
    )
    val_generator = val_datagen.flow_from_directory(
        directory=data_dir,
        target_size=(224, 224),
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation',
        shuffle=True
    )

    return train_generator, val_generator

## Save Trainings

In [14]:
def save_trainings(
            model: models.Sequential,
            output_zip: str,
            classes: list[str]
        ) -> None:
    """ Save the trained model and augmented images in a zip file """
    model.save(MODEL_FILENAME)
    logger.info(f"Model saved as {MODEL_FILENAME}")

    with open(CLASSES_FILENAME, 'wb') as f:
        pickle.dump(classes, f)
    logger.info(f"Classes saved as {CLASSES_FILENAME}")

    with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zf:
        zf.write(MODEL_FILENAME)
        zf.write(CLASSES_FILENAME)

        for root, _, files in os.walk("augmented_images"):
            for file in files:
                full_path = os.path.join(root, file)
                relative_path = os.path.relpath(full_path, start=".")
                zf.write(full_path, arcname=relative_path)

    logger.info(f"Model and augmented images saved in {output_zip}")

## Main Function

In [15]:
def main():
    """
    Main function to train the model and save it in a zip file.
    """

    dataset_path = extract_dataset_from_zip()

    if not os.path.isdir(dataset_path):
        raise NotADirectoryError(f'{dataset_path} is not a valid directory.')

    data_dir = os.path.abspath(dataset_path)
    batch_size = 64

    epochs = 15
    output_zip = "trained_model_and_augmented.zip"

    data_dir = os.path.abspath(data_dir)

    train_generator, val_generator = setup_data_generators(data_dir, batch_size)
    classes = list(train_generator.class_indices.keys())
    print(f"Found {len(classes)} classes: {classes}")

    model = create_model(len(classes))
    model.summary()
    print("Start training")
    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=epochs,
        verbose=1
    )

    val_loss, val_acc = model.evaluate(val_generator)
    print(f"Validation loss: {val_loss}, accuracy: {val_acc}")

    save_trainings(model, output_zip, classes)
    print("Training and saving completed.")

## Upload and Extract Dataset in Google Colab

Upload `.zip` dataset file in Google Colab. The file will be extracted and used for training.

In [16]:
def extract_dataset_from_zip():
    """
    Uploads a zip file in Google Colab, extracts it,
    and returns the path to the extracted dataset.
    If upload is canceled, it uses a default path ('/content/augmented_images.zip').
    If the 'augmented_images' folder exists, it uses that folder.
    """
    try:
        uploaded = files.upload()
        if uploaded:
            zip_file_path = list(uploaded.keys())[0]
            if zip_file_path.lower().endswith('.zip'):
                base_name = os.path.splitext(os.path.basename(zip_file_path))[0]
                extract_path = os.getcwd()
                with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
                    zip_ref.extractall(extract_path)

                dataset_path = os.path.join(extract_path, base_name)
                if not os.path.isdir(dataset_path):
                    raise NotADirectoryError(f'{dataset_path} is not a valid directory.')

                print(f"Dataset extracted to: {dataset_path}")
                return dataset_path
            else:
                raise ValueError("Uploaded file is not a .zip file.")
        else:
            print("Upload canceled. Using default path: /content/augmented_images.zip")
            zip_file_path = "/content/augmented_images.zip"

    except Exception as e:
        print(f"Error during upload: {e}")
        print("Using default path: /content/augmented_images.zip")
        zip_file_path = "/content/augmented_images.zip"

    if os.path.isdir("augmented_images"):
        print("Using existing 'augmented_images' folder.")
        return "augmented_images"
    else:
        base_name = os.path.splitext(os.path.basename(zip_file_path))[0]
        extract_path = os.getcwd()
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)

        dataset_path = os.path.join(extract_path, base_name)
        if not os.path.isdir(dataset_path):
            raise NotADirectoryError(f'{dataset_path} is not a valid directory.')

        print(f"Dataset extracted to: {dataset_path}")
        return dataset_path

## Run Main

In [17]:
main()

Upload canceled. Using default path: /content/augmented_images.zip
Using existing 'augmented_images' folder.
Found 10496 images belonging to 8 classes.
Found 2624 images belonging to 8 classes.
Found 8 classes: ['Apple_Black_rot', 'Apple_healthy', 'Apple_rust', 'Apple_scab', 'Grape_Black_rot', 'Grape_Esca', 'Grape_healthy', 'Grape_spot']


Start training
Epoch 1/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 826ms/step - accuracy: 0.3767 - loss: 1.8467 - val_accuracy: 0.7382 - val_loss: 0.8639
Epoch 2/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m134s[0m 817ms/step - accuracy: 0.7677 - loss: 0.6503 - val_accuracy: 0.7767 - val_loss: 0.8055
Epoch 3/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m134s[0m 819ms/step - accuracy: 0.8482 - loss: 0.4192 - val_accuracy: 0.8258 - val_loss: 0.5686
Epoch 4/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m133s[0m 814ms/step - accuracy: 0.8965 - loss: 0.2994 - val_accuracy: 0.8735 - val_loss: 0.3976
Epoch 5/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m136s[0m 827ms/step - accuracy: 0.8966 - loss: 0.2853 - val_accuracy: 0.9127 - val_loss: 0.2453
Epoch 6/15
[1m164/164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 814ms/step - accuracy: 0.9226 - loss: 0.2214 - val_accuracy: 0.9005 - val_lo



Validation loss: 0.16554541885852814, accuracy: 0.9432164430618286
Training and saving completed.


## Cleanup

In [18]:
# !rm -fr ./Apple Apple.zip trained_leaf_disease_model.h5 trained_model_and_augmented.zip ./augmented_images/