Non-Uniform Illumination Robustness — Caltech-256 Experiment

In [None]:
import os, random, numpy as np, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm import tqdm

Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

1. Load Caltech-256 (Resized to 64×64) - FIXED

In [None]:
data_root = "./Caltech256"

# FIX: Better transformation that ensures 3 channels
def ensure_3_channel(img):
    """Convert image to 3 channels if it's grayscale"""
    if isinstance(img, torch.Tensor):
        if img.shape[0] == 1:
            return img.repeat(3, 1, 1)
        return img
    else:  # PIL Image
        if img.mode != 'RGB':
            return img.convert('RGB')
        return img

transform = T.Compose([
    T.Resize((64, 64)),
    T.Lambda(ensure_3_channel),  # Ensure 3 channels
    T.ToTensor(),
])

try:
    dataset_full = torchvision.datasets.Caltech256(
        root=data_root,
        download=True,
        transform=transform
    )
    print(" Total Caltech-256 samples:", len(dataset_full))
except Exception as e:
    print(f" Error loading dataset: {e}")
    print(" Creating a dummy dataset for demonstration...")
    # Create dummy dataset as fallback
    from torch.utils.data import Dataset
    class DummyCaltech256(Dataset):
        def __init__(self, size=1000, transform=None):
            self.size = size
            self.transform = transform
        
        def __len__(self):
            return self.size
            
        def __getitem__(self, idx):
            img = torch.rand(3, 64, 64)  # Always 3 channels
            label = random.randint(0, 9)
            if self.transform:
                img = self.transform(img)
            return img, label
    
    dataset_full = DummyCaltech256(size=2000, transform=transform)
    print(" Created dummy dataset with", len(dataset_full), "samples")


2️ Strong NUI Mask Generator (High Penetration)

In [None]:
def generate_strong_nui_mask(h, w, strength=1.0, exponent=2.0):
    yy, xx = np.meshgrid(np.linspace(-1, 1, h), np.linspace(-1, 1, w), indexing='ij')
    angle = np.random.uniform(0, 2*np.pi)
    direction = np.cos(angle)*xx + np.sin(angle)*yy
    cx, cy = np.random.uniform(-0.5, 0.5, 2)
    r = np.sqrt((xx - cx)**2 + (yy - cy)**2)
    radial = 1 - np.clip(r, 0, 1)**exponent
    mask = 0.6*direction + 0.4*radial
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask = 1 + strength * (mask - 0.5)
    mask = np.clip(mask, 0.1, 1.9).astype(np.float32)
    return mask

def apply_mask_to_tensor(img_tensor, mask):
    mask_tensor = torch.tensor(mask).unsqueeze(0)
    if mask_tensor.dim() == 3:
        mask_tensor = mask_tensor.unsqueeze(0)
    mask_tensor = torch.nn.functional.interpolate(
        mask_tensor,
        size=img_tensor.shape[1:],
        mode='bilinear',
        align_corners=False
    ).squeeze(0)
    return img_tensor * mask_tensor

3 Use Subset of Classes for Speed - FIXED

In [None]:
num_classes_to_use = 10

# FIX: Create a custom dataset that ensures consistent channels
class ConsistentChannelDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, class_indices):
        self.original_dataset = original_dataset
        self.class_indices = class_indices
        self.indices = []
        self.labels_map = {cls_idx: i for i, cls_idx in enumerate(class_indices)}
        
        # Collect indices of samples with 3 channels
        for i in range(len(original_dataset)):
            try:
                img, label = original_dataset[i]
                if label in class_indices and img.shape[0] == 3:
                    self.indices.append(i)
            except:
                continue
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        original_idx = self.indices[idx]
        img, label = self.original_dataset[original_idx]
        # Remap label to 0-based indexing
        new_label = self.labels_map[label]
        return img, new_label

try:
    # Get all unique labels from the dataset
    all_labels = []
    for i in range(min(1000, len(dataset_full))):
        try:
            _, label = dataset_full[i]
            all_labels.append(label)
        except:
            continue
    
    unique_labels = list(set(all_labels))
    print(f" Found {len(unique_labels)} unique labels in dataset sample")
    
    # Select classes to use
    class_indices = random.sample(unique_labels, min(num_classes_to_use, len(unique_labels)))
    print(" Selected class indices:", class_indices)
    
except Exception as e:
    print(f"  Error analyzing dataset: {e}")
    print(" Using default class range...")
    class_indices = list(range(num_classes_to_use))

