# **ü´Å Lung Cancer Image Classification - Preprocessing Pipeline**

## üìã Project Overview
**Goal:** Build a deep learning model to classify lung CT scan images as **Normal** or **Malignant** (cancerous)

**This Notebook Covers:**
- ‚úÖ Data loading and preprocessing
- ‚úÖ Image enhancement using CLAHE
- ‚úÖ Data augmentation strategies
- ‚úÖ Dataset splitting (train/validation/test)
- ‚úÖ Data visualization

---

## üéØ Why Preprocessing Matters?
Medical images often have:
- Low contrast (hard to see differences)
- Varying sizes and orientations
- Different brightness levels

Proper preprocessing helps the model learn better patterns!

## üì¶ Import Libraries

**What each library does:**

- **`torch`** - PyTorch deep learning framework
- **`torchvision`** - Image processing tools for PyTorch (datasets, transforms, models)
- **`cv2` (OpenCV)** - Computer vision library (we use it for CLAHE enhancement)
- **`PIL` (Python Imaging Library)** - Basic image loading and manipulation
- **`numpy`** - Numerical operations on arrays
- **`matplotlib/seaborn`** - Data visualization
- **`tqdm`** - Progress bars for loops (makes waiting less boring!)
- **`pandas`** - Data analysis (if needed)

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import cv2
from PIL import Image

## ‚öôÔ∏è Configuration Constants

**Why these specific values?**

- **`DATA_DIR`** - Path to your organized dataset folders (train/val/test)
  
- **`BATCH_SIZE = 32`** - Number of images processed together
  - **Why 32?** Good balance between:
    - **Memory usage** (32 images fit in most GPUs)
    - **Training speed** (processes 32 at once, faster than 1 at a time)
    - **Gradient stability** (averages over 32 samples reduces noise)
  - Common choices: 16, 32, 64, 128
  
- **`IMAGE_SIZE = 224`** - Standard input size for most pretrained models
  - **Why 224?** Most ImageNet pretrained models (ResNet, VGG, EfficientNet) expect 224√ó224 images
  - Using standard size lets us use transfer learning later!

In [None]:
# CONSTANTS

# DATA_DIR = "F:/Machine Learning/PyTorch/Lung_Cancer/Final_Split_Data"
DATA_DIR = "/content/drive/MyDrive/Final_Split_Data"
BATCH_SIZE = 32
IMAGE_SIZE = 224 # 224x224 image pixels


In [None]:
DATA_DIR

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
print(f"üñ•Ô∏è  Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"‚ö° CUDA Available: {torch.cuda.is_available()}")
print(f"üìä GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Test actual speed
import time
x = torch.randn(1000, 1000).cuda()
start = time.time()
y = x @ x
torch.cuda.synchronize()
print(f"‚è±Ô∏è  GPU Speed Test: {(time.time()-start)*1000:.2f}ms")

## üîç Custom CLAHE Transform

### What is CLAHE?
**CLAHE** = Contrast Limited Adaptive Histogram Equalization

### Why do we need it?
- Medical images (CT scans) often have **low contrast**
- Hard to see subtle differences between normal and cancerous tissue
- CLAHE **enhances local contrast** without over-amplifying noise

### How CLAHE works:
1. **Divides image into small tiles** (8√ó8 grid)
2. **Applies histogram equalization to each tile separately** (enhances local details)
3. **Limits contrast amplification** (`clip_limit=2.0` prevents noise explosion)
4. **Blends tile boundaries smoothly** (avoids checkerboard effect)

### Key Parameters:
- **`clip_limit=2.0`** - Controls maximum contrast enhancement
  - Lower = less enhancement (0.5-1.0 for subtle)
  - Higher = more enhancement (2.0-4.0 for aggressive)
  - We use 2.0 as a balanced middle ground
  
- **`tile_grid_size=(8, 8)`** - Divides image into 8√ó8 = 64 tiles
  - Smaller tiles (4√ó4) = more local enhancement
  - Larger tiles (16√ó16) = more global enhancement

### Why LAB Color Space for RGB images?
- LAB separates **luminance (L)** from **color (A, B)**
- We only enhance luminance channel ‚Üí preserves original colors
- Converts: RGB ‚Üí LAB ‚Üí Enhance L ‚Üí RGB

In [None]:
class ApplyCLAHE:
    def __init__(self, clip_limit=1, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size

    def __call__(self, img):

        # convert PIL image to numpy array
        img_np = np.array(img)

        # apply CLAHE
        clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)

        # if gray scale
        if len(img_np.shape) == 2:
            img_clahe = clahe.apply(img_np)

        # if RGB, apply to each channel
        else:
            img_clahe = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
            img_clahe[:, :, 0] = clahe.apply(img_clahe[:, :, 0])
            img_clahe = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)

        return Image.fromarray(img_clahe)

## üîÑ Training Data Transforms (with Augmentation)

### Transform Pipeline Explained:

**1. `Grayscale(num_output_channels=1)`** - Convert to grayscale
   - **Why?** Lung CT scans don't need color, tissue structure matters more
   - Reduces data from 3 channels (RGB) to 1 channel

**2. `ApplyCLAHE(clip_limit=2.0)`** - Enhance contrast
   - Makes tissue differences more visible
   - Helps model detect subtle patterns

**3. `Grayscale(num_output_channels=3)`** - Convert back to 3-channel
   - **Why?** Pretrained models expect 3-channel input (RGB)
   - Simply triplicates the grayscale channel: [G] ‚Üí [G, G, G]

**4. `Resize((234, 234))` + `RandomCrop((224, 224))`** - Augmentation!
   - Resize to slightly larger (234√ó234)
   - Then randomly crop to 224√ó224
   - **Why?** Each epoch sees different crops ‚Üí model learns to be position-invariant
   - Prevents overfitting by adding variety

**5. `ToTensor()`** - Convert PIL Image ‚Üí PyTorch Tensor
   - Changes range from [0, 255] ‚Üí [0.0, 1.0]
   - Changes shape from (H, W, C) ‚Üí (C, H, W)

**6. `Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])`**
   - **Why these numbers?** ImageNet statistics (standard for transfer learning)
   - Formula: `(pixel - mean) / std`
   - Centers data around 0, makes training more stable
   - Required if using pretrained models!

In [None]:
# tranformers (grayscale, resize, to tensor, normalize)

train_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    ApplyCLAHE(clip_limit=2.0, tile_grid_size=(8, 8)),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMAGE_SIZE + 10, IMAGE_SIZE + 10)),
    transforms.RandomCrop((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
])

