# Fruit Quality Inspection and Generation

Automated food inspection requires not just accurate models, but *interpretable* ones. Before deploying a system to sort fresh produce from spoiled, we need to understand why it makes decisions. Does the model classify a fruit as "Rotten" because of visible mold, or is it picking up on background artifacts?

This notebook explores the internals of a deep learning model trained to inspect fruit. We'll build three visualization tools to examine its internal features, explain its decisions, and experiment with diffusion to generate some training data. A quick glance into the program:

- **Visualizing Feature Hierarchy**: Observe how a CNN transforms pixels into abstract features
- **Saliency Maps**: Identify which pixels most influence predictions
- **Grad-CAM**: Generate heatmaps highlighting decision-driving regions
- **Generating Synthetic Data**: Use Stable Diffusion to synthesize rare defect images

## Table of Contents
- [Imports](#0)
- [1 - Setting the Stage: Data and Model](#1)
    - [1.1 - The Fruit Dataset](#1-1)
    - [1.2 - Loading the Pre-trained Inspector](#1-2)
    - [1.3 - Making a Prediction](#1-3)
- [2 - Visualizing Internal Representations](#2)
    - [2.1 - Hooking into the Hierarchy](#2-1)
    - [2.2 - Capturing the Hierarchy](#2-2)
    - [2.3 - Processing Feature Maps](#2-3)
- [3 - Pixel Level Scrutiny: Saliency Maps](#3)
- [4 - Regional Attention: Class Activation Maps](#4)
- [5 - Comparison of Interpretability Techniques](#5)
- [6 - (Optional) Generative AI for Synthetic Data](#6)
    - [6.1 - Setting up Stable Diffusion](#6-1)
    - [6.2 - Generating Synthetic Data](#6-2)
    - [6.3 - Peeking into the Diffusion Process](#6-3)
- [7 - Conclusion](#7)

## Imports

In [None]:
import torch
from torch.nn import functional as F

In [None]:
import gc
from pathlib import Path

from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms

import helper_utils
import unittests

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

<a name='1'></a>
## 1 - Setting the Stage: Data and Model

Before auditing an AI's decision-making process, we need to establish the environment: the data representing the real-world problem and the pre-trained model that performs the inspection.

<a name='1-1'></a>
### 1.1 - The Fruit Dataset

We work with a curated subset of the [Fruit and Vegetable Disease (Healthy vs Rotten)](https://www.kaggle.com/datasets/muhammad0subhan/fruit-and-vegetable-disease-healthy-vs-rotten) dataset, focusing on three fruit categories: **Apples**, **Mangoes**, and **Tomatoes**, each with **Healthy** and **Rotten** conditions.

- **Total images**: 60 (10 per category for each fruit type)
- **Classes**: 2 (Fresh = 0, Rotten = 1)
- **Structure**: `./fruits_subset/{Fruit}_{Condition}/`

In [None]:
dataset_path = "./fruits_subset/"

In [None]:
helper_utils.plot_samples_from_dataset(dataset_path)

<a name='1-2'></a>
### 1.2 - Loading the Pre-trained Inspector

We use a **ResNet-50** architecture adapted for binary classification (fresh vs rotten). The model comes pre-trained on this task, allowing us to focus directly on interpretability.

In [None]:
fruits_model = helper_utils.load_model("./models/fruits_quality_model.pth", device)
fruits_model = fruits_model.to(device)

The architecture summary shows how the network is structured—early convolutional layers capture low-level details, while deeper sequential blocks combine features into abstract patterns. Understanding this structure helps us choose which layers to probe for visualization.

In [None]:
helper_utils.display_model_architecture(fruits_model)

<a name='1-3'></a>
### 1.3 - Making a Prediction

Let's see the model in action. The interactive tool below feeds images into the model and displays predictions in real time.

In [None]:
helper_utils.predict_fruit_quality(fruits_model, dataset_path, device)

The model gives answers, but no explanations. How did it know? Did it recognize skin texture, spot a defect, or just guess based on color? Right now it's a "black box"—in the next sections, we'll build tools to look inside.

<a name='2'></a>
## 2 - Visualizing Internal Representations

A CNN doesn't see an "apple" all at once—it builds understanding hierarchically. Early layers detect edges and textures; deeper layers combine these into complex patterns like stems, bruises, or mold patches. Here we'll observe this transformation in real time.

<a name='2-1'></a>
### 2.1 - Hooking into the Hierarchy

PyTorch discards intermediate feature maps after the forward pass to save memory. To peek inside, we use **hooks**—functions registered to specific layers that intercept and save outputs as data flows through.

The `grab` helper creates these hooks. It returns a closure that, when attached to a layer, saves that layer's output (detached from the computation graph) into a dictionary.

In [None]:
def grab(activations, name):
    """
    Creates a forward hook function to capture and store the output of a specific layer.

    Arguments:
        activations: A dictionary where the captured layer output will be stored.
        name: The key under which the output tensor will be saved in the dictionary.

    Returns:
        _hook: The closure function to be registered as a hook.
    """
    def _hook(_, __, out): 
        activations[name] = out.detach()
    return _hook

<a name='2-2'></a>
### 2.2 - Capturing the Hierarchy

This function captures feature maps from five key points in the ResNet: `conv1` and the first convolution of each residual layer (`layer1` through `layer4`). This gives us a "fingerprint" of how the image is represented at different depths.

In [None]:
def cnn_feature_hierarchy(img, model):
    """
    Visualizes the feature hierarchy of a CNN by capturing feature maps 
    from specific layers during a forward pass.

    Arguments:
        img: The input tensor (image) to process.
        model: The pretrained neural network module to use for feature extraction.

    Returns:
        activations: A dictionary mapping layer names to their captured 
                     feature-map tensors.
    """
    activations = {}
    
    layers = { 
        "conv1": model.conv1,
        "layer1": model.layer1[0].conv1,
        "layer2": model.layer2[0].conv1,
        "layer3": model.layer3[0].conv1,
        "layer4": model.layer4[0].conv1
    } 

    hooks = []
    for name, layer in layers.items():
        hook_function = grab(activations, name)
        hook_handle = layer.register_forward_hook(hook_function)
        hooks.append(hook_handle)

    with torch.no_grad():
        _ = model(img) 

    for h in hooks:  
        h.remove() 

    return activations

In [None]:
image_path = "./fruits_subset/Apple_Healthy/FreshApple_3.jpg"
img = helper_utils.preprocess_image(image_path, device)

activations = cnn_feature_hierarchy(img=img, model=fruits_model)

print("Activations Keys and Shapes:\n")
for name, tensor in activations.items():
    print(f"{name}:\t{tensor.shape}")

In [None]:
image_path = "./fruits_subset/Apple_Healthy/FreshApple_3.jpg"
img = helper_utils.preprocess_image(image_path, device)
activations = cnn_feature_hierarchy(img=img, model=fruits_model)
helper_utils.display_feature_hierarchy(activations, img)

<a name='2-3'></a>
### 2.3 - Processing Feature Maps

Raw feature maps are hard to interpret—early layers have high resolution but few channels, while deeper layers have low resolution but hundreds of channels. This function creates a standardized "visual strip" by selecting the most active channel at each depth and resizing it to match the original image.

In [None]:
def feature_map_strip(img, model):
    """
    Processes an image through a model to extract, select, and upsample 
    representative feature maps from specific layers.

    Arguments:
        img: The input image tensor.
        model: The pretrained neural network module used for feature extraction.

    Returns:
        upsampled: A list of tensors, each representing the most active 
                   channel from a specific layer, resized to 224x224 and 
                   normalized to [0, 1].
    """
    feats = cnn_feature_hierarchy(img, model)
    upsampled = [] 

    for name in ["conv1", "layer1", "layer2", "layer3", "layer4"]:
        fm = feats[name]
        avg_activation = torch.mean(fm, dim=(2, 3))
        idx = torch.argmax(avg_activation)
        sel = fm[:, idx:idx+1] 
        sel = F.interpolate(sel, size=(224, 224), mode="bilinear", align_corners=False)
        sel = (sel - sel.min()) / (sel.max() - sel.min() + 1e-8)
        upsampled.append(sel)

    return upsampled

In [None]:
image_path = "./fruits_subset/Tomato_Rotten/rottenTomato_8.jpg"
img = helper_utils.preprocess_image(image_path, device)

upsampled = feature_map_strip(img=img, model=fruits_model)

print("Shape of the upsampled feature maps:\n")
for i, name in enumerate(["conv1", "layer1", "layer2", "layer3", "layer4"]):
    print(f"{name}:  {upsampled[i].shape}")

In [None]:
image_path = "./fruits_subset/Tomato_Rotten/rottenTomato_8.jpg"
img = helper_utils.preprocess_image(image_path, device)
upsampled = feature_map_strip(img=img, model=fruits_model)
helper_utils.visual_strip(upsampled)

<a name='3'></a>
## 3 - Pixel Level Scrutiny: Saliency Maps

Feature maps show *what* patterns the network detects, but not *which* ones matter for the decision. **Saliency Maps** answer this by computing the gradient of the prediction with respect to input pixels—essentially asking: "If I slightly change this pixel, how much does your confidence change?"

This provides pixel-level sensitivity, helping verify that attention focuses on the fruit rather than background clutter.

In [None]:
def saliency_map(model, image_tensor, class_idx):
    """
    Generate a saliency map for a single image and class.

    Arguments:
        model: A trained CNN model instance; should be in evaluation mode.
        image_tensor: Input image tensor with shape (1, 3, H, W).
        class_idx: The integer index of the specific target class logit to explain.

    Returns:
        heatmap: A 2-D saliency heat-map normalised to [0, 1] with shape (H, W).
    """ 
    image_tensor = image_tensor.clone()
    image_tensor = image_tensor.detach()
    image_tensor.requires_grad_()

    output = model(image_tensor)
    target_logit = output[0, class_idx]

    model.zero_grad()
    target_logit.backward()

    grads = torch.abs(image_tensor.grad.data[0]).sum(dim=0)
    grads -= grads.min()
    grads /= (grads.max() - grads.min() + 1e-8)

    heatmap = grads.detach()
    return heatmap

In [None]:
image_path = "./fruits_subset/Apple_Rotten/rottenApple_7.jpg"
img = helper_utils.preprocess_image(image_path, device)
class_idx = 1  # 0 = fresh, 1 = rotten

heatmap = saliency_map(model=fruits_model, image_tensor=img, class_idx=class_idx)

print(f"Shape: {heatmap.shape}")
print(f"Range: min = {heatmap.min()}, max = {heatmap.max()}")

In [None]:
image_path = "./fruits_subset/Apple_Rotten/rottenApple_7.jpg"
img = helper_utils.preprocess_image(image_path, device)
class_idx = 1

heatmap = saliency_map(model=fruits_model, image_tensor=img, class_idx=class_idx)
helper_utils.display_saliency(image_tensor=img, heatmap=heatmap)

**Interpreting Saliency Maps:**
- Bright regions indicate pixels that strongly influence the prediction
- A well-trained model should highlight defects (brown spots, mold) rather than background
- Saliency maps can be noisy—focus on coherent clusters rather than individual pixels
- Sharp edges often appear salient simply due to high-frequency changes, not semantic importance

<a name='4'></a>
## 4 - Regional Attention: Class Activation Maps

Saliency maps are powerful but noisy. Sometimes we want a broader answer: not "which pixel?" but "which **region**?"

**Class Activation Maps (CAM)** combine the final convolutional layer's feature maps with classification weights to produce smooth heatmaps highlighting entire objects or regions the model focuses on.

Since our ResNet-50 uses Global Average Pooling → FC, we can compute CAM directly by mapping FC weights back onto feature maps—no backpropagation needed.

In [None]:
def simplified_cam(model, image_tensor, class_idx):
    """
    Generates a simplified Class Activation Map (CAM) for a specific image and class.

    Arguments:
        model: A trained ResNet-style neural network module.
        image_tensor: The input image tensor (1, 3, H, W), normalized for the model.
        class_idx: The integer index of the target class to explain.

    Returns:
        heatmap: A 2-D tensor representing the class activation heatmap, 
                 scaled to [0, 1] with the same spatial dimensions as the input.
    """
    fmap_holder = {}

    def save_fmap(_, __, output): 
        fmap_holder["feat"] = output.detach()

    hook = model.layer4[-1].conv3.register_forward_hook(save_fmap)

    with torch.no_grad():
        _ = model(image_tensor) 

    hook.remove() 

    feats = fmap_holder["feat"]
    weight_vec = model.fc.weight[class_idx]

    cam = torch.einsum("c,chw->hw", weight_vec, feats.squeeze(0))
    cam = F.relu(cam) 
    cam = (cam - cam.min()) / (cam.max() + 1e-8)

    H, W = image_tensor.shape[2:]
    cam_up = F.interpolate( 
        cam.unsqueeze(0).unsqueeze(0),
        size=(H, W),
        mode="bilinear", 
        align_corners=False, 
    )[0, 0] 

    heatmap = cam_up.cpu().detach()
    return heatmap

In [None]:
image_path = "./fruits_subset/Apple_Rotten/rottenApple_5.jpg"
img = helper_utils.preprocess_image(image_path, device)
class_idx = 1

heatmap = simplified_cam(model=fruits_model, image_tensor=img, class_idx=class_idx)

print(f"Shape: {heatmap.shape}")
print(f"Range: min = {heatmap.min()}, max = {heatmap.max()}")

In [None]:
image_path = "./fruits_subset/Apple_Rotten/rottenApple_5.jpg"
img = helper_utils.preprocess_image(image_path, device)
class_idx = 1

heatmap = simplified_cam(model=fruits_model, image_tensor=img, class_idx=class_idx)
helper_utils.display_cam(img, heatmap)

<a name='5'></a>
## 5 - Comparison of Interpretability Techniques

| **Aspect** | **Feature Hierarchy** | **Saliency Maps** | **CAM** |
|------------|----------------------|-------------------|---------|
| **Shows** | Feature evolution across layers | Pixel-level gradient sensitivity | Region-level class evidence |
| **Granularity** | Layer-by-layer | Individual pixels | Coarse spatial regions |
| **Speed** | Fast | Medium | Fast |
| **Requires gradients?** | No | Yes | No |
| **Best for** | Model debugging | Adversarial analysis, fine details | Trustworthiness, localization |

**Combining techniques:** Use CAM to verify focus on the fruit (not background), Saliency to pinpoint exact pixels driving decisions, and Feature Hierarchy to understand learned representations at each depth.

<a name='6'></a>
## 6 - (Optional) Generative AI for Synthetic Data

What happens when you need to train on rare defects but lack enough real photos? This is **data scarcity**—a common bottleneck in industrial AI.

Here we flip the script: instead of analyzing how models *interpret* images, we explore how they can *create* new ones using **Stable Diffusion** to synthesize realistic training data from scratch.

<a name='6-1'></a>
### 6.1 - Setting up Stable Diffusion

We use the Hugging Face `diffusers` library to load a pre-trained Stable Diffusion model, which bundles the text encoder, UNet, and VAE into a single pipeline.

In [None]:
def load_sd_pipeline(device, model_id="stabilityai/stable-diffusion-2-base"):
    """
    Initializes the Stable Diffusion pipeline from a pretrained model identifier 
    and transfers it to the specified computing device.

    Arguments:
        device: The target device (e.g., 'cuda', 'mps', 'cpu') for model execution.
        model_id: The repository ID of the pretrained model to load.
    """
    pipe = StableDiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path=model_id,
        torch_dtype=torch.float16,
        variant="fp16",
        cache_dir="./models",
        local_files_only=True 
    ).to(device) 
    
    return pipe

In [None]:
if "pipe" in globals():
    del pipe
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

try:
    helper_utils.check_model_snapshot()
    pipe = load_sd_pipeline(device)
    print("\nLoading Complete!")
except Exception as e:
    print(f"Error loading pipeline: {e}")
    if "pipe" in globals():
        del pipe
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    pipe = None

<a name='6-2'></a>
### 6.2 - Generating Synthetic Data

This function takes a text description (e.g., "A mango with a wormhole") and generates an image. By controlling the random seed, we ensure reproducibility—important for scientific and industrial contexts.

In [None]:
def generate_sd_image(pipe, prompt, negative_prompt, seed, steps, save_dir="synthetic"):
    """
    Generates a single image from a text prompt using a pre-loaded Stable Diffusion pipeline.

    Arguments:
        pipe: The initialized Stable Diffusion pipeline instance.
        prompt: The positive text description of the desired image.
        negative_prompt: Text description of elements to exclude from the image.
        seed: An integer value to initialize the random number generator.
        steps: The number of denoising steps to perform during inference.
        save_dir: The root directory path where the generated image will be saved.

    Returns:
        image: The generated PIL Image object.
    """
    device = pipe.device
    generator = torch.Generator(device=device).manual_seed(seed)
    
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=steps,
        generator=generator,
    ).images[0]

    slug = "_".join(prompt.lower().split()[:3]) 
    out_dir = Path(save_dir) / slug 
    out_dir.mkdir(parents=True, exist_ok=True) 
    out_path = out_dir / f"img_{seed}.png" 
    image.save(out_path)
    print(f"\nImage saved to {out_path}\n")

    return image

In [None]:
prompt = "A mango with a small hole made by a worm in the middle."
negative_prompt = "Fresh, intact."
seed = 42
steps = 50

In [None]:
try:
    img = generate_sd_image(
        pipe=pipe,
        prompt=prompt, 
        negative_prompt=negative_prompt,
        seed=seed, 
        steps=steps
    )
    plt.axis('off')
    plt.imshow(img)
except Exception as e:
    print(f"Error during generation: {e}")
    if "pipe" in globals():
        del pipe
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    pipe = None

<a name='6-3'></a>
### 6.3 - Peeking into the Diffusion Process

Stable Diffusion works by iteratively removing noise—starting from random static and gradually nudging pixel values until they match the text prompt. Visualizing this **denoising** process helps us see *when* the model decides on shapes versus when it refines textures.

We use a callback mechanism to intercept intermediate latents, decode them into viewable images, and assemble them into a grid.

In [None]:
def denoising_movie(pipe, prompt, seed, steps, capture_steps, save_grid_path="timelapse.png"):
    """
    Captures intermediate denoising frames from the Stable Diffusion process and 
    assembles them into a grid image.

    Arguments:
        pipe: The pre-loaded Stable Diffusion pipeline instance.
        prompt: The positive text description for generation.
        seed: An integer value for deterministic random noise generation.
        steps: The total number of inference steps to perform.
        capture_steps: A list of integer indices specifying which steps to capture.
        save_grid_path: The file path where the final grid image will be saved.

    Returns:
        ordered_frames: A list of PIL Image objects corresponding to the captured steps.
    """
    frames = {}

    def grab_frame(pipeline, step_idx, timestep, callback_kwargs): 
        if step_idx in capture_steps:
            latents = callback_kwargs["latents"]
            with torch.no_grad():
                img = pipe.vae.decode(
                    latents / pipe.vae.config.scaling_factor,
                    return_dict=False
                )[0] 
            pil = pipe.image_processor.postprocess(img, output_type="pil")[0]
            frames[step_idx] = pil
        return callback_kwargs

    generator = torch.Generator(pipe.device).manual_seed(seed)

    _ = pipe( 
        prompt=prompt,
        num_inference_steps=steps,
        generator=generator,
        callback_on_step_end=grab_frame,
    ) 

    ordered_frames = [frames[step] for step in capture_steps]

    w, h = ordered_frames[0].size 
    grid = Image.new("RGB", (w * 2, h * 2)) 
    for idx, frame in enumerate(ordered_frames): 
        row, col = divmod(idx, 2) 
        grid.paste(frame, (col * w, row * h)) 

    grid.save(save_grid_path) 
    print(f"Timelapse grid saved to {save_grid_path}") 

    return ordered_frames

In [None]:
prompt = "A healthy mango."
seed = 42
steps = 50
capture_steps = [0, 15, 30, 49]

In [None]:
try:
    ordered_frames = denoising_movie(
        pipe=pipe,
        prompt=prompt,
        seed=seed, 
        steps=steps,
        capture_steps=capture_steps
    )
    grid_image = plt.imread("timelapse.png")
    plt.axis('off')
    plt.imshow(grid_image)
except Exception as e:
    print(f"Error during denoising: {e}")
    if "pipe" in globals():
        del pipe
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    pipe = None

<a name='7'></a>
## 7 - Conclusion

So far, we've explored some interpretability techniques to better understand how a deep learning model makes predictions and also generates: 

- **Feature Hierarchy** revealed how the network builds understanding from edges to complex objects
- **Saliency Maps** provided pixel-level sensitivity to verify focus on relevant defects
- **Class Activation Maps** showed region-level attention for human-interpretable explanations
- **Stable Diffusion** demonstrated how generative AI can address data scarcity by synthesizing realistic training samples

These tools form a comprehensive toolkit for explaining, debugging, and improving computer vision systems in real-world applications.