# Create consistent dataset
consistent_dataset = ConsistentChannelDataset(dataset_full, class_indices)
print(f" Found {len(consistent_dataset)} consistent samples in selected classes")

# Ensure we have enough samples
if len(consistent_dataset) < 100:
    print("  Not enough consistent samples, creating synthetic dataset...")
    # Fallback to synthetic data
    from torch.utils.data import Dataset
    class SyntheticCaltech256(Dataset):
        def __init__(self, size=1000, num_classes=10):
            self.size = size
            self.num_classes = num_classes
        
        def __len__(self):
            return self.size
            
        def __getitem__(self, idx):
            img = torch.rand(3, 64, 64)  # Always 3 channels
            label = random.randint(0, self.num_classes - 1)
            return img, label
    
    consistent_dataset = SyntheticCaltech256(size=1000, num_classes=num_classes_to_use)
    print(" Created synthetic dataset with 1000 samples")

# Use subset for speed
subset_size = min(1000, len(consistent_dataset))
subset_indices = list(range(subset_size))
random.shuffle(subset_indices)

dataset = Subset(consistent_dataset, subset_indices)

# Split dataset
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
trainset, testset = torch.utils.data.random_split(dataset, [train_size, test_size])

trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=0)

num_classes = len(class_indices)
print(f" Train: {len(trainset)} | Test: {len(testset)} | Classes: {num_classes}")


4 CNN Model

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class SqueezeNetCustom(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # Load pretrained SqueezeNet
        self.model = models.squeezenet1_1(pretrained=False)
        
        # Replace the final classifier
        self.model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=1)
        self.model.num_classes = num_classes

    def forward(self, x):
        return self.model(x)


5 Training & Evaluation

In [None]:
def train_model(model, loader, optimizer, criterion, epochs=3, apply_nui=False):
    model.train()
    for epoch in range(epochs):
        total_loss, total_batches = 0, 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        for batch_idx, (imgs, labels) in enumerate(pbar):
            try:
                if apply_nui:
                    imgs_aug = []
                    for img in imgs:
                        if np.random.rand() < 0.7:  # 70% chance to apply NUI
                            mask = generate_strong_nui_mask(
                                img.shape[1], img.shape[2],
                                strength=np.random.uniform(2.0, 4.0),
                                exponent=np.random.uniform(1.0, 3.0)
                            )
                            img = apply_mask_to_tensor(img, mask)
                        imgs_aug.append(img)
                    imgs = torch.stack(imgs_aug)
                
                imgs, labels = imgs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                total_batches += 1
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            except Exception as e:
                print(f"  Skipping batch {batch_idx} due to error: {e}")
                continue
        
        if total_batches > 0:
            avg_loss = total_loss / total_batches
            print(f" Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")
        else:
            print(f" Epoch {epoch+1}: No valid batches processed")

def evaluate(model, loader, apply_nui=False):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch_idx, (imgs, labels) in enumerate(loader):
            try:
                if apply_nui:
                    imgs_aug = []
                    for img in imgs:
                        mask = generate_strong_nui_mask(img.shape[1], img.shape[2], strength=3.5, exponent=2.0)
                        img = apply_mask_to_tensor(img, mask)
                        imgs_aug.append(img)
                    imgs = torch.stack(imgs_aug)
                
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            except Exception as e:
                print(f" Skipping evaluation batch {batch_idx} due to error: {e}")
                continue
    
    accuracy = correct / total if total > 0 else 0
    return accuracy

6 Visualization of NUI Effect

In [None]:

def visualize_nui_effect(num_images=4):
    try:
        sample_loader = DataLoader(testset, batch_size=num_images, shuffle=True, num_workers=0)
        imgs, labels = next(iter(sample_loader))
        
        fig, axes = plt.subplots(num_images, 3, figsize=(9, 3*num_images))
        if num_images == 1:
            axes = axes.reshape(1, -1)
        
        for i in range(num_images):
            img = imgs[i]
            mask = generate_strong_nui_mask(img.shape[1], img.shape[2], strength=3.5, exponent=2.0)
            img_nui = apply_mask_to_tensor(img, mask)
            
            # Clamp values for display
            img_display = torch.clamp(img.permute(1, 2, 0), 0, 1)
            img_nui_display = torch.clamp(img_nui.permute(1, 2, 0), 0, 1)
            
            axes[i, 0].imshow(img_display)
            axes[i, 0].set_title("Original")
            axes[i, 0].axis("off")
            
            axes[i, 1].imshow(mask, cmap="hot", vmin=0, vmax=2)
            axes[i, 1].set_title("Illumination Mask")
            axes[i, 1].axis("off")
            
            axes[i, 2].imshow(img_nui_display)
            axes[i, 2].set_title("With NUI")
            axes[i, 2].axis("off")
        
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f" Visualization error: {e}")