## ‚úÖ Validation/Test Transforms (Deterministic)

### Key Difference from Training Transforms:

**NO Random Augmentation!**
- **Training:** Uses `RandomCrop` ‚Üí different crops each time
- **Val/Test:** Direct `Resize` ‚Üí same image every time

### Why No Augmentation for Val/Test?
- **Consistency:** We want to evaluate model performance on same images
- **Fair comparison:** Results should be reproducible
- **Real-world simulation:** During deployment, you'll use raw images

### Transform Pipeline:
1. **Grayscale** ‚Üí CLAHE enhancement ‚Üí **3-channel**
2. **Direct resize to 224√ó224** (no random crop)
3. **ToTensor** + **Normalize** (same as training)

This ensures val/test preprocessing matches training preprocessing exactly, except for randomness!

In [None]:
# Val/Test transforms (deterministic)
val_test_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    ApplyCLAHE(clip_limit=2.0, tile_grid_size=(8, 8)),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Direct resize, no crop
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
])

## üìÇ Load Datasets

### What is `ImageFolder`?
PyTorch's convenient dataset loader that expects this structure:
```
Final_Split_Data/
‚îú‚îÄ‚îÄ train/
‚îÇ   ‚îú‚îÄ‚îÄ Malignant/   (all cancer images here)
‚îÇ   ‚îî‚îÄ‚îÄ Normal/      (all normal images here)
‚îú‚îÄ‚îÄ val/
‚îÇ   ‚îú‚îÄ‚îÄ Malignant/
‚îÇ   ‚îî‚îÄ‚îÄ Normal/
‚îî‚îÄ‚îÄ test/
    ‚îú‚îÄ‚îÄ Malignant/
    ‚îî‚îÄ‚îÄ Normal/
```

### How it Works:
- **Automatically assigns labels** based on folder names
  - Malignant = class 0 or 1
  - Normal = class 0 or 1
- **Applies transforms** to each image when loading
- **Returns:** (image_tensor, label) pairs

### Why Separate Datasets?
- **Training set:** Used to learn patterns (largest split, ~70-80%)
- **Validation set:** Tune hyperparameters, check overfitting (~10-15%)
- **Test set:** Final evaluation, never seen during training (~10-15%)

In [None]:
# laod dataset

train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_transforms)
test_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "test"), transform=val_test_transforms)
val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "val"), transform=val_test_transforms)

## üîÑ Create DataLoaders

### What is a DataLoader?
A DataLoader wraps a dataset and provides:
- **Batching:** Groups images into batches
- **Shuffling:** Randomizes order (for training only)
- **Parallel loading:** Loads data in background while model trains
- **Memory management:** Efficient data transfer to GPU

### Parameter Explanations:

**`batch_size=32`**
- Processes 32 images at once
- GPU computes gradients for all 32, then averages them

**`shuffle=True` (training only)**
- **Training:** `shuffle=True` ‚Üí random order each epoch (prevents learning order patterns)
- **Val/Test:** `shuffle=False` ‚Üí same order (consistency)

**`num_workers=2`**
- Uses 2 CPU threads to load data in background
- **Why 2?** Good balance for most systems
  - 0 = single-threaded (slow, blocks training)
  - 2-4 = parallel loading (faster, keeps GPU busy)
  - Too many = memory overhead

**`pin_memory=False`**
- `True` = faster GPU transfer (but uses more RAM)
- `False` = slower transfer (but safer for limited RAM)

