In [1]:
import tensorflow as tf
from datasets import load_dataset
import numpy as np
from tensorflow import keras
from tensorflow.keras.applications import ResNet152, MobileNetV2
from tensorflow.keras.preprocessing import image_dataset_from_directory, image
import os
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import time
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocess
from tensorflow.keras.callbacks import ModelCheckpoint

In [2]:
# Set number of images to use (Adjust as needed)
NUM_TRAIN_SAMPLES = 12000   # Change this to the number of images you want for training
NUM_VAL_SAMPLES = 2500     # Change this for validation
NUM_TEST_SAMPLES = 2500    # Change this for testing

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
DATA_PATH = "/kaggle/input/plantvillage-dataset/color"

train_ds_full = image_dataset_from_directory(
    DATA_PATH,
    validation_split=0.2,  # 80% train, 20% temp
    subset="training",
    seed=42,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

temp_ds = image_dataset_from_directory(
    DATA_PATH,
    validation_split=0.2,  # 20% temp (validation + test)
    subset="validation",
    seed=42,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

# Get class names
class_names = train_ds_full.class_names
num_classes = len(class_names)
print("Classes:", class_names)
print("Number of classes:", num_classes)

# Further split temp_ds into validation and test
val_size = tf.data.experimental.cardinality(temp_ds).numpy() // 2
val_ds_full = temp_ds.take(val_size)  
test_ds_full = temp_ds.skip(val_size)  

train_ds = train_ds_full.take(NUM_TRAIN_SAMPLES // BATCH_SIZE)
val_ds = val_ds_full.take(NUM_VAL_SAMPLES // BATCH_SIZE)
test_ds = test_ds_full.take(NUM_TEST_SAMPLES // BATCH_SIZE)

def preprocess_both(image, label):

    teacher_image = resnet_preprocess(image)
    student_image = mobilenet_preprocess(image)
    
    return (teacher_image, student_image), label

train_ds = train_ds.map(preprocess_both).cache().shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess_both).cache().prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess_both).cache().prefetch(tf.data.AUTOTUNE)


Found 54305 files belonging to 38 classes.
Using 43444 files for training.
Found 54305 files belonging to 38 classes.
Using 10861 files for validation.
Classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato__

In [None]:

base_model = ResNet152(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
base_model.trainable = True

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(num_classes, activation=None)
])

model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"]
    )

teacher_train_ds = train_ds.map(lambda inputs, label: (inputs[0], label))
teacher_val_ds = val_ds.map(lambda inputs, label: (inputs[0], label))

resnet_history = model.fit(
    teacher_train_ds,
    validation_data=teacher_val_ds,
    epochs=12
)


In [None]:
model.evaluate(test_ds.map(lambda inputs, label: (inputs[0], label)))

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(resnet_history.history['loss'], label='Train Loss')
plt.plot(resnet_history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Plot training and validation accuracy
plt.figure(figsize=(10, 4))
plt.plot(resnet_history.history['accuracy'], label='Train Accuracy')
plt.plot(resnet_history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
model.save("resnet152_plant_disease.h5")

Distill code start from here!

In [None]:
teacher_model = load_model('/kaggle/input/teacher_model/pytorch/default/1/resnet152_plant_disease.h5')
teacher_model.trainable = False
teacher_model.evaluate(test_ds.map(lambda inputs, label: (inputs[0], label)))

In [None]:
num_classes = 38

base_model = MobileNetV2(weights=None, include_top=False, input_shape=(224, 224, 3))
base_model.trainable = True

student_model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(num_classes)
])

In [None]:
class Distiller(tf.keras.Model):
    def __init__(self, student, teacher, temperature, alpha):
        super(Distiller, self).__init__()
        self.student = student
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.kl_loss = tf.keras.losses.KLDivergence()

    def compile(self, optimizer, metrics):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)

    def train_step(self, data):
        (teacher_x, student_x), y = data

        # Teacher makes predictions on its preprocessed input
        teacher_logits = self.teacher(teacher_x, training=False)

        with tf.GradientTape() as tape:
            # Student makes predictions on its own preprocessed input
            student_logits = self.student(student_x, training=True)
            loss_ce = self.ce_loss(y, student_logits)
            loss_kl = self.kl_loss(
                tf.nn.softmax(teacher_logits / self.temperature),
                tf.nn.softmax(student_logits / self.temperature)
            ) * (self.temperature ** 2)

            loss = (1 - self.alpha) * loss_kl + self.alpha * loss_ce

        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
        self.compiled_metrics.update_state(y, student_logits)

        return {m.name: m.result() for m in self.metrics} | {'loss': loss}

    def test_step(self, data):
        (teacher_x, student_x), y = data
        student_logits = self.student(student_x, training=False)
        loss_ce = self.ce_loss(y, student_logits)
        self.compiled_metrics.update_state(y, student_logits)
        return {m.name: m.result() for m in self.metrics} | {'loss': loss_ce}

In [None]:
distiller = Distiller(student=student_model, teacher=teacher_model, temperature=3.0, alpha=0.2)

distiller.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

checkpoint_cb = ModelCheckpoint(
    "best_student_model.keras",
    monitor="val_sparse_categorical_accuracy",
    save_best_only=True,
    mode="max",
    verbose=1
)

# Train the distiller with the callback
student_history = distiller.fit(
    train_ds,
    validation_data=val_ds,
    epochs=80,
    callbacks=[checkpoint_cb]
)