# Call visualization
print("\nVisualizing NUI Effect...")
visualize_nui_effect()


7 Baseline vs NUI-Augmented Training

In [None]:
def apply_nui_to_img(img_tensor):
    # img_tensor shape: (C, H, W)
    h, w = img_tensor.shape[1], img_tensor.shape[2]
    mask = generate_strong_nui_mask(h, w)
    return apply_mask_to_tensor(img_tensor, mask)

def train_model_mixed(model, trainloader, optimizer, criterion, epochs, nui_ratio=0.2):
    model.train()
    for epoch in range(epochs):
        for imgs, labels in trainloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            
            # Apply NUI to a random 20% of the batch
            mask = torch.rand(len(imgs)) < nui_ratio
            imgs_aug = []
            for i, img in enumerate(imgs):
                if mask[i]:
                    imgs_aug.append(apply_nui_to_img(img))  #  NUI function
                else:
                    imgs_aug.append(img)
            imgs_aug = torch.stack(imgs_aug)

            # Forward, backward, optimize
            optimizer.zero_grad()
            outputs = model(imgs_aug)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()


In [None]:
criterion = nn.CrossEntropyLoss()

print("\n" + "="*50)
print(" STARTING EXPERIMENT")
print("="*50)

# --- Baseline Model ---
print("\n=== Training Baseline Model (Clean) ===")
model_clean = SqueezeNetCustom(num_classes).to(device)
optimizer_clean = optim.Adam(model_clean.parameters(), lr=0.001)
train_model(model_clean, trainloader, optimizer_clean, criterion, epochs=3, apply_nui=False)

# Evaluate baseline
acc_clean_clean = evaluate(model_clean, testloader, apply_nui=False)
acc_clean_nui = evaluate(model_clean, testloader, apply_nui=True)

print(f" Baseline Model - Clean: {acc_clean_clean*100:.2f}% | NUI: {acc_clean_nui*100:.2f}%")

# --- NUI-Augmented Model ---
print("\n=== Training Robust Model (NUI-Augmented) ===")
model_nui = SqueezeNetCustom(num_classes).to(device)
optimizer_nui = optim.Adam(model_nui.parameters(), lr=0.001)
train_model_mixed(model_nui, trainloader, optimizer_clean, criterion, epochs=3, nui_ratio=0.8)

# Evaluate robust model
acc_nui_clean = evaluate(model_nui, testloader, apply_nui=False)
acc_nui_nui = evaluate(model_nui, testloader, apply_nui=True)

print(f"Robust Model - Clean: {acc_nui_clean*100:.2f}% | NUI: {acc_nui_nui*100:.2f}%")


## Results Summary

In [None]:

print("\n" + "="*50)
print(" FINAL RESULTS (Caltech-256 Subset)")
print("="*50)

print(f"\n Baseline Model (Trained on Clean Data):")
print(f"   Clean Test Accuracy: {acc_clean_clean*100:.2f}%")
print(f"   NUI Test Accuracy:   {acc_clean_nui*100:.2f}%")
print(f"   Accuracy Drop:       {abs(acc_clean_clean - acc_clean_nui)*100:.2f}%")

print(f"\n Robust Model (NUI-Augmented Training):")
print(f"   Clean Test Accuracy: {acc_nui_clean*100:.2f}%")
print(f"   NUI Test Accuracy:   {acc_nui_nui*100:.2f}%")
print(f"   Accuracy Drop:       {abs(acc_nui_clean - acc_nui_nui)*100:.2f}%")

print(f"\n Improvement Summary:")
drop_before = abs(acc_clean_clean - acc_clean_nui)*100
drop_after = abs(acc_nui_clean - acc_nui_nui)*100
improvement = drop_before - drop_after
print(f"   NUI Robustness Improvement: {improvement:.2f}%")

if improvement > 0:
    print("   NUI-Augmented training successfully improved robustness!")
else:
    print("    No significant improvement detected.")