**`persistent_workers=True`**
- Keeps workers alive between epochs
- **Benefit:** Faster epoch transitions (no worker restart overhead)
- **Cost:** Uses more memory

### Output Explanation:
- **Dataset sizes:** Total number of images
- **Loader sizes:** Number of batches (images √∑ batch_size)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=3,
    pin_memory=True,
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=3,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=3,
    pin_memory=True,
    persistent_workers=True
)

print(f"‚úÖ Data loaded successfully!")
print("‚úÖ Classes : ", train_dataset.classes)
print("‚úÖ Dataset sizes : Train", len(train_dataset))
print("‚úÖ Dataset sizes : Validation", len(val_dataset))
print("‚úÖ Dataset sizes : Test", len(test_dataset))

In [None]:
print(len(train_loader), len(val_loader), len(test_loader))

## üõ†Ô∏è Error Handling for Corrupted Images

### What does this do?
**`LOAD_TRUNCATED_IMAGES = True`** allows PIL to load partially corrupted images

### Why needed?
- Sometimes image files get corrupted during download/transfer
- Without this, training crashes with "image file truncated" error
- With this, PIL attempts to load as much as possible

### When to use:
- Large datasets downloaded from internet
- Medical imaging datasets (often have file issues)
- Any dataset where you can't manually verify every image

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

## üìä Visualize Class Distribution

### Why check class distribution?

**Class Imbalance Problem:**
- If dataset has 900 Normal and 100 Malignant images
- Model might just predict "Normal" for everything ‚Üí 90% accuracy!
- But it never learned to detect cancer (terrible for medical use)

### What to look for:
- ‚úÖ **Balanced:** Both classes have similar counts (~50/50)
- ‚ö†Ô∏è **Slightly imbalanced:** 60/40 or 70/30 (often okay)
- ‚ùå **Severely imbalanced:** 90/10 or worse (needs special handling)

### If imbalanced, solutions:
1. **Data augmentation** (generate more samples for minority class) ‚Üê I did this offline!
2. **Class weights** (penalize model more for minority class errors)
3. **Oversampling/Undersampling**
4. **Use F1-score instead of accuracy**

### This Plot Shows:
- Red bar = Malignant images count
- Green bar = Normal images count
- Ideally should be roughly equal!

In [None]:
from collections import Counter
import matplotlib.pyplot as plt

labels = train_dataset.targets
label_counts = Counter(labels)

class_names = train_dataset.classes
class_labels = [class_names[i] for i in label_counts.keys()]
counts = list(label_counts.values())

plt.figure(figsize=(8, 6))
plt.bar(class_labels, counts, color=['red', 'green'])
plt.title("Class Distribution in Training Set")
plt.xlabel("Classes")
plt.ylabel("Number of Images")
plt.show()

## üñºÔ∏è Visualize Sample Images

### Why visualize?

**Quality Control:**
1. ‚úÖ **Verify transforms work:** Are images properly enhanced?
2. ‚úÖ **Check labels:** Do labels match images?
3. ‚úÖ **Spot errors:** Are there any corrupted/wrong images?
4. ‚úÖ **Understand data:** What does the model actually see?

### Function Breakdown:

**`show_batch()` function:**
- **Randomly samples** images from dataset
- **Unnormalizes** them (reverses normalization to display properly)
- **Displays in grid** (3 rows √ó 5 cols = 15 images)

### Why Unnormalize?
- Training images are normalized: `(pixel - 0.485) / 0.229`
- This makes pixel values negative/strange
- To display properly: `pixel = normalized * 0.229 + 0.485`
- Then clamp to [0, 1] range

### What to Check:
- ‚úÖ CLAHE enhancement working? (good contrast)
- ‚úÖ Images clearly visible?
- ‚úÖ Labels match image content?
- ‚úÖ Any obvious corrupted images?

**Pro Tip:** Run this multiple times to see different random samples!

In [None]:
def show_batch(dataset, class_names, num_images=24):

    # Get images directly from dataset (much faster)
    indices = np.random.choice(len(dataset), min(num_images, len(dataset)), replace=False)

    rows = 3
    cols = 5
    fig, axes = plt.subplots(rows, cols, figsize=(15, 6))

    # Unnormalize parameters
    mean = torch.tensor([0.485, 0.485, 0.485]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.229, 0.229]).view(3, 1, 1)

    for i, ax in enumerate(axes.flatten()):
        if i < len(indices):
            img, label = dataset[indices[i]]

            # Unnormalize
            img = img * std + mean
            img = torch.clamp(img, 0, 1)

            # Convert to numpy
            img = img.numpy().transpose((1, 2, 0))

            ax.imshow(img, cmap='gray')
            ax.set_title(class_names[label], fontsize=10)
            ax.axis('off')
        else:
            ax.axis('off')

    plt.suptitle("Sample Images from Training Set", fontsize=16)
    plt.tight_layout()
    plt.show()

# Call with dataset instead of dataloader
show_batch(train_dataset, train_dataset.classes)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# üß† CNN Model Architecture

