In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ---------------------------------
# Reproducibility
# ---------------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------------------------------
# Tiny Efficient CNN (generic)
# ---------------------------------
class TinyEfficientCNN(nn.Module):
    """Small depthwise-separable CNN classifier."""

    def __init__(self, num_classes, in_channels=3):
        super().__init__()

        def conv_dw(inp, oup, stride=1):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU6(inplace=True),

                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU6(inplace=True),
            )

        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True),

            conv_dw(32, 64, 1),
            conv_dw(64, 128, 2),
            conv_dw(128, 256, 2),
            conv_dw(256, 512, 2),
            nn.AdaptiveAvgPool2d(1),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


# ---------------------------------
# Train / Test functions
# ---------------------------------
def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0

    for data, target in loader:
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.size(0)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def test_epoch(model, loader):
    model.eval()
    total_loss = 0.0
    correct = 0

    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)

        total_loss += F.cross_entropy(output, target, reduction='sum').item()
        pred = output.argmax(1)
        correct += (pred == target).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    accuracy = 100.0 * correct / len(loader.dataset)
    return avg_loss, accuracy


# ---------------------------------
# Dataset + Loader
# ---------------------------------
def build_dataloaders(train_root, test_root, img_size=128, batch_size=128):
    # Separate transforms for train/test
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

    train_ds = datasets.ImageFolder(root=train_root, transform=train_transform)
    test_ds = datasets.ImageFolder(root=test_root, transform=test_transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
                             num_workers=2, pin_memory=True)

    return train_ds, test_ds, train_loader, test_loader

Baseline Models Comparative Study

In [2]:
import torch
import torch.nn as nn
from torchvision import models


