# Lesson 4: Pixel Neighborhoods and Connectivity
## Biomedical Image Processing - Basic Concepts

### Topics:
- 4-connectivity vs 8-connectivity
- Neighborhood operations
- Connected component analysis
- Practical applications in medical imaging

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from collections import deque

print("Libraries imported successfully!")

## 1. What Are Pixel Neighbors?

Every pixel in an image has neighbors - the pixels surrounding it.
There are two main ways to define neighbors:

- **4-connectivity (N4)**: Only up, down, left, right
- **8-connectivity (N8)**: All 8 surrounding pixels (including diagonals)

In [None]:
# Visualize neighborhood types
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Create a 5x5 grid for visualization
grid = np.zeros((5, 5))

# 4-connectivity
n4 = np.zeros((5, 5))
n4[2, 2] = 2  # Center pixel (red)
n4[1, 2] = 1  # Top
n4[3, 2] = 1  # Bottom
n4[2, 1] = 1  # Left
n4[2, 3] = 1  # Right

# 8-connectivity
n8 = np.zeros((5, 5))
n8[2, 2] = 2  # Center pixel (red)
n8[1:4, 1:4] = 1  # All 8 neighbors
n8[2, 2] = 2  # Restore center

# Diagonal neighbors only (for comparison)
nd = np.zeros((5, 5))
nd[2, 2] = 2  # Center pixel
nd[1, 1] = 1
nd[1, 3] = 1
nd[3, 1] = 1
nd[3, 3] = 1

cmap = plt.cm.colors.ListedColormap(['white', 'lightblue', 'red'])

for ax, data, title in zip(axes, [n4, n8, nd], 
    ['4-Connectivity (N4)', '8-Connectivity (N8)', 'Diagonal Only (ND)']):
    ax.imshow(data, cmap=cmap, vmin=0, vmax=2)
    ax.set_title(title, fontsize=14)
    
    # Add grid lines
    for i in range(6):
        ax.axhline(i - 0.5, color='gray', linewidth=0.5)
        ax.axvline(i - 0.5, color='gray', linewidth=0.5)
    
    # Add labels
    for i in range(5):
        for j in range(5):
            if data[i, j] == 2:
                ax.text(j, i, 'P', ha='center', va='center', fontsize=12, fontweight='bold')
            elif data[i, j] == 1:
                ax.text(j, i, 'N', ha='center', va='center', fontsize=10)
    
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Pixel Neighborhood Types', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

print("P = Center Pixel, N = Neighbor")
print("N4 has 4 neighbors, N8 has 8 neighbors")

## 2. Getting Neighbors Programmatically

In [None]:
def get_neighbors_4(img, row, col):
    """Get 4-connected neighbors (up, down, left, right)."""
    neighbors = []
    height, width = img.shape[:2]
    
    # (row_offset, col_offset)
    offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # Up, Down, Left, Right
    
    for dr, dc in offsets:
        new_r, new_c = row + dr, col + dc
        if 0 <= new_r < height and 0 <= new_c < width:
            neighbors.append((new_r, new_c, img[new_r, new_c]))
    
    return neighbors

def get_neighbors_8(img, row, col):
    """Get 8-connected neighbors (including diagonals)."""
    neighbors = []
    height, width = img.shape[:2]
    
    for dr in [-1, 0, 1]:
        for dc in [-1, 0, 1]:
            if dr == 0 and dc == 0:
                continue  # Skip the center pixel
            new_r, new_c = row + dr, col + dc
            if 0 <= new_r < height and 0 <= new_c < width:
                neighbors.append((new_r, new_c, img[new_r, new_c]))
    
    return neighbors

# Test with a simple image
test_img = np.array([
    [10, 20, 30],
    [40, 50, 60],
    [70, 80, 90]
], dtype=np.uint8)

print("Test Image:")
print(test_img)
print()

print("4-neighbors of center pixel (1,1):")
for r, c, val in get_neighbors_4(test_img, 1, 1):
    print(f"  Position ({r},{c}) = {val}")

print("\n8-neighbors of center pixel (1,1):")
for r, c, val in get_neighbors_8(test_img, 1, 1):
    print(f"  Position ({r},{c}) = {val}")

## 3. Why Does Connectivity Matter?

