In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

In [4]:
def preprocess(image, label):
    image = tf.image.resize(image, (224, 224))
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

ds_train, ds_info = tfds.load("eurosat/rgb", split="train[:80%]", as_supervised=True, with_info=True)
ds_val = tfds.load("eurosat/rgb", split="train[80%:]", as_supervised=True)

ds_train = ds_train.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

num_classes = ds_info.features['label'].num_classes



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/eurosat/rgb/2.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/eurosat/rgb/incomplete.FDP0T8_2.0.0/eurosat-train.tfrecord*...:   0%|     …

Dataset eurosat downloaded and prepared to /root/tensorflow_datasets/eurosat/rgb/2.0.0. Subsequent calls will reuse this data.


In [5]:
base_model = tf.keras.applications.VGG16(include_top=False, input_shape=(224, 224, 3), weights='imagenet')
base_model.trainable = False  # Freeze convolutional base

teacher_model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),  # updated hidden layer
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(num_classes)  # logits (no softmax)
])

teacher_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

print("\n🔧 Training Teacher Model...")
teacher_model.fit(ds_train, validation_data=ds_val, epochs=5)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m58889256/58889256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step

🔧 Training Teacher Model...
Epoch 1/5
[1m675/675[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m168s[0m 228ms/step - accuracy: 0.5342 - loss: 1.3930 - val_accuracy: 0.7524 - val_loss: 0.7126
Epoch 2/5
[1m675/675[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m195s[0m 235ms/step - accuracy: 0.7104 - loss: 0.7979 - val_accuracy: 0.8170 - val_loss: 0.5452
Epoch 3/5
[1m675/675[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m159s[0m 235ms/step - accuracy: 0.7702 - loss: 0.6522 - val_accuracy: 0.8331 - val_loss: 0.4888
Epoch 4/5
[1m675/675[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m190s[0m 217ms/step - accuracy: 0.7825 - loss: 0.6004 - val_accuracy: 0.8363 - val_loss: 0.4608
Epoch 5/5
[1m675/675[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m214s[0m 235ms/ste

<keras.src.callbacks.history.History at 0x7af3f23b8e50>

In [6]:
def get_student_model():
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(),

        tf.keras.layers.Conv2D(128, 3, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(),

        tf.keras.layers.Conv2D(256, 3, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.GlobalAveragePooling2D(),

        tf.keras.layers.Dense(256, activation='relu'),  # match with teacher
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(num_classes)  # logits (no softmax)
    ])

student_model = get_student_model()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [7]:
loss_fn_ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn_msc = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()

epochs = 5
alpha = 0.5  # Student CE loss weight
beta = 1.0   # Teacher guidance loss weight (MSE)

print("\n🎓 Training Student with Knowledge Distillation (MSE)...")
for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    for step, (x_batch, y_batch) in enumerate(ds_train):
        with tf.GradientTape() as tape:
            student_logits = student_model(x_batch, training=True)
            teacher_logits = teacher_model(x_batch, training=False)

            # Cross entropy loss (student vs ground truth)
            loss_l2 = loss_fn_ce(y_batch, student_logits)

            # MSE loss between teacher and student soft predictions
            student_probs = tf.nn.softmax(student_logits)
            teacher_probs = tf.nn.softmax(teacher_logits)
            loss_l1 = loss_fn_msc(teacher_probs, student_probs)

            # Total loss (no alignment)
            total_loss = alpha * loss_l2 + beta * loss_l1

        grads = tape.gradient(total_loss, student_model.trainable_weights)
        optimizer.apply_gradients(zip(grads, student_model.trainable_weights))

        if step % 50 == 0:
            print(f"Step {step}: Total Loss = {total_loss:.4f}, "
                  f"L1 (MSC) = {loss_l1:.4f}, L2 (CE) = {loss_l2:.4f}")


🎓 Training Student with Knowledge Distillation (MSE)...

Epoch 1/5
Step 0: Total Loss = 1.1083, L1 (MSC) = 0.0686, L2 (CE) = 2.0795
Step 50: Total Loss = 1.0549, L1 (MSC) = 0.0585, L2 (CE) = 1.9929
Step 100: Total Loss = 0.5119, L1 (MSC) = 0.0318, L2 (CE) = 0.9602
Step 150: Total Loss = 0.5897, L1 (MSC) = 0.0321, L2 (CE) = 1.1152
Step 200: Total Loss = 0.4071, L1 (MSC) = 0.0270, L2 (CE) = 0.7601
Step 250: Total Loss = 0.6219, L1 (MSC) = 0.0330, L2 (CE) = 1.1778
Step 300: Total Loss = 0.5237, L1 (MSC) = 0.0385, L2 (CE) = 0.9703
Step 350: Total Loss = 0.5651, L1 (MSC) = 0.0329, L2 (CE) = 1.0644
Step 400: Total Loss = 0.5232, L1 (MSC) = 0.0326, L2 (CE) = 0.9812
Step 450: Total Loss = 0.6926, L1 (MSC) = 0.0418, L2 (CE) = 1.3015
Step 500: Total Loss = 0.3021, L1 (MSC) = 0.0267, L2 (CE) = 0.5508
Step 550: Total Loss = 0.5280, L1 (MSC) = 0.0347, L2 (CE) = 0.9866
Step 600: Total Loss = 0.3569, L1 (MSC) = 0.0201, L2 (CE) = 0.6736
Step 650: Total Loss = 0.2529, L1 (MSC) = 0.0210, L2 (CE) = 0.46

In [8]:
# -----------------------------
# 5. 🎯 Evaluate Student
# -----------------------------
print("\n✅ Evaluating Student Model on Validation Set...")
acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

for x_batch, y_batch in ds_val:
    preds = student_model(x_batch, training=False)
    acc_metric.update_state(y_batch, preds)

final_acc = acc_metric.result().numpy()
print(f"\n🎓 Student Accuracy on Validation Set: {final_acc * 100:.2f}%")


✅ Evaluating Student Model on Validation Set...

🎓 Student Accuracy on Validation Set: 86.85%


In [9]:
print("\n✅ Evaluating Teacher Model on Validation Set...")
acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

for x_batch, y_batch in ds_val:
    preds = teacher_model(x_batch, training=False)
    acc_metric.update_state(y_batch, preds)

final_acc = acc_metric.result().numpy()
print(f"\n🎓 Teacher Accuracy on Validation Set: {final_acc * 100:.2f}%")


✅ Evaluating Student Model on Validation Set...

🎓 Student Accuracy on Validation Set: 83.69%
