In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, models, datasets
from torchvision.datasets import ImageFolder
import os
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
import time
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm
import kagglehub

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 2

from google.colab import drive
drive.mount('/content/drive', force_remount=False)
SAVE_DIR = '/content/drive/MyDrive/comp576final'

path = kagglehub.dataset_download("emmarex/plantdisease")
DATA_DIR = os.path.join(path, 'PlantVillage')

transform_eval = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class DatasetWrapper(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

def load_test_data(data_dir, save_dir):
    raw_dataset = ImageFolder(root=data_dir)
    classes = raw_dataset.classes
    split_path = os.path.join(save_dir, 'data_split_indices.pth')
    split_indices = torch.load(split_path)
    test_indices = split_indices['test_indices']
    test_subset = Subset(raw_dataset, test_indices)
    test_dataset = DatasetWrapper(test_subset, transform=transform_eval)
    use_pin_memory = (device.type == 'cuda')
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=use_pin_memory)
    return test_loader, test_dataset, classes

def create_resnet18_model(num_classes, pretrained=False):
    model = models.resnet18(weights=None) # We load weights later
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

class ResNet18Plant(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super().__init__()
        backbone = models.resnet18(weights=None)
        self.features = nn.Sequential(*list(backbone.children())[:-2])
        in_channels = 512
        self.classifier = nn.Sequential(
            nn.Linear(in_channels * 2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        gap = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        gmp = F.adaptive_max_pool2d(x, 1).view(x.size(0), -1)
        feat = torch.cat([gap, gmp], dim=1)
        out = self.classifier(feat)
        return out

class PlantCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
        )
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        out = self.classifier(x)
        return out

def measure_latency(model, device, dataset, num_samples=100):
    model.eval()
    subset_indices = range(min(num_samples, len(dataset)))
    subset = Subset(dataset, subset_indices)
    latency_loader = DataLoader(subset, batch_size=1, shuffle=False,
                              num_workers=1, pin_memory=False)
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    for _ in range(10):
        with torch.no_grad():
            _ = model(dummy_input)
    total_time = 0
    count = 0

    with torch.no_grad():
        for inputs, _ in latency_loader:
            inputs = inputs.to(device)

            if device.type == 'cuda':
                torch.cuda.synchronize()
            start_time = time.time()

            _ = model(inputs)

            if device.type == 'cuda':
                torch.cuda.synchronize()
            end_time = time.time()

            total_time += (end_time - start_time)
            count += 1

    avg_latency = (total_time / count) * 1000 # ms
    return avg_latency

def evaluate_model(model_path, model, dataloader, dataset, device, classes):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    latency_ms = measure_latency(model, device, dataset)

    all_preds = []
    all_labels = []
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Inference"):
            inputs = inputs.to(device)
            labels = list(labels.numpy())
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels)
            total_samples += inputs.size(0)

    report = classification_report(all_labels, all_preds, target_names=classes, output_dict=True)
    cm = confusion_matrix(all_labels, all_preds)
    precision = report['macro avg']['precision']
    recall = report['macro avg']['recall']
    accuracy = report['accuracy']
    return {
        'model_name': os.path.basename(model_path),
        'inference_time_ms': latency_ms,
        'precision': precision,
        'recall': recall,
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'all_labels': all_labels,
        'all_preds': all_preds,
        'full_report': report
    }

def plot_confusion_matrix(cm, classes, model_name):
    plt.figure(figsize=(12, 10))
    annot = True if len(classes) < 20 else False
    sns.heatmap(cm, annot=annot, fmt='d', cmap='Blues',
                xticklabels=classes if len(classes) < 50 else [],
                yticklabels=classes if len(classes) < 50 else [])
    plt.title(f'Confusion Matrix - {model_name}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    save_path = os.path.join(SAVE_DIR, f'confusion_matrix_{model_name}.png')
    plt.savefig(save_path)
    plt.close()

def plot_samples(model, dataloader, classes, model_name, device):
    model.eval()
    images, labels = next(iter(dataloader))
    indices = random.sample(range(len(images)), 5)
    selected_images = images[indices]
    selected_labels = labels[indices]

    with torch.no_grad():
        outputs = model(selected_images.to(device))
        _, preds = torch.max(outputs, 1)

    plt.figure(figsize=(15, 5))
    for i in range(5):
        ax = plt.subplot(1, 5, i + 1)
        img = selected_images[i].cpu().permute(1, 2, 0).numpy()
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        true_name = classes[selected_labels[i]]
        pred_name = classes[preds[i]]
        color = 'green' if true_name == pred_name else 'red'
        plt.imshow(img)
        plt.title(f"True: {true_name}\nPred: {pred_name}", color=color, fontsize=10)
        plt.axis('off')

    plt.suptitle(f'Sample Predictions - {model_name}', fontsize=16)
    plt.tight_layout()
    save_path = os.path.join(SAVE_DIR, f'samples_{model_name}.png')
    plt.savefig(save_path)
    plt.close()

def main():
    test_loader, test_dataset, classes = load_test_data(DATA_DIR, SAVE_DIR)
    num_classes = len(classes)

    models_to_eval = [
        {
            'path': 'best_baseline_resnet18.pth',
            'model': create_resnet18_model(num_classes)
        },
        {
            'path': 'best_improved_resnet18plant.pth',
            'model': ResNet18Plant(num_classes)
        },
        {
            'path': 'best_cnn_model.pth',
            'model': PlantCNN(num_classes)
        },
        {
            'path': 'best_customcnn_distilled.pth',
            'model': PlantCNN(num_classes)
        },
        {
            'path': 'best_customcnn_distilled_real_baseline.pth',
            'model': PlantCNN(num_classes)
        }
    ]

    results = []

    for item in models_to_eval:
        full_path = os.path.join(SAVE_DIR, item['path'])

        res = evaluate_model(full_path, item['model'], test_loader, test_dataset, device, classes)
        if res:
            results.append(res)

            plot_confusion_matrix(res['confusion_matrix'], classes, res['model_name'])
            plot_samples(item['model'], test_loader, classes, res['model_name'], device)

    for res in results:
        # Calculate params and size to display
        model_name = res['model_name']
        matching_item = next(item for item in models_to_eval if item['path'] == model_name)
        model = matching_item['model']
        total_params = sum(p.numel() for p in model.parameters())
        param_millions = total_params / 1e6
        size_mb = (total_params * 4) / (1024 * 1024)

        print(f"Model: {res['model_name']}")
        print(f"Accuracy: {res['accuracy']*100:.2f}%")
        print(f"Precision: {res['precision']:.4f}")
        print(f"Recall: {res['recall']:.4f}")
        print(f"Latency (BS=1): {res['inference_time_ms']:.2f} ms")
        print(f"Params: {param_millions:.2f} M")
        print(f"Size: {size_mb:.2f} MB")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using Colab cache for faster access to the 'plantdisease' dataset.


Inference: 100%|██████████| 65/65 [03:27<00:00,  3.19s/it]
Inference: 100%|██████████| 65/65 [03:22<00:00,  3.12s/it]
Inference: 100%|██████████| 65/65 [03:00<00:00,  2.77s/it]
Inference: 100%|██████████| 65/65 [02:30<00:00,  2.32s/it]
Inference: 100%|██████████| 65/65 [02:25<00:00,  2.24s/it]


Model: best_baseline_resnet18.pth
Accuracy: 99.95%
Precision: 0.9995
Recall: 0.9996
Latency (BS=1): 107.93 ms
Params: 11.18 M
Size: 42.66 MB
Model: best_improved_resnet18plant.pth
Accuracy: 99.85%
Precision: 0.9984
Recall: 0.9989
Latency (BS=1): 103.49 ms
Params: 11.44 M
Size: 43.65 MB
Model: best_cnn_model.pth
Accuracy: 98.40%
Precision: 0.9826
Recall: 0.9827
Latency (BS=1): 53.64 ms
Params: 0.46 M
Size: 1.75 MB
Model: best_customcnn_distilled.pth
Accuracy: 98.50%
Precision: 0.9832
Recall: 0.9726
Latency (BS=1): 64.16 ms
Params: 0.46 M
Size: 1.75 MB
Model: best_customcnn_distilled_real_baseline.pth
Accuracy: 98.25%
Precision: 0.9805
Recall: 0.9697
Latency (BS=1): 59.69 ms
Params: 0.46 M
Size: 1.75 MB