The choice of connectivity affects how we count and identify objects in an image!
Let's see a surprising example.

In [None]:
# Create a diagonal pattern
diagonal_pattern = np.array([
    [1, 0, 0, 0, 1],
    [0, 1, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 1, 0, 1, 0],
    [1, 0, 0, 0, 1]
], dtype=np.uint8)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax in axes:
    ax.imshow(diagonal_pattern, cmap='gray_r', interpolation='nearest')
    for i in range(6):
        ax.axhline(i - 0.5, color='lightgray', linewidth=0.5)
        ax.axvline(i - 0.5, color='lightgray', linewidth=0.5)
    ax.set_xticks([])
    ax.set_yticks([])

axes[0].set_title('With 4-connectivity:\nHow many objects?', fontsize=12)
axes[1].set_title('With 8-connectivity:\nHow many objects?', fontsize=12)

plt.tight_layout()
plt.show()

print("\nANSWER:")
print("4-connectivity: 9 separate objects (diagonals don't connect)")
print("8-connectivity: 1 connected object (diagonals DO connect)")
print("\nThe same image - completely different interpretation!")

## 4. Connected Component Labeling

This is a fundamental algorithm in image processing that finds and labels separate objects.

In [None]:
def connected_components(binary_img, connectivity=4):
    """Find connected components using flood fill."""
    height, width = binary_img.shape
    labels = np.zeros_like(binary_img, dtype=np.int32)
    current_label = 0
    
    # Define neighbor offsets based on connectivity
    if connectivity == 4:
        offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    else:  # 8-connectivity
        offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), 
                   (0, 1), (1, -1), (1, 0), (1, 1)]
    
    for i in range(height):
        for j in range(width):
            # If pixel is foreground and not yet labeled
            if binary_img[i, j] == 1 and labels[i, j] == 0:
                current_label += 1
                
                # BFS flood fill
                queue = deque([(i, j)])
                labels[i, j] = current_label
                
                while queue:
                    r, c = queue.popleft()
                    
                    for dr, dc in offsets:
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < height and 0 <= nc < width and
                            binary_img[nr, nc] == 1 and labels[nr, nc] == 0):
                            labels[nr, nc] = current_label
                            queue.append((nr, nc))
    
    return labels, current_label

# Test on the diagonal pattern
labels_4, count_4 = connected_components(diagonal_pattern, connectivity=4)
labels_8, count_8 = connected_components(diagonal_pattern, connectivity=8)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(diagonal_pattern, cmap='gray_r', interpolation='nearest')
axes[0].set_title('Original Binary Image')
axes[0].axis('off')

im1 = axes[1].imshow(labels_4, cmap='tab10', interpolation='nearest')
axes[1].set_title(f'4-Connected: {count_4} Objects')
axes[1].axis('off')

im2 = axes[2].imshow(labels_8, cmap='tab10', interpolation='nearest')
axes[2].set_title(f'8-Connected: {count_8} Object(s)')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("Each color represents a different connected component.")

## 5. Real-World Example: Cell Counting

Connected component analysis is used to count cells in microscopy images!

In [None]:
# Create a simulated cell image
np.random.seed(42)
cell_image = np.zeros((100, 100), dtype=np.uint8)

# Add some "cells" as circles
def add_cell(img, center_r, center_c, radius):
    y, x = np.ogrid[:img.shape[0], :img.shape[1]]
    mask = (x - center_c)**2 + (y - center_r)**2 <= radius**2
    img[mask] = 1

# Add random cells
cell_positions = [(20, 25, 8), (20, 70, 6), (45, 45, 10), 
                  (60, 20, 7), (70, 80, 9), (85, 50, 6)]

for r, c, rad in cell_positions:
    add_cell(cell_image, r, c, rad)

# Count cells
labels, num_cells = connected_components(cell_image, connectivity=8)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(cell_image, cmap='gray')
axes[0].set_title('Simulated Microscopy Image')
axes[0].axis('off')

axes[1].imshow(labels, cmap='nipy_spectral')
axes[1].set_title(f'Labeled Cells: {num_cells} Found')
axes[1].axis('off')

# Show cell statistics
axes[2].axis('off')
stats_text = "Cell Statistics:\n\n"
for i in range(1, num_cells + 1):
    cell_pixels = np.sum(labels == i)
    stats_text += f"Cell {i}: {cell_pixels} pixels\n"