# ---------------------------------
# Create model given a name
# ---------------------------------
def build_model(name, num_classes):
    name = name.lower()

    if name == "tinycnn":
        return TinyEfficientCNN(num_classes=num_classes)

    elif name == "mobilenetv3_small":
        model = models.mobilenet_v3_small(pretrained=True)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        return model

    elif name == "mobilenet_v2":
        model = models.mobilenet_v2(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    elif name == "efficientnet_b0":
        model = models.efficientnet_b0(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    elif name == "resnet18":
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model

    elif name == "convnext_tiny":
        model = models.convnext_tiny(pretrained=True)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
        return model

    else:
        raise ValueError(f"Unknown model name: {name}")


def main():

    train_root = "/home/ifran/Projects_UBUNTU/emnss_dublin/clouds_train"
    test_root  = "/home/ifran/Projects_UBUNTU/emnss_dublin/clouds_test"

    num_epochs = 100
    batch_size = 256

    # -----------------------------
    # Load dataset ONCE
    # -----------------------------
    train_ds, test_ds, train_loader, test_loader = build_dataloaders(
        train_root, test_root, batch_size=batch_size
    )
    num_classes = len(train_ds.classes)

    # -----------------------------
    # Models to benchmark
    # -----------------------------
    model_names = [
        "convnext_tiny",        # ~28M params
        "resnet18",             # ~11.7M params
        "efficientnet_b0",      # ~5.3M params
        "mobilenet_v2",         # ~3.4M params
        "mobilenetv3_small",    # ~2.5M params
        "tinycnn",              # your custom smallest model
    ]


    # -----------------------------
    # Train each one in sequence
    # -----------------------------
    # Early stopping settings
    patience = 30
    wait = 0

    for name in model_names:
        print("\n" + "="*60)
        print(f"üî• Training model: {name}")
        print("="*60)

        # Build model
        model = build_model(name, num_classes).to(device)

        # Optimizer & Scheduler
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

        best_acc = 0.0
        wait = 0  # reset early-stopping counter

        if not os.path.exists('Checkpoints'):
            os.mkdir('Checkpoints')
        save_path = f"Checkpoints/best_{name}.pth"

        # -------------------------
        # Training loop + Early Stopping
        # -------------------------
        for epoch in range(num_epochs):
            train_loss = train_epoch(model, train_loader, optimizer)
            test_loss, test_acc = test_epoch(model, test_loader)
            scheduler.step()

            print(f"{name} | Epoch {epoch}: "
                f"Train={train_loss:.4f}  Test={test_loss:.4f}  Acc={test_acc:.2f}%")

            # Check for improvement
            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(model.state_dict(), save_path)
                print(f"‚úî Saved best ({best_acc:.2f}%) ‚Üí {save_path}")
                wait = 0  # reset wait counter
            else:
                wait += 1

            # Trigger early stopping
            if wait >= patience:
                print(f"‚õî Early stopping triggered after {patience} epochs without improvement.")
                break

        print(f"üèÅ Finished training {name}. Best Acc = {best_acc:.2f}%")

    print("\nüéâ All models trained successfully!")



if __name__ == "__main__":
    main()



üî• Training model: convnext_tiny




convnext_tiny | Epoch 0: Train=1.7921  Test=2.4181  Acc=14.20%
‚úî Saved best (14.20%) ‚Üí Checkpoints/best_convnext_tiny.pth
convnext_tiny | Epoch 1: Train=2.2456  Test=1.7058  Acc=21.81%
‚úî Saved best (21.81%) ‚Üí Checkpoints/best_convnext_tiny.pth
convnext_tiny | Epoch 2: Train=1.4901  Test=1.0552  Acc=66.26%
‚úî Saved best (66.26%) ‚Üí Checkpoints/best_convnext_tiny.pth
convnext_tiny | Epoch 3: Train=0.9330  Test=1.1322  Acc=59.47%
convnext_tiny | Epoch 4: Train=0.8495  Test=0.5283  Acc=82.10%
‚úî Saved best (82.10%) ‚Üí Checkpoints/best_convnext_tiny.pth
convnext_tiny | Epoch 5: Train=0.5049  Test=0.3188  Acc=89.51%
‚úî Saved best (89.51%) ‚Üí Checkpoints/best_convnext_tiny.pth
convnext_tiny | Epoch 6: Train=0.2691  Test=0.3834  Acc=85.19%
convnext_tiny | Epoch 7: Train=0.2774  Test=0.2558  Acc=89.30%
convnext_tiny | Epoch 8: Train=0.2175  Test=0.1807  Acc=93.42%
‚úî Saved best (93.42%) ‚Üí Checkpoints/best_convnext_tiny.pth
convnext_tiny | Epoch 9: Train=0.1254  Test=0.3309  Acc



resnet18 | Epoch 0: Train=1.3590  Test=1.0582  Acc=70.58%
‚úî Saved best (70.58%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 1: Train=0.2121  Test=1.2023  Acc=72.02%
‚úî Saved best (72.02%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 2: Train=0.0732  Test=1.8390  Acc=72.84%
‚úî Saved best (72.84%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 3: Train=0.0297  Test=1.6225  Acc=77.16%
‚úî Saved best (77.16%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 4: Train=0.0414  Test=1.7686  Acc=74.49%
resnet18 | Epoch 5: Train=0.0349  Test=1.3623  Acc=75.72%
resnet18 | Epoch 6: Train=0.0160  Test=1.1133  Acc=81.07%
‚úî Saved best (81.07%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 7: Train=0.0204  Test=0.9064  Acc=84.57%
‚úî Saved best (84.57%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 8: Train=0.0576  Test=0.5136  Acc=90.33%
‚úî Saved best (90.33%) ‚Üí Checkpoints/best_resnet18.pth
resnet18 | Epoch 9: Train=0.0407  Test=0.7782  Acc=88.07%
resnet18 | Epo



efficientnet_b0 | Epoch 0: Train=1.7604  Test=1.7490  Acc=39.71%
‚úî Saved best (39.71%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 1: Train=0.7718  Test=1.3383  Acc=49.79%
‚úî Saved best (49.79%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 2: Train=0.3033  Test=0.9530  Acc=70.78%
‚úî Saved best (70.78%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 3: Train=0.1257  Test=0.6427  Acc=83.54%
‚úî Saved best (83.54%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 4: Train=0.0630  Test=0.4963  Acc=88.48%
‚úî Saved best (88.48%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 5: Train=0.0353  Test=0.4038  Acc=90.53%
‚úî Saved best (90.53%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 6: Train=0.0187  Test=0.3625  Acc=91.77%
‚úî Saved best (91.77%) ‚Üí Checkpoints/best_efficientnet_b0.pth
efficientnet_b0 | Epoch 7: Train=0.0075  Test=0.3448  Acc=92.80%
‚úî Saved best (92.80%) ‚



mobilenet_v2 | Epoch 0: Train=1.4555  Test=3.1053  Acc=22.63%
‚úî Saved best (22.63%) ‚Üí Checkpoints/best_mobilenet_v2.pth
mobilenet_v2 | Epoch 1: Train=0.2548  Test=3.5626  Acc=30.86%
‚úî Saved best (30.86%) ‚Üí Checkpoints/best_mobilenet_v2.pth
mobilenet_v2 | Epoch 2: Train=0.0814  Test=3.1691  Acc=45.88%
‚úî Saved best (45.88%) ‚Üí Checkpoints/best_mobilenet_v2.pth
mobilenet_v2 | Epoch 3: Train=0.0645  Test=1.9868  Acc=64.61%
‚úî Saved best (64.61%) ‚Üí Checkpoints/best_mobilenet_v2.pth
mobilenet_v2 | Epoch 4: Train=0.0674  Test=1.2558  Acc=76.54%
‚úî Saved best (76.54%) ‚Üí Checkpoints/best_mobilenet_v2.pth
mobilenet_v2 | Epoch 5: Train=0.0173  Test=1.3465  Acc=74.07%
mobilenet_v2 | Epoch 6: Train=0.0533  Test=0.6803  Acc=87.24%
‚úî Saved best (87.24%) ‚Üí Checkpoints/best_mobilenet_v2.pth
mobilenet_v2 | Epoch 7: Train=0.0111  Test=0.6086  Acc=85.80%
mobilenet_v2 | Epoch 8: Train=0.0540  Test=0.5899  Acc=87.24%
mobilenet_v2 | Epoch 9: Train=0.0063  Test=0.6666  Acc=87.65%
‚úî Save



mobilenetv3_small | Epoch 0: Train=1.7575  Test=1.6931  Acc=49.38%
‚úî Saved best (49.38%) ‚Üí Checkpoints/best_mobilenetv3_small.pth
mobilenetv3_small | Epoch 1: Train=0.7740  Test=1.3488  Acc=61.73%
‚úî Saved best (61.73%) ‚Üí Checkpoints/best_mobilenetv3_small.pth
mobilenetv3_small | Epoch 2: Train=0.3557  Test=1.0918  Acc=62.35%
‚úî Saved best (62.35%) ‚Üí Checkpoints/best_mobilenetv3_small.pth
mobilenetv3_small | Epoch 3: Train=0.1591  Test=1.0418  Acc=61.11%
mobilenetv3_small | Epoch 4: Train=0.0839  Test=1.1039  Acc=62.35%
mobilenetv3_small | Epoch 5: Train=0.0424  Test=1.2435  Acc=61.73%
mobilenetv3_small | Epoch 6: Train=0.0173  Test=1.3978  Acc=61.32%
mobilenetv3_small | Epoch 7: Train=0.0107  Test=1.5200  Acc=61.32%
mobilenetv3_small | Epoch 8: Train=0.0033  Test=1.5890  Acc=61.93%
mobilenetv3_small | Epoch 9: Train=0.0029  Test=1.6417  Acc=62.76%
‚úî Saved best (62.76%) ‚Üí Checkpoints/best_mobilenetv3_small.pth
mobilenetv3_small | Epoch 10: Train=0.0010  Test=1.6882  Acc=6

Efficiency-Accuracy Results

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
from ptflops import get_model_complexity_info
import time


device_gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_cpu = torch.device("cpu")


# -------------------------------------------------
# Load models (same as training)
# -------------------------------------------------
def build_model(name, num_classes):
    name = name.lower()

    if name == "tinycnn":
        return TinyEfficientCNN(num_classes=num_classes)

    elif name == "mobilenetv3_small":
        model = models.mobilenet_v3_small(pretrained=False)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        return model

    elif name == "mobilenet_v2":
        model = models.mobilenet_v2(pretrained=False)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    elif name == "efficientnet_b0":
        model = models.efficientnet_b0(pretrained=False)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    elif name == "resnet18":
        model = models.resnet18(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model

    elif name == "convnext_tiny":
        model = models.convnext_tiny(pretrained=False)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
        return model

    else:
        raise ValueError("Unknown model: " + name)


# -------------------------------------------------
# Accuracy Evaluation
# -------------------------------------------------
@torch.no_grad()
def compute_accuracy(model, dataloader, device):
    model.eval().to(device)
    correct, total = 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return 100 * correct / total


# -------------------------------------------------
# Latency measurement
# -------------------------------------------------
def measure_latency(model, device, input_shape=(1, 3, 128, 128), runs=50):
    model.to(device).eval()

    dummy = torch.randn(*input_shape).to(device)

    # Warm-up
    for _ in range(10):
        _ = model(dummy)

    # Timed runs
    start = time.time()
    for _ in range(runs):
        _ = model(dummy)
    end = time.time()

    return (end - start) / runs  # seconds per inference


# -------------------------------------------------
# Main Evaluation
# -------------------------------------------------
def main():

    # --------------------------
    # Dataset
    # --------------------------
    test_root = "/home/ifran/Projects_UBUNTU/emnss_dublin/clouds_test"

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

    test_ds = datasets.ImageFolder(test_root, transform=transform)
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

    num_classes = len(test_ds.classes)

    # --------------------------
    # Models to evaluate
    # --------------------------
    model_names = [
        "tinycnn",
        "mobilenetv3_small",
        "mobilenet_v2",
        "efficientnet_b0",
        "resnet18",
        "convnext_tiny",
    ]

    # --------------------------
    # Loop over all models
    # --------------------------
    for name in model_names:

        print("\n" + "="*60)
        print(f" Evaluating Model: {name}")
        print("="*60)

        # Build + load checkpoint
        model = build_model(name, num_classes)
        ckpt = f"Checkpoints/best_{name}.pth"
        model.load_state_dict(torch.load(ckpt, map_location="cpu"))

        # ------------------------------
        # Params
        # ------------------------------
        num_params = sum(p.numel() for p in model.parameters())
        print(f"Parameters: {num_params/1e6:.3f} M")

        # ------------------------------
        # FLOPs
        # ------------------------------
        with torch.cuda.device(0):
            macs, params = get_model_complexity_info(
                model, (3, 128, 128),
                as_strings=False,
                print_per_layer_stat=False
            )
        print(f"FLOPs: {macs/1e6:.2f} MFLOPs")

        # ------------------------------
        # Accuracy (GPU if available)
        # ------------------------------
        acc_gpu = compute_accuracy(model, test_loader, device_gpu)
        print(f"Accuracy (GPU): {acc_gpu:.2f}%")

        # ------------------------------
        # Latency
        # ------------------------------
        lat_gpu = measure_latency(model, device_gpu)
        lat_cpu = measure_latency(model, device_cpu)

        print(f"Latency (GPU): {lat_gpu*1000:.2f} ms")
        print(f"Latency (CPU ‚Äì edge-like): {lat_cpu*1000:.2f} ms")

        print("="*60)


if __name__ == "__main__":
    main()



 Evaluating Model: tinycnn
Parameters: 0.186 M
FLOPs: 41.98 MFLOPs


  model.load_state_dict(torch.load(ckpt, map_location="cpu"))


Accuracy (GPU): 83.95%
Latency (GPU): 1.52 ms
Latency (CPU ‚Äì edge-like): 2.18 ms

 Evaluating Model: mobilenetv3_small
Parameters: 1.525 M
FLOPs: 20.27 MFLOPs




Accuracy (GPU): 88.07%
Latency (GPU): 5.73 ms
Latency (CPU ‚Äì edge-like): 3.74 ms

 Evaluating Model: mobilenet_v2
Parameters: 2.233 M
FLOPs: 104.18 MFLOPs
Accuracy (GPU): 93.62%
Latency (GPU): 6.11 ms
Latency (CPU ‚Äì edge-like): 7.75 ms

 Evaluating Model: efficientnet_b0
Parameters: 4.017 M
FLOPs: 133.96 MFLOPs
Accuracy (GPU): 95.88%
Latency (GPU): 9.35 ms
Latency (CPU ‚Äì edge-like): 11.72 ms

 Evaluating Model: resnet18
Parameters: 11.180 M
FLOPs: 595.86 MFLOPs
Accuracy (GPU): 93.62%
Latency (GPU): 2.62 ms
Latency (CPU ‚Äì edge-like): 7.57 ms

 Evaluating Model: convnext_tiny
Parameters: 27.826 M
FLOPs: 1465.34 MFLOPs
Accuracy (GPU): 94.44%
Latency (GPU): 5.70 ms
Latency (CPU ‚Äì edge-like): 18.91 ms


Self-Distillation

In [2]:
import torch
import torch.nn as nn
from torchvision import models

# --------------------------------------------------------
# Simple Knowledge Distillation Loss
# --------------------------------------------------------
def kd_loss(student_logits, teacher_logits, T=4.0, alpha=0.5):
    """
    student_logits: raw outputs of student
    teacher_logits: raw outputs of teacher
    """
    KD = nn.KLDivLoss(reduction="batchmean")
    log_p_s = nn.functional.log_softmax(student_logits / T, dim=1)
    p_t = nn.functional.softmax(teacher_logits / T, dim=1)

    soft_loss = KD(log_p_s, p_t) * (T * T)
    return alpha * soft_loss


# --------------------------------------------------------
# Build model (unchanged)
# --------------------------------------------------------
def build_model(name, num_classes):
    name = name.lower()

    if name == "tinycnn":
        return TinyEfficientCNN(num_classes=num_classes)

    elif name == "mobilenetv3_small":
        model = models.mobilenet_v3_small(pretrained=True)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        return model

    elif name == "mobilenet_v2":
        model = models.mobilenet_v2(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    elif name == "efficientnet_b0":
        model = models.efficientnet_b0(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    elif name == "resnet18":
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model

    elif name == "convnext_tiny":
        model = models.convnext_tiny(pretrained=True)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
        return model

    else:
        raise ValueError(f"Unknown model name: {name}")


# =========================================================
# Main
# =========================================================
def main():

    train_root = "/home/ifran/Projects_UBUNTU/emnss_dublin/clouds_train"
    test_root  = "/home/ifran/Projects_UBUNTU/emnss_dublin/clouds_test"

    num_epochs = 100
    batch_size = 256
    patience = 30

    # -----------------------------
    # Load dataset ONCE
    # -----------------------------
    train_ds, test_ds, train_loader, test_loader = build_dataloaders(
        train_root, test_root, batch_size=batch_size
    )
    num_classes = len(train_ds.classes)

    # -----------------------------
    # Models to benchmark
    # -----------------------------
    model_names = [
        "convnext_tiny",
        "resnet18",
        "efficientnet_b0",
        "mobilenet_v2",
        "mobilenetv3_small",
        "tinycnn",
    ]


    # =====================================================
    # Part 1 ‚Äî For each model: TRAIN teacher model
    # =====================================================
    for name in model_names:

        print("\n" + "="*70)
        print(f"üî• Training TEACHER model: {name}")
        print("="*70)

        teacher = build_model(name, num_classes).to(device)

        optimizer = torch.optim.Adam(teacher.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

        best_acc = 0.0
        wait = 0

        if not os.path.exists('Checkpoints'):
            os.mkdir('Checkpoints')
        save_path = f"Checkpoints/best_{name}.pth"

        for epoch in range(num_epochs):
            train_loss = train_epoch(teacher, train_loader, optimizer)
            test_loss, test_acc = test_epoch(teacher, test_loader)
            scheduler.step()

            print(f"{name} Teacher | Epoch {epoch} "
                  f"Train={train_loss:.4f}  Test={test_loss:.4f}  Acc={test_acc:.2f}%")

            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(teacher.state_dict(), save_path)
                print(f"‚úî Saved TEACHER checkpoint ({best_acc:.2f}%)")
                wait = 0
            else:
                wait += 1

            if wait >= patience:
                print("‚õî Early stopping (teacher)")
                break

        print(f"üèÅ Finished TEACHER for {name}. Best Acc={best_acc:.2f}%")



    # =====================================================
    # Part 2 ‚Äî For each model: TRAIN student with KD
    # =====================================================
    for name in model_names:

        print("\n" + "="*70)
        print(f"üéì Training STUDENT model with KD: {name}")
        print("="*70)

        # Load teacher
        teacher = build_model(name, num_classes).to(device)
        teacher.load_state_dict(torch.load(f"Checkpoints/best_{name}.pth"))
        teacher.eval()

        # Build fresh student
        student = build_model(name, num_classes).to(device)

        optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

        best_acc = 0.0
        wait = 0

        save_path = f"Checkpoints/student_{name}.pth"

        for epoch in range(num_epochs):
            student.train()

            total_loss = 0
            for imgs, labels in train_loader:
                imgs, labels = imgs.to(device), labels.to(device)

                optimizer.zero_grad()

                # Forward
                s_logits = student(imgs)
                with torch.no_grad():
                    t_logits = teacher(imgs)

                # Combined KD + CE loss
                ce_loss = nn.CrossEntropyLoss()(s_logits, labels)
                distill = kd_loss(s_logits, t_logits)

                loss = ce_loss + distill

                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            # Evaluate student
            test_loss, test_acc = test_epoch(student, test_loader)
            scheduler.step()

            print(f"{name} Student | Epoch {epoch}: "
                  f"TrainLoss={total_loss/len(train_loader):.4f}  Acc={test_acc:.2f}%")

            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(student.state_dict(), save_path)
                print(f"‚úî Saved STUDENT checkpoint ({best_acc:.2f}%)")
                wait = 0
            else:
                wait += 1

            if wait >= patience:
                print("‚õî Early stopping (student)")
                break

        print(f"üèÅ Finished STUDENT for {name}. Best Acc={best_acc:.2f}%")


    print("\nüéâ All teacher‚Äìstudent KD training done!")


# --------------------------------------------------------
if __name__ == "__main__":
    main()



üî• Training TEACHER model: convnext_tiny




convnext_tiny Teacher | Epoch 0 Train=1.7921  Test=2.4181  Acc=14.20%
‚úî Saved TEACHER checkpoint (14.20%)
convnext_tiny Teacher | Epoch 1 Train=2.2456  Test=1.7058  Acc=21.81%
‚úî Saved TEACHER checkpoint (21.81%)
convnext_tiny Teacher | Epoch 2 Train=1.4901  Test=1.0552  Acc=66.26%
‚úî Saved TEACHER checkpoint (66.26%)
convnext_tiny Teacher | Epoch 3 Train=0.9330  Test=1.1322  Acc=59.47%
convnext_tiny Teacher | Epoch 4 Train=0.8495  Test=0.5283  Acc=82.10%
‚úî Saved TEACHER checkpoint (82.10%)
convnext_tiny Teacher | Epoch 5 Train=0.5049  Test=0.3188  Acc=89.51%
‚úî Saved TEACHER checkpoint (89.51%)
convnext_tiny Teacher | Epoch 6 Train=0.2691  Test=0.3834  Acc=85.19%
convnext_tiny Teacher | Epoch 7 Train=0.2774  Test=0.2558  Acc=89.30%
convnext_tiny Teacher | Epoch 8 Train=0.2175  Test=0.1807  Acc=93.42%
‚úî Saved TEACHER checkpoint (93.42%)
convnext_tiny Teacher | Epoch 9 Train=0.1254  Test=0.3309  Acc=87.45%
convnext_tiny Teacher | Epoch 10 Train=0.1288  Test=0.2531  Acc=90.95%
c



resnet18 Teacher | Epoch 0 Train=1.3590  Test=1.0582  Acc=70.58%
‚úî Saved TEACHER checkpoint (70.58%)
resnet18 Teacher | Epoch 1 Train=0.2121  Test=1.2023  Acc=72.02%
‚úî Saved TEACHER checkpoint (72.02%)
resnet18 Teacher | Epoch 2 Train=0.0732  Test=1.8390  Acc=72.84%
‚úî Saved TEACHER checkpoint (72.84%)
resnet18 Teacher | Epoch 3 Train=0.0297  Test=1.6225  Acc=77.16%
‚úî Saved TEACHER checkpoint (77.16%)
resnet18 Teacher | Epoch 4 Train=0.0414  Test=1.7686  Acc=74.49%
resnet18 Teacher | Epoch 5 Train=0.0349  Test=1.3623  Acc=75.72%
resnet18 Teacher | Epoch 6 Train=0.0160  Test=1.1133  Acc=81.07%
‚úî Saved TEACHER checkpoint (81.07%)
resnet18 Teacher | Epoch 7 Train=0.0204  Test=0.9064  Acc=84.57%
‚úî Saved TEACHER checkpoint (84.57%)
resnet18 Teacher | Epoch 8 Train=0.0576  Test=0.5136  Acc=90.33%
‚úî Saved TEACHER checkpoint (90.33%)
resnet18 Teacher | Epoch 9 Train=0.0407  Test=0.7782  Acc=88.07%
resnet18 Teacher | Epoch 10 Train=0.0316  Test=2.0481  Acc=77.37%
resnet18 Teacher |



efficientnet_b0 Teacher | Epoch 0 Train=1.7604  Test=1.7490  Acc=39.71%
‚úî Saved TEACHER checkpoint (39.71%)
efficientnet_b0 Teacher | Epoch 1 Train=0.7718  Test=1.3383  Acc=49.79%
‚úî Saved TEACHER checkpoint (49.79%)
efficientnet_b0 Teacher | Epoch 2 Train=0.3033  Test=0.9530  Acc=70.78%
‚úî Saved TEACHER checkpoint (70.78%)
efficientnet_b0 Teacher | Epoch 3 Train=0.1257  Test=0.6427  Acc=83.54%
‚úî Saved TEACHER checkpoint (83.54%)
efficientnet_b0 Teacher | Epoch 4 Train=0.0630  Test=0.4963  Acc=88.48%
‚úî Saved TEACHER checkpoint (88.48%)
efficientnet_b0 Teacher | Epoch 5 Train=0.0353  Test=0.4038  Acc=90.53%
‚úî Saved TEACHER checkpoint (90.53%)
efficientnet_b0 Teacher | Epoch 6 Train=0.0187  Test=0.3625  Acc=91.77%
‚úî Saved TEACHER checkpoint (91.77%)
efficientnet_b0 Teacher | Epoch 7 Train=0.0075  Test=0.3448  Acc=92.80%
‚úî Saved TEACHER checkpoint (92.80%)
efficientnet_b0 Teacher | Epoch 8 Train=0.0075  Test=0.3342  Acc=92.80%
efficientnet_b0 Teacher | Epoch 9 Train=0.0014  



mobilenet_v2 Teacher | Epoch 0 Train=1.4555  Test=3.1053  Acc=22.63%
‚úî Saved TEACHER checkpoint (22.63%)
mobilenet_v2 Teacher | Epoch 1 Train=0.2548  Test=3.5626  Acc=30.86%
‚úî Saved TEACHER checkpoint (30.86%)
mobilenet_v2 Teacher | Epoch 2 Train=0.0814  Test=3.1691  Acc=45.88%
‚úî Saved TEACHER checkpoint (45.88%)
mobilenet_v2 Teacher | Epoch 3 Train=0.0645  Test=1.9868  Acc=64.61%
‚úî Saved TEACHER checkpoint (64.61%)
mobilenet_v2 Teacher | Epoch 4 Train=0.0674  Test=1.2558  Acc=76.54%
‚úî Saved TEACHER checkpoint (76.54%)
mobilenet_v2 Teacher | Epoch 5 Train=0.0173  Test=1.3465  Acc=74.07%
mobilenet_v2 Teacher | Epoch 6 Train=0.0533  Test=0.6803  Acc=87.24%
‚úî Saved TEACHER checkpoint (87.24%)
mobilenet_v2 Teacher | Epoch 7 Train=0.0111  Test=0.6086  Acc=85.80%
mobilenet_v2 Teacher | Epoch 8 Train=0.0540  Test=0.5899  Acc=87.24%
mobilenet_v2 Teacher | Epoch 9 Train=0.0063  Test=0.6666  Acc=87.65%
‚úî Saved TEACHER checkpoint (87.65%)
mobilenet_v2 Teacher | Epoch 10 Train=0.0073



mobilenetv3_small Teacher | Epoch 0 Train=1.7575  Test=1.6931  Acc=49.38%
‚úî Saved TEACHER checkpoint (49.38%)
mobilenetv3_small Teacher | Epoch 1 Train=0.7740  Test=1.3488  Acc=61.73%
‚úî Saved TEACHER checkpoint (61.73%)
mobilenetv3_small Teacher | Epoch 2 Train=0.3557  Test=1.0918  Acc=62.35%
‚úî Saved TEACHER checkpoint (62.35%)
mobilenetv3_small Teacher | Epoch 3 Train=0.1591  Test=1.0418  Acc=61.11%
mobilenetv3_small Teacher | Epoch 4 Train=0.0839  Test=1.1039  Acc=62.35%
mobilenetv3_small Teacher | Epoch 5 Train=0.0424  Test=1.2435  Acc=61.73%
mobilenetv3_small Teacher | Epoch 6 Train=0.0173  Test=1.3978  Acc=61.32%
mobilenetv3_small Teacher | Epoch 7 Train=0.0107  Test=1.5200  Acc=61.32%
mobilenetv3_small Teacher | Epoch 8 Train=0.0033  Test=1.5890  Acc=61.93%
mobilenetv3_small Teacher | Epoch 9 Train=0.0029  Test=1.6417  Acc=62.76%
‚úî Saved TEACHER checkpoint (62.76%)
mobilenetv3_small Teacher | Epoch 10 Train=0.0010  Test=1.6882  Acc=63.99%
‚úî Saved TEACHER checkpoint (63.

  teacher.load_state_dict(torch.load(f"Checkpoints/best_{name}.pth"))


convnext_tiny Student | Epoch 0: TrainLoss=8.4438  Acc=48.56%
‚úî Saved STUDENT checkpoint (48.56%)
convnext_tiny Student | Epoch 1: TrainLoss=7.6470  Acc=70.78%
‚úî Saved STUDENT checkpoint (70.78%)
convnext_tiny Student | Epoch 2: TrainLoss=6.3423  Acc=44.03%
convnext_tiny Student | Epoch 3: TrainLoss=8.1730  Acc=69.55%
convnext_tiny Student | Epoch 4: TrainLoss=4.3537  Acc=79.01%
‚úî Saved STUDENT checkpoint (79.01%)
convnext_tiny Student | Epoch 5: TrainLoss=3.3909  Acc=81.28%
‚úî Saved STUDENT checkpoint (81.28%)
convnext_tiny Student | Epoch 6: TrainLoss=2.5318  Acc=87.24%
‚úî Saved STUDENT checkpoint (87.24%)
convnext_tiny Student | Epoch 7: TrainLoss=1.6634  Acc=90.53%
‚úî Saved STUDENT checkpoint (90.53%)
convnext_tiny Student | Epoch 8: TrainLoss=1.1182  Acc=91.77%
‚úî Saved STUDENT checkpoint (91.77%)
convnext_tiny Student | Epoch 9: TrainLoss=0.8372  Acc=89.51%
convnext_tiny Student | Epoch 10: TrainLoss=0.7950  Acc=91.15%
convnext_tiny Student | Epoch 11: TrainLoss=0.3489 