# Train 6 CNN models with mixed precision
This notebook prepares a plant-disease dataset, enables mixed precision, and trains six models (VGG16, VGG19, InceptionV3, Xception, ResNet50, DenseNet121) using `image_dataset_from_directory`.
Notes: run the first code cell to install dependencies if needed, then run cells in order.

In [14]:
# Install required packages (run if packages missing)
# Uncomment to run installs inside the notebook environment
# !pip install -q tensorflow tensorflow-io kagglehub matplotlib seaborn

Using Colab cache for faster access to the 'new-plant-diseases-dataset' dataset.
Path to dataset files: /kaggle/input/new-plant-diseases-dataset
Path to dataset files: /kaggle/input/new-plant-diseases-dataset


In [18]:
# Imports and mixed precision setup
import os
import pathlib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import datetime
# Enable mixed precision for faster training on modern GPUs
try:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    print('Mixed precision policy set to:', tf.keras.mixed_precision.global_policy())
except Exception as e:
    print('Could not set mixed precision policy:', e)
AUTOTUNE = tf.data.AUTOTUNE

Mixed precision policy set to: <DTypePolicy "mixed_float16">


In [19]:
# Dataset path: try to download with kagglehub if available, else set `data_dir` manually.
import zipfile
data_dir = None
try:
    import kagglehub
    print('Attempting to download dataset with kagglehub...')
    path = kagglehub.dataset_download('vipoooool/new-plant-diseases-dataset')
    # kagglehub.dataset_download often returns a zip path; try to extract
    if path:
        if path.endswith('.zip') and os.path.exists(path):
            extract_to = os.path.splitext(path)[0] + '_extracted'
            with zipfile.ZipFile(path, 'r') as zf:
                zf.extractall(extract_to)
            data_dir = extract_to
            print('Extracted dataset to', data_dir)
        else:
            # If path is a directory or already extracted
            data_dir = path
            print('Using dataset path:', data_dir)
except Exception as e:
    print('kagglehub not available or download failed:', e)
    # Fallback: user should edit this path to point to local dataset directory
if data_dir is None:
    # EDIT this if your dataset is already available locally
    data_dir = '/path/to/plant-disease-dataset'  # <-- change this to your dataset folder
    print('Please set `data_dir` to your dataset directory. Current value:', data_dir)
data_dir = pathlib.Path(data_dir)

Attempting to download dataset with kagglehub...
Using Colab cache for faster access to the 'new-plant-diseases-dataset' dataset.
Using dataset path: /kaggle/input/new-plant-diseases-dataset
Using Colab cache for faster access to the 'new-plant-diseases-dataset' dataset.
Using dataset path: /kaggle/input/new-plant-diseases-dataset


In [20]:
# Create train/validation/test datasets using image_dataset_from_directory
# Adjust these parameters as needed
batch_size = 32
image_size = (224, 224)
validation_split = 0.2
seed = 123
label_mode = 'categorical'
# If dataset path doesn't exist, this cell will error - set `data_dir` first
if not data_dir.exists():
    raise FileNotFoundError(f'Data directory not found: {data_dir} - please set the correct path')
train_ds = tf.keras.utils.image_dataset_from_directory(str(data_dir),
    validation_split=validation_split, subset='training', seed=seed,
    image_size=image_size, batch_size=batch_size, label_mode=label_mode)
val_ds = tf.keras.utils.image_dataset_from_directory(str(data_dir),
    validation_split=validation_split, subset='validation', seed=seed,
    image_size=image_size, batch_size=batch_size, label_mode=label_mode)
# Optionally create a test set by splitting differently or using a separate folder
class_names = train_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)
print('Number of classes:', num_classes)

Found 175767 files belonging to 3 classes.
Using 140614 files for training.
Using 140614 files for training.
Found 175767 files belonging to 3 classes.
Found 175767 files belonging to 3 classes.
Using 35153 files for validation.
Using 35153 files for validation.
Classes: ['New Plant Diseases Dataset(Augmented)', 'new plant diseases dataset(augmented)', 'test']
Number of classes: 3
Classes: ['New Plant Diseases Dataset(Augmented)', 'new plant diseases dataset(augmented)', 'test']
Number of classes: 3