## Architecture Overview:

**Convolutional Blocks:**
1. **Block 1:** Conv(32 filters) ‚Üí ReLU ‚Üí MaxPool ‚Üí Dropout(0.25)
2. **Block 2:** Conv(64 filters) ‚Üí ReLU ‚Üí MaxPool ‚Üí Dropout(0.25)  
3. **Block 3:** Conv(128 filters) ‚Üí ReLU ‚Üí MaxPool ‚Üí Dropout(0.3)

**Fully Connected Layers:**
4. **Flatten** ‚Üí Converts feature maps to 1D vector
5. **Dense(512)** ‚Üí ReLU ‚Üí Dropout(0.5) ‚Üí High-level feature learning
6. **Dense(2)** ‚Üí Output layer (Normal vs Malignant)

### Why This Architecture?

**Increasing Filter Depth (32 ‚Üí 64 ‚Üí 128):**
- Early layers detect simple patterns (edges, textures)
- Deeper layers combine patterns into complex features (tissue structures)
- More filters = more feature detectors

**MaxPooling:**
- Reduces spatial dimensions (224√ó224 ‚Üí 112√ó112 ‚Üí 56√ó56 ‚Üí 28√ó28)
- Makes model invariant to small translations
- Reduces computation

**Dropout Strategy:**
- Lower dropout in conv layers (0.25-0.3) - preserve spatial features
- Higher dropout in dense layer (0.5) - prevent overfitting on high-level features
- Critical for medical images (limited data, high overfitting risk)

In [None]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")

if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üìä GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  Running on CPU (training will be slower)")

In [None]:
class LungCancerCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(LungCancerCNN, self).__init__()
        
        # Block 1: Conv(32) ‚Üí ReLU ‚Üí MaxPool ‚Üí Dropout(0.25)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout1 = nn.Dropout2d(0.25)
        
        # Block 2: Conv(64) ‚Üí ReLU ‚Üí MaxPool ‚Üí Dropout(0.25)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout2 = nn.Dropout2d(0.25)
        
        # Block 3: Conv(128) ‚Üí ReLU ‚Üí MaxPool ‚Üí Dropout(0.3)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout3 = nn.Dropout2d(0.3)
        
        # Calculate flattened size: 224 / 2 / 2 / 2 = 28
        # So feature maps are 28x28x128
        self.flatten_size = 28 * 28 * 128
        
        # Fully connected layers
        self.fc1 = nn.Linear(self.flatten_size, 512)
        self.dropout4 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)
        
    def forward(self, x):
        # Block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        
        # Block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.dropout2(x)
        
        # Block 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.pool3(x)
        x = self.dropout3(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout4(x)
        x = self.fc2(x)
        
        return x

# Initialize model
model = LungCancerCNN(num_classes=2).to(device)
print("‚úÖ Model created successfully!")
print(f"üìä Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Model summary
print("=" * 70)
print("üèóÔ∏è  MODEL ARCHITECTURE")
print("=" * 70)
print(model)
print("=" * 70)

# ‚öôÔ∏è Training Configuration

## Loss Function: CrossEntropyLoss
- Perfect for binary classification
- Combines softmax + negative log likelihood
- Automatically handles class probabilities

## Optimizer: Adam
- **Learning rate = 0.001** (default, good starting point)
- Adaptive learning rates per parameter
- Works well for medical image classification

## Learning Rate Scheduler: ReduceLROnPlateau
- Reduces LR when validation loss plateaus
- **Factor = 0.5** (halves LR)
- **Patience = 3** (waits 3 epochs before reducing)
- **Min LR = 1e-6** (prevents LR from getting too small)
- Helps model converge to better minima

In [None]:
# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6, verbose=True
)

# Training parameters
NUM_EPOCHS = 30
EARLY_STOPPING_PATIENCE = 7

print("‚úÖ Training configuration set!")
print(f"üìä Epochs: {NUM_EPOCHS}")
print(f"‚èπÔ∏è  Early stopping patience: {EARLY_STOPPING_PATIENCE}")

# üèãÔ∏è Training & Validation Functions

## Training Function:
1. **Sets model to training mode** (`model.train()`)
2. **Iterates through batches** with progress bar
3. **Forward pass** ‚Üí compute loss
4. **Backward pass** ‚Üí compute gradients
5. **Optimizer step** ‚Üí update weights
6. **Tracks metrics** (loss, accuracy)

## Validation Function:
1. **Sets model to evaluation mode** (`model.eval()`)
2. **Disables gradient computation** (`torch.no_grad()`)
3. **Evaluates on validation set**
4. **Returns metrics** for early stopping decisions

## Key Features:
- ‚úÖ Real-time progress bars (tqdm)
- ‚úÖ Batch-level accuracy tracking
- ‚úÖ Loss averaging
- ‚úÖ GPU memory efficient

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation", leave=False)
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

print("‚úÖ Training functions defined!")

# üöÄ Training Loop with Early Stopping

## What Happens Here:

**For Each Epoch:**
1. **Train** on training set
2. **Validate** on validation set
3. **Update learning rate** (scheduler)
4. **Track best model** (save if validation loss improves)
5. **Early stopping** (stop if no improvement for 7 epochs)

## Early Stopping:
- Prevents overfitting by stopping when model stops improving
- **Patience = 7** means we wait 7 epochs without improvement
- Saves training time
- Returns best model (not last model!)

## Metrics Tracked:
- ‚úÖ Training loss & accuracy
- ‚úÖ Validation loss & accuracy
- ‚úÖ Learning rate changes
- ‚úÖ Best model checkpoint

In [None]:
# Training loop with early stopping
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'lr': []
}