stats_text += f"\nTotal: {num_cells} cells"
axes[2].text(0.1, 0.5, stats_text, fontsize=12, family='monospace', 
             verticalalignment='center', transform=axes[2].transAxes)
axes[2].set_title('Cell Analysis')

plt.tight_layout()
plt.show()

## 6. Neighborhood Operations: Kernels/Filters

Most image filters work by examining a pixel and its neighbors together.

In [None]:
def apply_kernel(img, kernel):
    """Apply a 3x3 kernel to an image (simplified convolution)."""
    h, w = img.shape
    kh, kw = kernel.shape
    pad_h, pad_w = kh // 2, kw // 2
    
    # Pad image
    padded = np.pad(img.astype(np.float32), ((pad_h, pad_h), (pad_w, pad_w)), mode='edge')
    output = np.zeros_like(img, dtype=np.float32)
    
    for i in range(h):
        for j in range(w):
            region = padded[i:i+kh, j:j+kw]
            output[i, j] = np.sum(region * kernel)
    
    return output

# Create a test image with sharp edges
test = np.zeros((50, 50), dtype=np.uint8)
test[15:35, 15:35] = 200
test[20:30, 20:30] = 100

# Common kernels
# Average (blur)
kernel_average = np.ones((3, 3)) / 9

# Edge detection (Laplacian)
kernel_edge = np.array([
    [0, -1, 0],
    [-1, 4, -1],
    [0, -1, 0]
])

# Sharpen
kernel_sharpen = np.array([
    [0, -1, 0],
    [-1, 5, -1],
    [0, -1, 0]
])

# Apply kernels
blurred = apply_kernel(test, kernel_average)
edges = apply_kernel(test, kernel_edge)
sharpened = apply_kernel(test, kernel_sharpen)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Top row: Images
axes[0, 0].imshow(test, cmap='gray')
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')

axes[0, 1].imshow(blurred, cmap='gray')
axes[0, 1].set_title('Blurred (Average)')
axes[0, 1].axis('off')

axes[0, 2].imshow(np.abs(edges), cmap='gray')
axes[0, 2].set_title('Edge Detection')
axes[0, 2].axis('off')

axes[0, 3].imshow(np.clip(sharpened, 0, 255), cmap='gray')
axes[0, 3].set_title('Sharpened')
axes[0, 3].axis('off')

# Bottom row: Kernels
def show_kernel(ax, kernel, title):
    ax.imshow(kernel, cmap='RdBu', vmin=-2, vmax=2)
    for i in range(kernel.shape[0]):
        for j in range(kernel.shape[1]):
            ax.text(j, i, f'{kernel[i,j]:.2f}', ha='center', va='center', fontsize=10)
    ax.set_title(title)
    ax.axis('off')

axes[1, 0].axis('off')
axes[1, 0].text(0.5, 0.5, '3x3\nNeighborhood', ha='center', va='center', fontsize=14)
show_kernel(axes[1, 1], kernel_average, 'Average Kernel')
show_kernel(axes[1, 2], kernel_edge.astype(float), 'Edge Kernel')
show_kernel(axes[1, 3], kernel_sharpen.astype(float), 'Sharpen Kernel')

plt.tight_layout()
plt.show()

## 7. Interactive Example: Drawing with Neighbors

In [None]:
def dilate(img, connectivity=8):
    """Morphological dilation - expand white regions."""
    if connectivity == 4:
        kernel = np.array([[0,1,0],[1,1,1],[0,1,0]])
    else:
        kernel = np.ones((3,3))
    
    h, w = img.shape
    padded = np.pad(img, 1, mode='constant', constant_values=0)
    output = np.zeros_like(img)
    
    for i in range(h):
        for j in range(w):
            region = padded[i:i+3, j:j+3]
            if np.any(region * kernel):
                output[i, j] = 1
    
    return output

def erode(img, connectivity=8):
    """Morphological erosion - shrink white regions."""
    if connectivity == 4:
        kernel = np.array([[0,1,0],[1,1,1],[0,1,0]])
    else:
        kernel = np.ones((3,3))
    
    h, w = img.shape
    padded = np.pad(img, 1, mode='constant', constant_values=0)
    output = np.zeros_like(img)
    
    for i in range(h):
        for j in range(w):
            region = padded[i:i+3, j:j+3]
            if np.all(region[kernel == 1] == 1):
                output[i, j] = 1
    
    return output