In [21]:
# Performance: cache and prefetch
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [23]:
# Model factory and training loop
from tensorflow.keras import applications
from tensorflow.keras import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Input
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
models_to_train = {
    'VGG16': applications.VGG16,
    'VGG19': applications.VGG19,
    'InceptionV3': applications.InceptionV3,
    'Xception': applications.Xception,
    'ResNet50': applications.ResNet50,
    'DenseNet121': applications.DenseNet121,
}
# Training hyperparameters
base_learning_rate = 1e-4
head_epochs = 3
fine_tune_epochs = 2
models_dir = pathlib.Path('models')
models_dir.mkdir(parents=True, exist_ok=True)
histories = {}
for name, constructor in models_to_train.items():
    print('== Training', name, '==')
    tf.keras.backend.clear_session()
    # Build base model
    try:
        base = constructor(weights='imagenet', include_top=False, input_shape=(image_size[0], image_size[1], 3))
    except Exception as e:
        print(f'Failed to construct {name} with imagenet weights: {e} - trying without weights')
        base = constructor(weights=None, include_top=False, input_shape=(image_size[0], image_size[1], 3))
    base.trainable = False
    inputs = Input(shape=(image_size[0], image_size[1], 3))
    x = base(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.3)(x)
    # Ensure final dense is float32 to avoid numeric issues with mixed precision
    outputs = Dense(num_classes, activation='softmax', dtype='float32')(x)
    model = Model(inputs, outputs, name=name)
    # Optimizer with loss scaling for mixed precision
    base_optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate)
    try:
        optimizer = tf.keras.mixed_precision.LossScaleOptimizer(base_optimizer)
    except Exception:
        optimizer = base_optimizer
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    model.summary()
    # Callbacks
    timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    ckpt = ModelCheckpoint(models_dir / f'{name}_best.h5', monitor='val_accuracy', save_best_only=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, verbose=1)
    early = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
    # Train head
    history_head = model.fit(train_ds, validation_data=val_ds, epochs=head_epochs, callbacks=[ckpt, reduce_lr, early])
    # Optionally fine-tune: unfreeze last block and continue training
    base.trainable = True
    # Recompile with lower LR
    try:
        base_optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate/10)
        optimizer = tf.keras.mixed_precision.LossScaleOptimizer(base_optimizer)
    except Exception:
        optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate/10)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    history_fine = model.fit(train_ds, validation_data=val_ds, epochs=head_epochs + fine_tune_epochs, initial_epoch=history_head.epoch[-1]+1 if hasattr(history_head, 'epoch') and len(history_head.epoch)>0 else 0, callbacks=[ckpt, reduce_lr, early])
    # Save final model
    model.save(models_dir / f'{name}_final')
    # Combine histories for plotting later
    h = {}
    for k,v in history_head.history.items():
        h[k] = v.copy()
    for k,v in history_fine.history.items():
        # append fine tuning metrics
        if k in h:
            h[k].extend(v)
        else:
            h[k] = v.copy()
    histories[name] = h
    print(f'Finished training {name}. Saved to {models_dir}')

== Training VGG16 ==


Epoch 1/3
[1m 432/4395[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m10:17[0m 156ms/step - accuracy: 0.4410 - loss: 3.2106

: 

: 

: 

In [None]:
# Plot training curves for all models
import seaborn as sns
for name, h in histories.items():
    epochs = range(1, len(h['loss'])+1)
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(epochs, h['loss'], label='train_loss')
    plt.plot(epochs, h['val_loss'], label='val_loss')
    plt.title(f'{name} Loss')
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(epochs, h['accuracy'], label='train_acc')
    plt.plot(epochs, h['val_accuracy'], label='val_acc')
    plt.title(f'{name} Accuracy')
    plt.legend()
    plt.show()