student_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(student_history.history['loss'], label='Train Loss')
plt.plot(student_history.history['val_loss'], label='Validation Loss')
plt.title('Distill Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(10, 4))
plt.plot(student_history.history['sparse_categorical_accuracy'], label='Train Accuracy')
plt.plot(student_history.history['val_sparse_categorical_accuracy'], label='Validation Accuracy')
plt.title('Distill Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
student_model.evaluate(test_ds.map(lambda inputs, label: (inputs[1], label)))

In [None]:
base_model = load_model('/kaggle/working/base_student_model.h5')

In [None]:
base_model.evaluate(test_ds.map(lambda inputs, label: (inputs[1], label)))

In [None]:
student_model.save("distill_student_model.h5")

train without distill

In [None]:
import tensorflow.keras.backend as K
K.clear_session()

In [None]:
num_classes = 38

base_model = MobileNetV2(weights=None, include_top=False, input_shape=(224, 224, 3))
base_model.trainable = True

base_student_model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(num_classes)
])

base_train_ds = train_ds.map(lambda inputs, label: (inputs[1], label))
base_val_ds = val_ds.map(lambda inputs, label: (inputs[1], label))

base_student_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

In [None]:
history_baseline = base_student_model.fit(base_train_ds, validation_data=base_val_ds, epochs=80)

In [None]:
base_student_model.evaluate(test_ds.map(lambda inputs, label: (inputs[1], label)))

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(history_baseline.history['loss'], label='Train Loss')
plt.plot(history_baseline.history['val_loss'], label='Validation Loss')
plt.title('Base Model Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(10, 4))
plt.plot(history_baseline.history['sparse_categorical_accuracy'], label='Train Accuracy')
plt.plot(history_baseline.history['val_sparse_categorical_accuracy'], label='Validation Accuracy')
plt.title('Base Model Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
base_studen_model.summary()

In [None]:
base_student_model.save("base_student_model.h5")

inference

In [21]:
base_model = load_model('/kaggle/input/based_mobilenet/tensorflow2/default/1/base_student_model.h5')
distilled_model = load_model('/kaggle/input/distill_student_model/tensorflow2/default/1/distill_student_model.h5')
teacher_model = load_model('/kaggle/input/teacher_model/pytorch/default/1/resnet152_plant_disease.h5')

In [22]:
base_model.summary()

In [23]:
distilled_model.summary()

In [24]:
teacher_model.summary()

In [9]:
baseline_eval = base_model.evaluate(test_ds.map(lambda inputs, label: (inputs[1], label)), return_dict=True)
distilled_eval = distilled_model.evaluate(test_ds.map(lambda inputs, label: (inputs[1], label)), return_dict=True)
teacher_eval = teacher_model.evaluate(test_ds.map(lambda inputs, label: (inputs[0], label)), return_dict=True)

[1m78/78[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 29ms/step - loss: 0.8844 - sparse_categorical_accuracy: 0.8284
[1m78/78[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 28ms/step - loss: 0.7393 - sparse_categorical_accuracy: 0.8744
[1m78/78[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 268ms/step - accuracy: 0.9728 - loss: 0.1249


In [25]:


# Prepare a batch of test images (1 or more)
test_images = next(iter(test_ds.take(1)))[0]  # Get images only (without labels)

# Measure teacher model inference time
start_time = time.time()
_ = teacher_model.predict(test_images[0])
teacher_time = time.time() - start_time

# Measure student model inference time
start_time = time.time()
_ = distilled_model.predict(test_images[1])
distilled_model_time = time.time() - start_time

start_time = time.time()
_ = base_model.predict(test_images[1])
base_time = time.time() - start_time


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 7s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


In [26]:
# Create table data
data = [
    {
        "Model": "Teacher (ResNet152)",
        "Train Samples": 12000,
        "Epochs": 12,
        "Input Size": "224x224",
        "Accuracy": round(teacher_eval['accuracy'], 4),
        "Loss": round(teacher_eval['loss'], 4),
        "Model Size (# Parameters)": 58638120,
        "Inference Time": teacher_time
    },
    {
        "Model": "Base Mobilenetv2",
        "Train Samples": 12000,
        "Epochs": 80,
        "Input Size": "224x224",
        "Accuracy": round(baseline_eval['sparse_categorical_accuracy'], 4),
        "Loss": round(baseline_eval['loss'], 4),
        "Model Size (# Parameters)": 2595688,
        "Inference Time": base_time
    },
    {
        "Model": "Student (Distilled_Mobilenetv2)",
        "Train Samples": 12000,
        "Epochs": 80,
        "Input Size": "224x224",
        "Accuracy": round(distilled_eval['sparse_categorical_accuracy'], 4),
        "Loss": round(distilled_eval['loss'], 4),
        "Model Size (# Parameters)": 2595688,
        "Inference Time": distilled_model_time
    }
]

# Create DataFrame
df = pd.DataFrame(data)
df


Unnamed: 0,Model,Train Samples,Epochs,Input Size,Accuracy,Loss,Model Size (# Parameters),Inference Time
0,Teacher (ResNet152),12000,12,224x224,0.9724,0.1417,58638120,7.110332
1,Base Mobilenetv2,12000,80,224x224,0.8277,0.9454,2595688,3.500432
2,Student (Distilled_Mobilenetv2),12000,80,224x224,0.8802,0.737,2595688,2.400066
