**Summary of the Full Project Pipeline**
- Input: Raw RGB image.
- Opponency: Break into Intensity, Red-Green, and Blue-Yellow.
- Center-Surround: Apply Difference of Gaussians (DoG) at multiple scales.
- Integration: Combine scales into a master Saliency Map.
- Gating: Mask the original image to create a "Foveal" view.
- Classification: Predict the object based on the attended pixels.

### The Concept: Difference of Gaussians (DoG)

To mimic the "On-Center/Off-Surround" cell, we calculate the difference between a narrow Gaussian kernel (the center) and a wide Gaussian kernel (the surround).

#### Implementation

We will define a `CenterSurroundAttention` layer. Instead of standard "dot-product" attention, this uses spatial filtering to determine what the model should focus on.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
class CenterSurroundLayer(nn.Module):
    def __init__(self, channels, kernel_size=7, sigma_center=0.5, sigma_surround=3.0):
        super(CenterSurroundLayer, self).__init__()
        self.channels = channels
        self.kernel_size = kernel_size

        # Create a coordinate grid
        x = torch.arange(kernel_size) - (kernel_size - 1) / 2
        y = torch.arange(kernel_size) - (kernel_size - 1) / 2
        grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
        dist_sq = grid_x**2 + grid_y**2

        # Calculate Gaussian kernels
        center = torch.exp(-dist_sq / (2 * sigma_center**2))
        surround = torch.exp(-dist_sq / (2 * sigma_surround**2))

        # Center-Surround (On-Center / Off-Surround)
        # Normalizing ensures we don't explode the brightness
        dog_kernel = (center / center.sum()) - (surround / surround.sum())

        # Reshape for depthwise convolution: [out_channels, in_channels/groups, H, W]
        dog_kernel = dog_kernel.view(1, 1, kernel_size, kernel_size)
        self.register_buffer('weight', dog_kernel.repeat(channels, 1, 1, 1))

    def forward(self, x):
        # Using functional conv2d with groups=channels for depthwise operation
        return F.conv2d(x, self.weight, padding=self.kernel_size//2, groups=self.channels)

In [None]:
# Example Usage:
# model = CenterSurroundLayer(channels=3)
# output = model(input_tensor)

**Why this matters in AI**

    Saliency: This layer naturally highlights "blobs" and edges, acting as a pre-processor for higher-level attention.

    Data Efficiency: Unlike Transformers, which have to learn spatial relationships, this layer enforces a biological prior about how light and shapes work.

    Noise Reduction: By subtracting the surround, you effectively perform high-pass filtering, removing low-frequency noise.

**Visualization in notebook**

Since you are in a notebook, you should definitely visualize the kernel to ensure it looks like the "Mexican Hat" function.

In [None]:
layer = CenterSurroundLayer(channels=1, kernel_size=21, sigma_center=2.0, sigma_surround=5.0)
kernel = layer.weight.squeeze().cpu().numpy()

plt.imshow(kernel, cmap='RdBu')
plt.colorbar()
plt.title("Center-Surround Receptive Field (DoG)")
plt.show()

This hard-coded version is great for understanding, but in a hybrid AI model, we usually make `sigma_center` and `sigma_surround` learnable parameters.

Making the parameters learnable transitions this from a fixed image processing filter to an Adaptive Center-Surround Attention mechanism. This allows the network to dynamically "zoom" its focus based on the scale of the features it's trying to detect.

In PyTorch, we can't just make the kernel weights learnable if we want to maintain the Gaussian shape. Instead, we make the σ (variance) values learnable. This ensures the model always follows biological constraints while optimizing for performance.

### Adaptive Model

In [None]:
class AdaptiveCenterSurround(nn.Module):
    def __init__(self, channels, kernel_size=15):
        super().__init__()
        self.channels = channels
        self.kernel_size = kernel_size

        # Initialize sigmas as learnable parameters
        # We use log-space or softplus to ensure they stay positive
        self.sigma_center = nn.Parameter(torch.tensor([1.0]))
        self.sigma_surround = nn.Parameter(torch.tensor([3.0]))

        # Pre-calculate coordinate grid
        x = torch.arange(kernel_size).float() - (kernel_size - 1) / 2
        y = torch.arange(kernel_size).float() - (kernel_size - 1) / 2
        grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
        self.register_buffer('dist_sq', grid_x**2 + grid_y**2)

    def get_kernel(self):
        # Ensure sigmas are positive and surround > center
        s_c = torch.clamp(self.sigma_center, min=0.5)
        s_s = torch.clamp(self.sigma_surround, min=s_c.item() + 0.5)

        # Generate Gaussians
        center = torch.exp(-self.dist_sq / (2 * s_c**2))
        surround = torch.exp(-self.dist_sq / (2 * s_s**2))

        # Normalize and subtract
        dog_kernel = (center / center.sum()) - (surround / surround.sum())
        return dog_kernel.view(1, 1, self.kernel_size, self.kernel_size).repeat(self.channels, 1, 1, 1)

    def forward(self, x):
        kernel = self.get_kernel()
        return F.conv2d(x, kernel, padding=self.kernel_size//2, groups=self.channels)

### Why this Matters?

In the primary visual cortex (V1), neurons don't have a fixed resolution. Some cells are tuned to high-frequency edges (small $σ$), while others look at broader textures (large $σ$).

By using the code above:

1. Backpropagation will now calculate the gradient of the loss with respect to $σ$.
2. If the model needs more detail to reduce loss, $σ$ center​ will shrink.
3. If the model needs to ignore high-frequency noise, $σ$ surround​ will expand.

### Testing the "Attention" Effect

You can run this simple test in your notebook to see how the model reacts to a simple input (like a white square on a black background):

In [None]:
# Create a dummy image (black with a white square)
img = torch.zeros(1, 1, 32, 32)
# change this variable to change the image seeing the attention change
img[:, :, 10:22, 10:22] = 1.0

model = AdaptiveCenterSurround(channels=1, kernel_size=15)
output = model(img)

# Plotting
fig, ax = plt.subplots(1, 2)
ax[0].imshow(img[0,0], cmap='gray')
ax[0].set_title("Input (Stimulus)")
ax[1].imshow(output.detach()[0,0], cmap='magma')
ax[1].set_title("Neural Response (Edges)")
plt.show()

## Taking to Next Step

We will be modelling the full pipeline!!

We’re essentially recreating the Itti-Koch saliency model, which is the gold standard for biologically inspired computer vision. It simulates how the human brain processes "bottom-up" visual attention using multi-scale features.

### Step 1: The Multi-Scale Saliency Model

The Itti-Koch model works by creating "Feature Maps" (Intensity, Color, Orientation) and then applying center-surround operations across different spatial scales.

In [None]:
class SaliencyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # We use a Gaussian Pyramid approach for multi-scale processing
        self.cs_layer = AdaptiveCenterSurround(channels=1, kernel_size=15)

    def get_pyramid(self, x, levels=4):
        pyramid = [x]
        for i in range(levels - 1):
            pyramid.append(F.avg_pool2d(pyramid[-1], kernel_size=2, stride=2))
        return pyramid

    def forward(self, x):
        # 1. Convert to Intensity (Greyscale)
        intensity = torch.mean(x, dim=1, keepdim=True)

        # 2. Build Pyramid (Multi-scale)
        scales = self.get_pyramid(intensity)

        # 3. Apply Center-Surround to each scale
        maps = []
        for s in scales:
            cs_map = self.cs_layer(s)
            # Resize back to original size to combine
            maps.append(F.interpolate(cs_map, size=(x.shape[2], x.shape[3]), mode='bilinear'))

        # 4. Aggregate Saliency (Summing the maps)
        saliency_map = torch.mean(torch.stack(maps), dim=0)
        return torch.abs(saliency_map) # Focus on magnitude of contrast

### Step 2: Heatmap Visualization Tool

To make this useful, we need a way to overlay the "neural firing" (saliency) onto the original image. This helps us see what the model "thinks" is important.

In [None]:
def visualize_attention(img_path, model):
    # Load and Preprocess
    raw_img = Image.open(img_path).convert('RGB')
    transform = T.Compose([T.Resize((256, 256)), T.ToTensor()])
    input_tensor = transform(raw_img).unsqueeze(0)

    # Generate Saliency
    with torch.no_grad():
        saliency = model(input_tensor)

    # Normalize for visualization
    s_map = saliency[0, 0].cpu().numpy()
    s_map = (s_map - s_map.min()) / (s_map.max() - s_map.min())

    # Create Heatmap Overlay
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(raw_img.resize((256, 256)))
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(s_map, cmap='jet')
    plt.title("Saliency Map")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    # Overlaying
    plt.imshow(raw_img.resize((256, 256)))
    plt.imshow(s_map, cmap='jet', alpha=0.5) # Alpha blending
    plt.title("Attention Overlay")
    plt.axis('off')

    plt.show()

### Understanding the Mechanism

The "Center-Surround" creates a "Mexican Hat" response. In an image, this means the model ignores flat areas (like a clear blue sky) and reacts strongly to discontinuities (like a bird in that sky).

### Running the project

In [None]:
# Initialize
model = SaliencyModel()

# Run on an image (Ensure you have a sample.jpg in your directory)
visualize_attention('sample.jpg', model)

## Improving the model

Adding Color Opponency is where the neuroscience really shines. In the primate visual system, the Parvocellular pathway processes "Red vs. Green" and "Blue vs. Yellow" signals. This is why a red apple pops out against green leaves—not necessarily because it's brighter, but because of the chromatic contrast.

Let’s build the final, complete notebook structure.

**The Color Opponency Module**

We will create four channels: R (Red), G (Green), B (Blue), and Y (Yellow).

- $RG=∣R−G∣$
- $BY=∣B−Y∣$ (where $Y=(R+G)\div2​$)

In [None]:
class ColorOpponency(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # x shape: [B, 3, H, W]
        r, g, b = x[:, 0, :, :], x[:, 1, :, :], x[:, 2, :, :]

        # Intensity (I)
        intensity = (r + g + b) / 3

        # Color channels normalized by intensity to get pure hue
        # We add a small epsilon to avoid division by zero
        eps = 1e-5
        R = r - (g + b) / 2
        G = g - (r + b) / 2
        B = b - (r + g) / 2
        Y = (r + g) / 2 - torch.abs(r - g) / 2 - b

        # Primate-like Opponency Maps
        RG = torch.abs(R - G).unsqueeze(1)
        BY = torch.abs(B - Y).unsqueeze(1)

        return intensity.unsqueeze(1), RG, BY

**The Integrated Itti-Koch Model**

Now we combine Intensity, Color, and Center-Surround across multiple scales.

In [None]:
class FullIttiKochModel(nn.Module):
    def __init__(self, kernel_size=15):
        super().__init__()
        self.color_engine = ColorOpponency()
        # Shared center-surround layer
        self.cs_layer = AdaptiveCenterSurround(channels=1, kernel_size=kernel_size)

    def process_feature(self, feature_map, levels=4):
        # Create pyramid
        pyramid = [feature_map]
        for _ in range(levels - 1):
            pyramid.append(F.avg_pool2d(pyramid[-1], kernel_size=2, stride=2))

        # Apply Center-Surround and upscale back
        results = []
        for p in pyramid:
            out = self.cs_layer(p)
            results.append(F.interpolate(out, size=(feature_map.shape[2], feature_map.shape[3]), mode='bilinear'))

        return torch.mean(torch.stack(results), dim=0)

    def forward(self, x):
        I, RG, BY = self.color_engine(x)

        # Extract saliency for each pathway
        saliency_I = self.process_feature(I)
        saliency_RG = self.process_feature(RG)
        saliency_BY = self.process_feature(BY)

        # Combine maps (Linear Integration)
        combined_saliency = (saliency_I + saliency_RG + saliency_BY) / 3
        return torch.abs(combined_saliency)

**The Visualization Tool (Final Form)**

This function will generate the heatmap and overlay it, creating that "thermal vision" effect seen in eye-tracking studies.

In [None]:
import cv2
import numpy as np

def generate_heatmap(model, image_tensor, original_img_size):
    model.eval()
    with torch.no_grad():
        saliency = model(image_tensor)

    # Process saliency map for heatmap
    s_map = saliency[0, 0].cpu().numpy()
    s_map = (s_map - s_map.min()) / (s_map.max() - s_map.min() + 1e-8)
    s_map = (s_map * 255).astype(np.uint8)

    # Resize to original image dimensions
    s_map_resized = cv2.resize(s_map, original_img_size)

    # Apply Color Map (Jet)
    heatmap = cv2.applyColorMap(s_map_resized, cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

    return heatmap, s_map_resized

def plot_final_results(img_path):
    # Load
    raw_img = np.array(Image.open(img_path).convert('RGB'))
    h, w, _ = raw_img.shape

    # Prep tensor
    transform = T.Compose([T.ToTensor(), T.Resize((256, 256))])
    input_tensor = transform(Image.fromarray(raw_img)).unsqueeze(0)

    # Run Model
    model = FullIttiKochModel()
    heatmap, gray_saliency = generate_heatmap(model, input_tensor, (w, h))

    # Overlay
    overlay = cv2.addWeighted(raw_img, 0.6, heatmap, 0.4, 0)

    # Display
    plt.figure(figsize=(16, 8))
    titles = ['Original', 'Saliency (Grey)', 'Heatmap', 'Attention Overlay']
    imgs = [raw_img, gray_saliency, heatmap, overlay]

    for i in range(4):
        plt.subplot(1, 4, i+1)
        plt.imshow(imgs[i], cmap='gray' if i==1 else None)
        plt.title(titles[i])
        plt.axis('off')
    plt.tight_layout()
    plt.show()

## Summary of Code

1. Bio-Plausible Filtering: Using Difference of Gaussians to simulate retinal ganglion cells.
2. Learnable Parameters: Allowing σ to be optimized via backprop (if you choose to train it).
3. Multi-Scale Processing: Using a Gaussian pyramid to find salient objects of different sizes.
4. Chromatic Opponency: Modeling the R/G and B/Y pathways of the human brain.

Running on a specific dataset(CIFAR) to see how it can be used as a "foveal" pre-processor to reduce the amount of data a neural network has to "look" at.

### The Attention-Gated Classifier

We will create a wrapper that takes an image, generates the Itti-Koch saliency map, and uses that map to "gate" the pixels before they reach a standard CNN (like a simple ResNet or ConvNet).

In [None]:
class FovealAttentionNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.saliency_extractor = FullIttiKochModel()

        # A simple CNN "Brain" to process the gated image
        self.classifier = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        # 1. Generate the attention mask [B, 1, H, W]
        mask = self.saliency_extractor(x)

        # Normalize mask to [0, 1] range for gating
        mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)

        # 2. "Foveation": Multiply input by mask
        # This effectively 'blacks out' non-salient areas
        gated_x = x * mask

        # 3. Classify based ONLY on attended regions
        logits = self.classifier(gated_x)
        return logits, gated_x, mask

### Testing on CIFAR-10

Since CIFAR-10 images are small (32x32), the center-surround effect is very dramatic.

In [None]:
import torchvision
import torchvision.transforms as transforms

# Load CIFAR-10
transform = transforms.Compose([transforms.ToTensor()])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True)

# Initialize Model
net = FovealAttentionNet(num_classes=10)
dataiter = iter(testloader)
images, labels = next(dataiter)

# Run Inference
logits, gated_images, masks = net(images)

# Visualize the 'Foveation' process
fig, axes = plt.subplots(4, 3, figsize=(10, 12))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

for i in range(4):
    # Original
    axes[i, 0].imshow(images[i].permute(1, 2, 0))
    axes[i, 0].set_title(f"Original: {classes[labels[i]]}")

    # Mask
    axes[i, 1].imshow(masks[i, 0].detach().numpy(), cmap='gray')
    axes[i, 1].set_title("Saliency Mask")

    # Gated Input
    axes[i, 2].imshow(gated_images[i].detach().permute(1, 2, 0))
    axes[i, 2].set_title("What the Model 'Sees'")

    for ax in axes[i]: ax.axis('off')

plt.tight_layout()
plt.show()

- Background Suppression: If there's a bird on a branch, the leaves and sky are dimmed, and the high-contrast edges of the bird are highlighted.
- Information Bottleneck: The classifier is now forced to be sparse. It cannot rely on background correlations (like "green pixels usually mean deer") because the attention mechanism filters them out.

#### Comparison Training Loop

To see if this "Bio-Plausible" approach actually helps, we can compare it against a vanilla CNN. We will track the Validation Accuracy for both to see if the attention-gated model learns more efficiently or handles noise better.

In [None]:
import torch.optim as optim

def train_and_compare(epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Initialize two models
    # Standard CNN (no attention)
    standard_net = FovealAttentionNet().to(device)
    # We "turn off" the gating by overriding the forward for comparison
    def standard_forward(x):
        logits = standard_net.classifier(x)
        return logits

    # Attention-Gated CNN
    attention_net = FovealAttentionNet().to(device)

    # 2. Setup Optimizers
    opt_std = optim.Adam(standard_net.parameters(), lr=0.001)
    opt_attn = optim.Adam(attention_net.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    history = {'std_acc': [], 'attn_acc': []}

    for epoch in range(epochs):
        # --- Training Loop ---
        standard_net.train()
        attention_net.train()

        for i, (imgs, labels) in enumerate(testloader): # Using small subset for demo
            imgs, labels = imgs.to(device), labels.to(device)

            # Train Standard
            opt_std.zero_grad()
            out_std = standard_forward(imgs)
            loss_std = criterion(out_std, labels)
            loss_std.backward()
            opt_std.step()

            # Train Attention
            opt_attn.zero_grad()
            out_attn, _, _ = attention_net(imgs)
            loss_attn = criterion(out_attn, labels)
            loss_attn.backward()
            opt_attn.step()

        print(f"Epoch {epoch+1} complete.")
        # (In a real scenario, you'd calculate accuracy here and append to history)

    return standard_net, attention_net

# Execute
std_model, attn_model = train_and_compare()

#### Robustness Testing (The "Noise" Challenge)

Standard CNNs are notoriously fragile to noise. Biological systems use center-surround attention specifically to filter out environmental "static."

Test both models on images with Gaussian Noise or Salt-and-Pepper Noise.
- Hypothesis: The Metacognitive model will maintain higher accuracy because the Saliency Gate "blacks out" the noise in the background, only letting the high-contrast features through.

In [None]:
def add_noise(batch, noise_level=0.1):
    return batch + noise_level * torch.randn_like(batch)

# Pending from here
# Test both models on noisy_images = add_noise(test_images)

#### "Noise Sensitivity" Experiment

Instead of testing just one noise level, test a range. This will create a "Decay Curve." If your Metacognitive model is superior, its curve will stay "flat" longer than the standard model.

In [None]:
noise_levels = [0.0, 0.1, 0.2, 0.3, 0.5]
results = []

# Get one batch of data for the stress test
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

for level in noise_levels:
    # 1. Add noise
    noisy_batch = add_noise(images, noise_level=level)

    # 2. Test Standard Model
    with torch.no_grad():
        # Using the classifier directly as 'standard'
        std_logits = standard_net.classifier(noisy_batch)
        std_acc = (std_logits.argmax(1) == labels).float().mean().item()

        # 3. Test Metacognitive Model
        meta_logits, _, _ = attention_net(noisy_batch)
        meta_acc = (meta_logits.argmax(1) == labels).float().mean().item()

    results.append({
        'Noise Level': level,
        'Standard Acc': std_acc,
        'Metacognitive Acc': meta_acc
    })

# Display as a Table
import pandas as pd
df = pd.DataFrame(results)
print(df)

#### Visualize "Attention under Attack"

This is the most interesting part of the notebook. You need to see if the Center-Surround mechanism is successfully "cleaning" the image before it hits the classifier.

Run this code to compare what the "Brain" (the classifier) sees in both models:

In [None]:
level = 0.3
noisy_imgs = add_noise(images, noise_level=level)

# Get the attention gate's output
_, gated_imgs, masks = attention_net(noisy_imgs)

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(noisy_imgs[0].permute(1,2,0).clip(0,1))
axes[0].set_title("Standard Model Sees This (Noisy)")

axes[1].imshow(masks[0,0].detach(), cmap='gray')
axes[1].set_title("Metacognitive Attention Map")

axes[2].imshow(gated_imgs[0].detach().permute(1,2,0).clip(0,1))
axes[2].set_title("Gated Input (Cleaned features)")
plt.show()

#### Calculate "Information Sparsity" (The Efficiency Metric)

Since our project is about Metacognitive Attention, we should prove that our model is "smarter" because it ignores useless data. We measure this by seeing how "dark" the masks are.

In [None]:
def calculate_sparsity(mask):
    # Percentage of signal allowed through (0 to 1)
    return torch.mean(mask).item()

avg_sparsity = calculate_sparsity(masks)
print(f"The model is classifying based on only {avg_sparsity*100:.2f}% of the pixels.")