In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
from matplotlib.patches import Patch
from matplotlib.colors import ListedColormap, BoundaryNorm
import random
from torch.utils.data import Subset as TorchSubset

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

from google.colab import drive
drive.mount('/content/drive', force_remount=False)
SAVE_DIR = '/content/drive/MyDrive/comp576final'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import kagglehub
path = kagglehub.dataset_download("emmarex/plantdisease")
data_dir = os.path.join(path, 'PlantVillage')

def create_resnet18_baseline(num_classes, pretrained=True):
    model = models.resnet18(pretrained=pretrained)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

class ResNet18Plant(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        if pretrained:
            backbone = models.resnet18(weights="IMAGENET1K_V1")
        else:
            backbone = models.resnet18(weights=None)

        self.features = nn.Sequential(*list(backbone.children())[:-2])
        in_channels = 512

        self.classifier = nn.Sequential(
            nn.Linear(in_channels * 2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        gap = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
        gmp = F.adaptive_max_pool2d(x, 1).view(x.size(0), -1)
        feat = torch.cat([gap, gmp], dim=1)
        out = self.classifier(feat)
        return out

class CustomCNN(nn.Module):
    def __init__(self, num_classes, img_size=224):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        out = self.classifier(x)
        return out

def get_prediction_colormap(num_classes):
    base_cmap = plt.get_cmap('tab20')
    colors = [base_cmap(i % base_cmap.N) for i in range(num_classes)]
    cmap = ListedColormap(colors)
    boundaries = np.arange(-0.5, num_classes + 0.5, 1.0)
    norm = BoundaryNorm(boundaries, ncolors=num_classes)
    return cmap, norm

def denormalize_image(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def occlusion_sensitivity_analysis(models_info, device, loader, output_dir, class_names,
                                   num_examples=3, patch_size=16, stride=8, baseline_value=0.0):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    samples = []
    for data, target in loader:
        samples.append((data, int(target.item())))
        if len(samples) >= num_examples:
            break
    for sample_idx, (image_cpu, label_idx) in enumerate(samples):
        # Ensure image_cpu has batch dimension
        if image_cpu.dim() == 3:
            image_cpu = image_cpu.unsqueeze(0)
        image = image_cpu.to(device)
        # Calculate occlusion grid
        _, _, height, width = image_cpu.shape
        y_positions = list(range(0, height - patch_size + 1, stride))
        x_positions = list(range(0, width - patch_size + 1, stride))
        # cover the edges
        if y_positions[-1] + patch_size < height:
            y_positions.append(height - patch_size)
        if x_positions[-1] + patch_size < width:
            x_positions.append(width - patch_size)
        grid_y = len(y_positions)
        grid_x = len(x_positions)

        fig, axes = plt.subplots(1, 6, figsize=(30, 5))
        ax_orig = axes[0]
        image_display_orig = denormalize_image(image_cpu[0]).cpu().numpy()
        image_display_orig = np.transpose(image_display_orig, (1, 2, 0))
        image_display_orig = np.clip(image_display_orig, 0.0, 1.0)
        ax_orig.imshow(image_display_orig)
        ax_orig.set_title(f'Original Image\nTrue Label: {class_names[label_idx]}',
                         fontsize=12, fontweight='bold')
        ax_orig.axis('off')

        for model_idx, (model, model_name) in enumerate(models_info):
            model.eval()
            with torch.no_grad():
                output_base = model(image)
                probas_base = F.softmax(output_base, dim=1)
                base_prob = probas_base[0, label_idx].item()
                base_pred = output_base.argmax(dim=1).item()

                prob_map = np.full((grid_y, grid_x), np.nan, dtype=np.float32)
                for iy, y0 in enumerate(y_positions):
                    for ix, x0 in enumerate(x_positions):
                        occluded = image_cpu.clone()
                        y_end = min(y0 + patch_size, height)
                        x_end = min(x0 + patch_size, width)
                        if isinstance(baseline_value, torch.Tensor):
                            occluded[:, :, y0:y_end, x0:x_end] = baseline_value.view(3, 1, 1)
                        else:
                            occluded[:, :, y0:y_end, x0:x_end] = baseline_value
                        occluded_device = occluded.to(device)
                        output_occ = model(occluded_device)
                        probas_occ = F.softmax(output_occ, dim=1)
                        prob_map[iy, ix] = probas_occ[0, label_idx].item()

            ax = axes[model_idx + 1]
            image_display = denormalize_image(image_cpu[0]).cpu().numpy()
            image_display = np.transpose(image_display, (1, 2, 0))
            image_display = np.clip(image_display, 0.0, 1.0)
            prob_drop = base_prob - prob_map
            max_drop = np.nanmax(prob_drop)
            ax.imshow(image_display)
            im_overlay = ax.imshow(prob_drop, cmap='hot', origin='upper',
                                  extent=[0, width, height, 0], interpolation='bilinear',
                                  vmin=0, vmax=1.0, alpha=0.6)
            correct_mark = "correct" if base_pred == label_idx else "incorrect"
            ax.set_title(f'{model_name} {correct_mark}\nPred: {class_names[base_pred]} ({base_prob:.1%})\nMax Drop: {max_drop:.4f}',
                        fontsize=10, fontweight='bold')
            ax.axis('off')

            plt.colorbar(im_overlay, ax=ax, fraction=0.046, pad=0.04, label='Prob. Drop')
        fig.suptitle(f'Occlusion Sensitivity Comparison - Sample {sample_idx + 1}\n'
                    f'Patch: {patch_size}*{patch_size}, Stride: {stride}',
                    fontsize=14, fontweight='bold')
        fig.tight_layout(rect=[0, 0.01, 1, 0.96])
        output_path = output_dir / f'occlusion_patch{patch_size}_sample_{sample_idx + 1:02d}.png'
        fig.savefig(output_path, dpi=200, bbox_inches='tight')
        plt.close(fig)


def main():
    raw_dataset = ImageFolder(root=data_dir)
    num_classes = len(raw_dataset.classes)
    class_names = raw_dataset.classes
    split_path = os.path.join(SAVE_DIR, 'data_split_indices.pth')
    split_indices = torch.load(split_path)
    test_indices = split_indices['test_indices']
    test_subset = Subset(raw_dataset, test_indices)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    output_dir = Path(SAVE_DIR) / 'occlusion_sensitivity'
    output_dir.mkdir(exist_ok=True)

    num_samples = 10
    patch_size = 56
    stride = 28

    baseline_tensor = torch.randn(3, 1, 1) * 2.0  # Random noise with large std
    sampled_indices = np.random.choice(len(test_subset), num_samples, replace=False)
    selected_samples = TorchSubset(test_subset, sampled_indices.tolist())

    class TransformedSubset:
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform

        def __len__(self):
            return len(self.subset)

        def __getitem__(self, idx):
            image, label = self.subset[idx]
            return self.transform(image), label

    transformed_samples = TransformedSubset(selected_samples, transform)
    sample_loader = torch.utils.data.DataLoader(transformed_samples, batch_size=1, shuffle=False)
    models_info = []
    models_to_analyze = [
        ('ResNet18 Baseline', create_resnet18_baseline, 'best_baseline_resnet18.pth'),
        ('ResNet18 Improved', ResNet18Plant, 'best_improved_resnet18plant.pth'),
        ('CNN Scratch', CustomCNN, 'best_cnn_model.pth'),
        ('CNN Distilled (Baseline)', CustomCNN, 'best_customcnn_distilled_real_baseline.pth'),
        ('CNN Distilled (Improved)', CustomCNN, 'best_customcnn_distilled.pth'),
    ]

    for model_name, model_class, checkpoint_name in models_to_analyze:
        if 'ResNet18 Baseline' in model_name:
            model = model_class(num_classes, pretrained=False)
        elif 'ResNet18 Improved' in model_name:
            model = model_class(num_classes, pretrained=False)
        else:
            model = model_class(num_classes)
        checkpoint_path = os.path.join(SAVE_DIR, checkpoint_name)
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        model = model.to(device)
        model.eval()
        models_info.append((model, model_name))
    occlusion_sensitivity_analysis(
        models_info=models_info,
        device=device,
        loader=sample_loader,
        output_dir=output_dir,
        class_names=class_names,
        num_examples=num_samples,
        patch_size=patch_size,
        stride=stride,
        baseline_value=baseline_tensor
    )


if __name__ == '__main__':
    main()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using Colab cache for faster access to the 'plantdisease' dataset.