best_val_loss = float('inf')
best_model_state = None
patience_counter = 0

print("üöÄ Starting training...\n")
print("=" * 70)

for epoch in range(NUM_EPOCHS):
    print(f"\nüìÖ Epoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 70)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    
    # Print epoch results
    print(f"\nüìä Results:")
    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"   Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    print(f"   Learning Rate: {current_lr:.6f}")
    
    # Check for best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict().copy()
        patience_counter = 0
        print(f"   ‚úÖ New best model! (Val Loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        print(f"   ‚è≥ No improvement ({patience_counter}/{EARLY_STOPPING_PATIENCE})")
    
    # Early stopping
    if patience_counter >= EARLY_STOPPING_PATIENCE:
        print(f"\n‚èπÔ∏è  Early stopping triggered after {epoch+1} epochs!")
        print(f"   Best validation loss: {best_val_loss:.4f}")
        break

print("\n" + "=" * 70)
print("‚úÖ Training completed!")

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"‚úÖ Best model loaded (Val Loss: {best_val_loss:.4f})")

# üìà Visualize Training History

## What to Look For:

**Training vs Validation Loss:**
- ‚úÖ Both decreasing ‚Üí model learning well
- ‚ö†Ô∏è Val loss increases while train loss decreases ‚Üí overfitting
- ‚ö†Ô∏è Both high and flat ‚Üí underfitting (model too simple)

**Training vs Validation Accuracy:**
- ‚úÖ Both increasing together ‚Üí good generalization
- ‚ö†Ô∏è Large gap (train >> val) ‚Üí overfitting
- ‚ö†Ô∏è Both low ‚Üí model not learning

**Learning Rate Schedule:**
- Shows when LR was reduced (should align with validation plateaus)

## Ideal Pattern:
- Both curves smooth and converging
- Small gap between train and validation
- Steady improvement over time

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', marker='s', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('üìâ Training & Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', marker='o', linewidth=2)
axes[1].plot(history['val_acc'], label='Val Acc', marker='s', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('üìà Training & Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate plot
axes[2].plot(history['lr'], marker='o', linewidth=2, color='red')
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Learning Rate', fontsize=12)
axes[2].set_title('‚öôÔ∏è Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
print("\n" + "=" * 70)
print("üìä FINAL TRAINING METRICS")
print("=" * 70)
print(f"Best Validation Loss:     {min(history['val_loss']):.4f}")
print(f"Best Validation Accuracy: {max(history['val_acc']):.2f}%")
print(f"Final Train Accuracy:     {history['train_acc'][-1]:.2f}%")
print(f"Final Val Accuracy:       {history['val_acc'][-1]:.2f}%")
print("=" * 70)

# üéØ Test Set Evaluation

## Why Test Set?
- **Never seen during training** (unbiased evaluation)
- **Simulates real-world performance**
- **Final model assessment**

## Metrics We'll Calculate:
1. **Accuracy** - Overall correctness (TP + TN) / Total
2. **Precision** - Of predicted cancers, how many were correct? TP / (TP + FP)
3. **Recall (Sensitivity)** - Of actual cancers, how many did we catch? TP / (TP + FN)
4. **F1-Score** - Harmonic mean of precision and recall
5. **Confusion Matrix** - Visual breakdown of predictions

## For Medical Diagnosis:
- **High Recall** is critical! (Don't miss cancer cases)
- **False Negatives** are dangerous (cancer labeled as normal)
- **False Positives** are less critical (extra screening isn't harmful)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def evaluate_model(model, dataloader, device):
    """Comprehensive model evaluation"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.array(all_preds), np.array(all_labels)

# Evaluate on test set
print("üß™ Evaluating on test set...\n")
test_preds, test_labels = evaluate_model(model, test_loader, device)

# Calculate metrics
test_accuracy = accuracy_score(test_labels, test_preds)

print("=" * 70)
print("üéØ TEST SET RESULTS")
print("=" * 70)
print(f"\n‚úÖ Test Accuracy: {test_accuracy * 100:.2f}%\n")

# Classification report
print("üìä Detailed Classification Report:")
print("-" * 70)
print(classification_report(test_labels, test_preds, 
                          target_names=train_dataset.classes,
                          digits=4))
print("=" * 70)

# üî≤ Confusion Matrix Visualization

## How to Read:

**Matrix Layout:**
```
                Predicted
              Normal | Malignant
Actual Normal    TN   |   FP
    Malignant    FN   |   TP
```

**What Each Cell Means:**
- **TN (Top-Left):** Correctly identified normal cases ‚úÖ
- **TP (Bottom-Right):** Correctly identified cancer cases ‚úÖ
- **FP (Top-Right):** Normal classified as cancer ‚ö†Ô∏è (False Alarm)
- **FN (Bottom-Left):** Cancer classified as normal ‚ùå (Dangerous!)

## Goal:
- **Maximize diagonal** (TN and TP)
- **Minimize off-diagonal** (FP and FN)
- **Especially minimize FN** (missed cancer cases)

In [None]:
# Confusion Matrix
cm = confusion_matrix(test_labels, test_preds)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=train_dataset.classes,
            yticklabels=train_dataset.classes,
            ax=axes[0], cbar_kws={'label': 'Count'})
axes[0].set_title('üî≤ Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('True Label', fontsize=12)
axes[0].set_xlabel('Predicted Label', fontsize=12)

# Normalized (percentages)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=train_dataset.classes,
            yticklabels=train_dataset.classes,
            ax=axes[1], cbar_kws={'label': 'Percentage'})
axes[1].set_title('üî≤ Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_xlabel('Predicted Label', fontsize=12)

plt.tight_layout()
plt.show()

# Print confusion matrix interpretation
print("\nüìä Confusion Matrix Breakdown:")
print("=" * 70)
print(f"True Negatives (TN):  {cm[0][0]:4d} - Normal correctly identified")
print(f"False Positives (FP): {cm[0][1]:4d} - Normal wrongly labeled as Malignant")
print(f"False Negatives (FN): {cm[1][0]:4d} - Malignant wrongly labeled as Normal ‚ö†Ô∏è")
print(f"True Positives (TP):  {cm[1][1]:4d} - Malignant correctly identified")
print("=" * 70)

# üñºÔ∏è Visualize Predictions

## What This Shows:
- **Random sample** of test images
- **True labels** vs **predicted labels**
- **Correct predictions** in green ‚úÖ
- **Incorrect predictions** in red ‚ùå

## Analysis:
- Look for patterns in errors
- Are certain types of images harder to classify?
- Do misclassifications make visual sense?
- Quality control for model behavior

In [None]:
def visualize_predictions(model, dataset, class_names, num_images=15, device='cuda'):
    """Visualize model predictions on sample images"""
    model.eval()
    
    # Random sample
    indices = np.random.choice(len(dataset), num_images, replace=False)
    
    rows = 3
    cols = 5
    fig, axes = plt.subplots(rows, cols, figsize=(15, 9))
    
    # Unnormalize parameters
    mean = torch.tensor([0.485, 0.485, 0.485]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.229, 0.229]).view(3, 1, 1)
    
    with torch.no_grad():
        for i, ax in enumerate(axes.flatten()):
            if i < len(indices):
                img, label = dataset[indices[i]]
                
                # Predict
                img_batch = img.unsqueeze(0).to(device)
                output = model(img_batch)
                _, predicted = torch.max(output, 1)
                pred_class = predicted.item()
                
                # Unnormalize image
                img = img * std + mean
                img = torch.clamp(img, 0, 1)
                img = img.numpy().transpose((1, 2, 0))
                
                # Display
                ax.imshow(img, cmap='gray')
                
                # Color: green if correct, red if wrong
                is_correct = (pred_class == label)
                color = 'green' if is_correct else 'red'
                symbol = '‚úì' if is_correct else '‚úó'
                
                title = f"{symbol} True: {class_names[label]}\nPred: {class_names[pred_class]}"
                ax.set_title(title, fontsize=9, color=color, fontweight='bold')
                ax.axis('off')
            else:
                ax.axis('off')
    
    plt.suptitle("üñºÔ∏è Sample Predictions on Test Set", fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize predictions
visualize_predictions(model, test_dataset, train_dataset.classes, num_images=15, device=device)

# üíæ Save Model

## What Gets Saved:

**1. Model State Dict (lung_cancer_cnn.pth):**
- Model weights and biases
- Use for inference/deployment
- Requires model architecture to load

**2. Complete Checkpoint (lung_cancer_checkpoint.pth):**
- Model state
- Optimizer state
- Training history
- Hyperparameters
- Use to resume training

## File Locations:
- Saved in the same directory as notebook
- Can be uploaded to cloud storage
- Use for deployment or further training

In [None]:
# Save model
save_dir = "."  # Current directory

# Save model state dict
model_path = os.path.join(save_dir, "lung_cancer_cnn.pth")
torch.save(model.state_dict(), model_path)
print(f"‚úÖ Model saved to: {model_path}")

# Save complete checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'test_accuracy': test_accuracy,
    'class_names': train_dataset.classes,
    'image_size': IMAGE_SIZE,
    'batch_size': BATCH_SIZE
}

checkpoint_path = os.path.join(save_dir, "lung_cancer_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)
print(f"‚úÖ Checkpoint saved to: {checkpoint_path}")

print("\n" + "=" * 70)
print("üéâ Training Pipeline Complete!")
print("=" * 70)
print(f"‚úÖ Final Test Accuracy: {test_accuracy * 100:.2f}%")
print(f"‚úÖ Model files saved successfully")
print("=" * 70)

# üîÑ How to Load and Use Saved Model

## Loading the Model:

Use this code to load the trained model for inference or deployment:

```python
# Create model instance
model = LungCancerCNN(num_classes=2)

# Load weights
model.load_state_dict(torch.load('lung_cancer_cnn.pth'))
model.to(device)
model.eval()

# Now ready for predictions!
```

## Making Predictions on New Images:

```python
from PIL import Image

# Load and preprocess image
img = Image.open('new_ct_scan.png')
img_tensor = val_test_transforms(img).unsqueeze(0).to(device)

# Predict
with torch.no_grad():
    output = model(img_tensor)
    _, predicted = torch.max(output, 1)
    
print(f"Prediction: {train_dataset.classes[predicted.item()]}")
```

# üî• Grad-CAM Visualization

## What is Grad-CAM?
**Grad-CAM** = Gradient-weighted Class Activation Mapping

### Why Use Grad-CAM?
- **Explainability:** Shows which image regions the model focuses on
- **Trust:** Verify model looks at tissue (not artifacts/background)
- **Debugging:** Identify if model learns spurious correlations
- **Medical validation:** Critical for clinical applications

### How Grad-CAM Works:
1. Forward pass ‚Üí get predictions
2. Backward pass ‚Üí compute gradients of target class w.r.t. feature maps
3. Weight feature maps by gradients
4. Generate heatmap showing important regions

### What to Look For:
- ‚úÖ Model focuses on lung tissue
- ‚úÖ Different attention for Normal vs Malignant
- ‚ùå Model focuses on borders/artifacts (bad!)

In [None]:
import torch.nn.functional as F
from matplotlib.colors import LinearSegmentedColormap

class GradCAM:
    """Grad-CAM implementation for CNN visualization"""
    
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        """Save forward pass activations"""
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        """Save backward pass gradients"""
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image, target_class=None):
        """Generate Grad-CAM heatmap"""
        # Forward pass
        self.model.eval()
        output = self.model(input_image)
        
        # If no target class specified, use predicted class
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        class_score = output[0, target_class]
        class_score.backward()
        
        # Generate CAM
        gradients = self.gradients[0]  # [C, H, W]
        activations = self.activations[0]  # [C, H, W]
        
        # Global average pooling on gradients
        weights = gradients.mean(dim=(1, 2))  # [C]
        
        # Weighted combination of activation maps
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        # ReLU (only positive contributions)
        cam = F.relu(cam)
        
        # Normalize to [0, 1]
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        
        return cam.cpu().numpy(), target_class


def visualize_gradcam(model, dataset, class_names, num_images=9, device='cuda'):
    """Visualize Grad-CAM for sample images"""
    
    # Create Grad-CAM object (target last conv layer)
    gradcam = GradCAM(model, target_layer=model.conv3)
    
    # Random sample
    indices = np.random.choice(len(dataset), num_images, replace=False)
    
    rows = 3
    cols = 3
    fig, axes = plt.subplots(rows, cols * 3, figsize=(18, 9))
    
    # Unnormalize parameters
    mean = torch.tensor([0.485, 0.485, 0.485]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.229, 0.229]).view(3, 1, 1)
    
    for idx, img_idx in enumerate(indices):
        img, label = dataset[img_idx]
        
        # Get Grad-CAM
        img_batch = img.unsqueeze(0).to(device)
        cam, pred_class = gradcam.generate_cam(img_batch)
        
        # Unnormalize image
        img_display = img * std + mean
        img_display = torch.clamp(img_display, 0, 1)
        img_display = img_display.numpy().transpose((1, 2, 0))
        
        # Resize CAM to match image size
        cam_resized = cv2.resize(cam, (224, 224))
        
        # Get row and column for subplots
        row = idx // cols
        col_base = (idx % cols) * 3
        
        # 1. Original Image
        axes[row, col_base].imshow(img_display, cmap='gray')
        axes[row, col_base].set_title(f'True: {class_names[label]}', fontsize=9)
        axes[row, col_base].axis('off')
        
        # 2. Grad-CAM Heatmap
        axes[row, col_base + 1].imshow(cam_resized, cmap='jet')
        axes[row, col_base + 1].set_title(f'Pred: {class_names[pred_class]}', fontsize=9)
        axes[row, col_base + 1].axis('off')
        
        # 3. Overlay
        axes[row, col_base + 2].imshow(img_display, cmap='gray')
        axes[row, col_base + 2].imshow(cam_resized, cmap='jet', alpha=0.5)
        
        # Color: green if correct, red if wrong
        is_correct = (pred_class == label)
        color = 'green' if is_correct else 'red'
        symbol = '‚úì' if is_correct else '‚úó'
        axes[row, col_base + 2].set_title(f'{symbol} Overlay', fontsize=9, color=color)
        axes[row, col_base + 2].axis('off')
    
    plt.suptitle('üî• Grad-CAM Visualization: Original | Heatmap | Overlay', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Generate Grad-CAM visualizations
print("üî• Generating Grad-CAM visualizations...\n")
visualize_gradcam(model, test_dataset, train_dataset.classes, num_images=9, device=device)

In [None]:
# Advanced: Compare Grad-CAM for Correct vs Incorrect Predictions

def visualize_gradcam_comparison(model, dataset, class_names, device='cuda'):
    """Compare Grad-CAM for correct and incorrect predictions"""
    
    gradcam = GradCAM(model, target_layer=model.conv3)
    
    # Find correct and incorrect predictions
    correct_indices = []
    incorrect_indices = []
    
    model.eval()
    with torch.no_grad():
        for idx in range(len(dataset)):
            img, label = dataset[idx]
            img_batch = img.unsqueeze(0).to(device)
            output = model(img_batch)
            pred = output.argmax(dim=1).item()
            
            if pred == label:
                correct_indices.append(idx)
            else:
                incorrect_indices.append(idx)
            
            if len(correct_indices) >= 3 and len(incorrect_indices) >= 3:
                break
    
    # Plot comparison
    fig, axes = plt.subplots(2, 6, figsize=(18, 6))
    
    mean = torch.tensor([0.485, 0.485, 0.485]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.229, 0.229]).view(3, 1, 1)
    
    # Correct predictions (top row)
    for i, idx in enumerate(correct_indices[:3]):
        img, label = dataset[idx]
        img_batch = img.unsqueeze(0).to(device)
        cam, pred_class = gradcam.generate_cam(img_batch)
        
        img_display = img * std + mean
        img_display = torch.clamp(img_display, 0, 1).numpy().transpose((1, 2, 0))
        cam_resized = cv2.resize(cam, (224, 224))
        
        # Original
        axes[0, i*2].imshow(img_display, cmap='gray')
        axes[0, i*2].set_title(f'‚úì True: {class_names[label]}', fontsize=9, color='green')
        axes[0, i*2].axis('off')
        
        # Overlay
        axes[0, i*2+1].imshow(img_display, cmap='gray')
        axes[0, i*2+1].imshow(cam_resized, cmap='jet', alpha=0.5)
        axes[0, i*2+1].set_title(f'Pred: {class_names[pred_class]}', fontsize=9, color='green')
        axes[0, i*2+1].axis('off')
    
    # Incorrect predictions (bottom row)
    for i, idx in enumerate(incorrect_indices[:3]):
        img, label = dataset[idx]
        img_batch = img.unsqueeze(0).to(device)
        cam, pred_class = gradcam.generate_cam(img_batch)
        
        img_display = img * std + mean
        img_display = torch.clamp(img_display, 0, 1).numpy().transpose((1, 2, 0))
        cam_resized = cv2.resize(cam, (224, 224))
        
        # Original
        axes[1, i*2].imshow(img_display, cmap='gray')
        axes[1, i*2].set_title(f'‚úó True: {class_names[label]}', fontsize=9, color='red')
        axes[1, i*2].axis('off')
        
        # Overlay
        axes[1, i*2+1].imshow(img_display, cmap='gray')
        axes[1, i*2+1].imshow(cam_resized, cmap='jet', alpha=0.5)
        axes[1, i*2+1].set_title(f'Pred: {class_names[pred_class]}', fontsize=9, color='red')
        axes[1, i*2+1].axis('off')
    
    plt.suptitle('üî• Grad-CAM: Correct ‚úì vs Incorrect ‚úó Predictions', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Compare correct vs incorrect
print("\nüîç Comparing Grad-CAM for correct vs incorrect predictions...\n")
visualize_gradcam_comparison(model, test_dataset, train_dataset.classes, device=device)

# üìä Grad-CAM Interpretation Guide

## What Good Grad-CAM Looks Like:
- ‚úÖ **Focuses on lung tissue** (not borders/artifacts)
- ‚úÖ **Different patterns** for Normal vs Malignant
- ‚úÖ **Consistent attention** across similar cases
- ‚úÖ **Localized hotspots** on suspicious regions

## What Bad Grad-CAM Looks Like:
- ‚ùå Focuses on image corners/edges
- ‚ùå Highlights background/artifacts
- ‚ùå Random scattered attention
- ‚ùå Same pattern for all classes

## Medical Insights:
For **Malignant cases**, model should focus on:
- Irregular tissue patterns
- Dense nodules or masses
- Texture abnormalities

For **Normal cases**, model should recognize:
- Regular tissue structure
- Uniform density
- Absence of abnormalities

## Next Steps:
If Grad-CAM shows problems:
1. Add more data augmentation
2. Use weighted loss (focus on tissue regions)
3. Try attention mechanisms in architecture
4. Use segmentation masks if available