# Create a thin line
line = np.zeros((30, 30), dtype=np.uint8)
line[14:16, 5:25] = 1  # Horizontal line
line[5:25, 14:16] = 1  # Vertical line (cross)

# Apply operations
dilated_1 = dilate(line)
dilated_2 = dilate(dilated_1)
eroded_1 = erode(line)

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(line, cmap='gray', interpolation='nearest')
axes[0].set_title('Original Cross')
axes[0].axis('off')

axes[1].imshow(dilated_1, cmap='gray', interpolation='nearest')
axes[1].set_title('Dilated 1x\n(Grows outward)')
axes[1].axis('off')

axes[2].imshow(dilated_2, cmap='gray', interpolation='nearest')
axes[2].set_title('Dilated 2x\n(Even larger)')
axes[2].axis('off')

axes[3].imshow(eroded_1, cmap='gray', interpolation='nearest')
axes[3].set_title('Eroded 1x\n(Shrinks inward)')
axes[3].axis('off')

plt.suptitle('Morphological Operations Use Neighborhoods!', fontsize=14)
plt.tight_layout()
plt.show()

## 8. Border Pixels: Edge Cases

What happens at the image edges where pixels don't have all neighbors?

In [None]:
# Visualize border handling strategies
small = np.array([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

# Original
pad0 = np.pad(small, 1, mode='constant', constant_values=0)
axes[0].imshow(pad0, cmap='Blues')
axes[0].set_title('Zero Padding\n(Add black border)')
for i in range(pad0.shape[0]):
    for j in range(pad0.shape[1]):
        axes[0].text(j, i, str(pad0[i,j]), ha='center', va='center')
axes[0].axis('off')

# Replicate
pad_rep = np.pad(small, 1, mode='edge')
axes[1].imshow(pad_rep, cmap='Blues')
axes[1].set_title('Edge Replication\n(Repeat border values)')
for i in range(pad_rep.shape[0]):
    for j in range(pad_rep.shape[1]):
        axes[1].text(j, i, str(pad_rep[i,j]), ha='center', va='center')
axes[1].axis('off')

# Reflect
pad_ref = np.pad(small, 1, mode='reflect')
axes[2].imshow(pad_ref, cmap='Blues')
axes[2].set_title('Reflection\n(Mirror at border)')
for i in range(pad_ref.shape[0]):
    for j in range(pad_ref.shape[1]):
        axes[2].text(j, i, str(pad_ref[i,j]), ha='center', va='center')
axes[2].axis('off')

# Wrap
pad_wrap = np.pad(small, 1, mode='wrap')
axes[3].imshow(pad_wrap, cmap='Blues')
axes[3].set_title('Wrap Around\n(Circular/periodic)')
for i in range(pad_wrap.shape[0]):
    for j in range(pad_wrap.shape[1]):
        axes[3].text(j, i, str(pad_wrap[i,j]), ha='center', va='center')
axes[3].axis('off')

plt.suptitle('Border Handling Strategies for Neighborhood Operations', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Summary

In [None]:
print("""
SUMMARY: PIXEL NEIGHBORHOODS & CONNECTIVITY
==========================================

1. CONNECTIVITY TYPES:
   - 4-connectivity: Only horizontal/vertical neighbors
   - 8-connectivity: All surrounding pixels (including diagonals)

2. CONNECTED COMPONENTS:
   - Algorithm to find and label separate objects
   - Result depends on connectivity choice!
   - Used in: Cell counting, object detection, segmentation

3. KERNEL OPERATIONS:
   - Work on a pixel AND its neighbors together
   - Examples: Blur, sharpen, edge detection
   - Kernel values determine the effect

4. MORPHOLOGICAL OPERATIONS:
   - Dilation: Expand regions (using neighbors)
   - Erosion: Shrink regions (using neighbors)

5. BORDER HANDLING:
   - Zero padding, edge replication, reflection, wrap
   - Choice affects results at image edges

This is the foundation for filters, segmentation, and object analysis!
""")