# Paper 7: ImageNet Classification with Deep Convolutional Neural Networks
## Alex Krizhevsky, Ilya Sutskever, Geoffrey E. Hinton (2012)

### AlexNet: The CNN that Started the Deep Learning Revolution

AlexNet won ImageNet 2012 with a top-5 error of 15.3%, crushing the competition (26.2%). This paper reignited interest in deep learning.

In [None]:
import torch
import matplotlib.pyplot as plt
torch.manual_seed(42)
import random

def randint(low, high=None, size=None):
    """
    Compute randint.
    
    Args:
        low: Input parameter.
        high: Input parameter.
        size: Size parameter.
    Returns:
        Computed result.
    """
    if high is None:
        high = low
        low = 0
    if size is None:
        return int(torch.randint(low, high, (1,)).item())
    return torch.randint(low, high, size)

def uniform(low=0.0, high=1.0, size=None):
    """
    Compute uniform.
    
    Args:
        low: Input parameter.
        high: Input parameter.
        size: Size parameter.
    Returns:
        Computed result.
    """
    if size is None:
        return float(torch.empty(1).uniform_(low, high).item())
    return torch.empty(size).uniform_(low, high)

def choice(a, size=None, replace=True, p=None):
    """
    Compute choice.
    
    Args:
        a: Input parameter.
        size: Size parameter.
        replace: Input parameter.
        p: Probability or probability vector.
    Returns:
        Computed result.
    """
    if isinstance(a, int):
        a = list(range(a))
    if p is not None:
        probs = torch.tensor(p, dtype=torch.float32)
        if size is None:
            idx = int(torch.multinomial(probs, 1, replacement=replace).item())
            return a[idx]
        idx = torch.multinomial(probs, size, replacement=replace).tolist()
        return [a[i] for i in idx]
    if size is None:
        return random.choice(a)
    return random.choices(a, k=size)


# Torch-Numpy compatibility helpers
if not hasattr(torch.Tensor, 'copy'):
    torch.Tensor.copy = torch.Tensor.clone

def _astype(self, dtype):
    """
    Compute  astype.
    
    Args:
        dtype: Input parameter.
    Returns:
        Computed result.
    """
    if dtype is float:
        return self.float()
    if dtype is int:
        return self.int()
    return self.to(dtype)

if not hasattr(torch.Tensor, 'astype'):
    torch.Tensor.astype = _astype


## Convolutional Layer Implementation

The core building block of CNNs

In [None]:
def relu(x):
    """
    Compute relu.
    
    Args:
        x: Input data.
    Returns:
        Computed result.
    """
    return torch.maximum(0, x)

def conv2d(input_image, kernel, stride=1, padding=0):
    """
    Compute conv2d.
    
    Args:
        input_image: Image data.
        kernel: Kernel or filter.
        stride: Input parameter.
        padding: Input parameter.
    Returns:
        output: Input parameter.
    """
    if len(input_image.shape) == 2:
        input_image = input_image[torch.newaxis, :, :]
    
    in_channels, H, W = input_image.shape
    out_channels, _, kH, kW = kernel.shape
    
    # Add padding
    if padding > 0:
        input_padded = torch.pad(input_image, 
                             ((0, 0), (padding, padding), (padding, padding)), 
                             mode='constant')
    else:
        input_padded = input_image
    
    # Output dimensions
    out_H = (H + 2*padding - kH) // stride + 1
    out_W = (W + 2*padding - kW) // stride + 1
    
    output = torch.zeros((out_channels, out_H, out_W))
    
    # Perform convolution
    for oc in range(out_channels):
        for i in range(out_H):
            for j in range(out_W):
                h_start = i * stride
                w_start = j * stride
                
                # Extract patch
                patch = input_padded[:, h_start:h_start+kH, w_start:w_start+kW]
                
                # Convolve with kernel
                output[oc, i, j] = torch.sum(patch * kernel[oc])
    
    return output

def max_pool2d(input_image, pool_size=2, stride=2):
    """
    Compute max pool2d.
    
    Args:
        input_image: Image data.
        pool_size: Size of pool.
        stride: Input parameter.
    Returns:
        output: Input parameter.
    """
    C, H, W = input_image.shape
    
    out_H = (H - pool_size) // stride + 1
    out_W = (W - pool_size) // stride + 1
    
    output = torch.zeros((C, out_H, out_W))
    
    for c in range(C):
        for i in range(out_H):
            for j in range(out_W):
                h_start = i * stride
                w_start = j * stride
                
                pool_region = input_image[c, h_start:h_start+pool_size, 
                                         w_start:w_start+pool_size]
                output[c, i, j] = torch.amax(pool_region)
    
    return output

# Test convolution
test_image = torch.randn(1, 8, 8)
test_kernel = torch.randn(3, 1, 3, 3) * 0.1

conv_output = conv2d(test_image, test_kernel, stride=1, padding=1)
print(f"Input shape: {test_image.shape}")
print(f"Kernel shape: {test_kernel.shape}")
print(f"Conv output shape: {conv_output.shape}")

pooled = max_pool2d(conv_output, pool_size=2, stride=2)
print(f"After max pooling: {pooled.shape}")

## AlexNet Architecture (Simplified)

Original: 227x227x3 → 5 conv layers → 3 FC layers → 1000 classes

Our simplified version for 32x32 images

In [None]:
class AlexNetSimplified:
    def __init__(self, num_classes=10):
        """
        Initialize the instance.
        
        Args:
            num_classes: Number of classes.
        Returns:
            Computed result.
        """
        # Conv layers
        self.conv1_filters = torch.randn(32, 3, 3, 3) * 0.01
        self.conv1_bias = torch.zeros(32)
        
        self.conv2_filters = torch.randn(64, 32, 3, 3) * 0.01
        self.conv2_bias = torch.zeros(64)
        
        self.conv3_filters = torch.randn(128, 64, 3, 3) * 0.01
        self.conv3_bias = torch.zeros(128)
        
        # FC layers (after conv: 128 * 4 * 4 = 2048)
        self.fc1_weights = torch.randn(2048, 512) * 0.01
        self.fc1_bias = torch.zeros(512)
        
        self.fc2_weights = torch.randn(512, num_classes) * 0.01
        self.fc2_bias = torch.zeros(num_classes)
    
    def forward(self, x, use_dropout=False, dropout_rate=0.5):
        """
        Run the forward pass value.
        
        Args:
            x: Input data.
            use_dropout: Input parameter.
            dropout_rate: Rate parameter.
        Returns:
            Computed result.
        """
        # Conv1 + ReLU + MaxPool
        conv1 = conv2d(x, self.conv1_filters, stride=1, padding=1)
        conv1 += self.conv1_bias[:, torch.newaxis, torch.newaxis]
        conv1 = relu(conv1)
        pool1 = max_pool2d(conv1, pool_size=2, stride=2)  # 32 x 16 x 16
        
        # Conv2 + ReLU + MaxPool
        conv2 = conv2d(pool1, self.conv2_filters, stride=1, padding=1)
        conv2 += self.conv2_bias[:, torch.newaxis, torch.newaxis]
        conv2 = relu(conv2)
        pool2 = max_pool2d(conv2, pool_size=2, stride=2)  # 64 x 8 x 8
        
        # Conv3 + ReLU + MaxPool
        conv3 = conv2d(pool2, self.conv3_filters, stride=1, padding=1)
        conv3 += self.conv3_bias[:, torch.newaxis, torch.newaxis]
        conv3 = relu(conv3)
        pool3 = max_pool2d(conv3, pool_size=2, stride=2)  # 128 x 4 x 4
        
        # Flatten
        flattened = pool3.reshape(-1)
        
        # FC1 + ReLU + Dropout
        fc1 = torch.matmul(flattened, self.fc1_weights) + self.fc1_bias
        fc1 = relu(fc1)
        
        if use_dropout:
            dropout_mask = (torch.rand(*fc1.shape) > dropout_rate).astype(float)
            fc1 = fc1 * dropout_mask / (1 - dropout_rate)
        
        # FC2 (output)
        output = torch.matmul(fc1, self.fc2_weights) + self.fc2_bias
        
        return output

# Create model
alexnet = AlexNetSimplified(num_classes=10)
print("AlexNet (simplified) created")

# Test forward pass
test_img = torch.randn(3, 32, 32)
output = alexnet.forward(test_img)
print(f"Input: (3, 32, 32)")
print(f"Output: {output.shape} (class scores)")

## Generate Synthetic Image Data

In [None]:
def generate_simple_images(num_samples=100, image_size=32):
    """
    Generate simple images.
    
    Args:
        num_samples: Number of samples.
        image_size: Size of image.
    Returns:
        item: Computed value.
        item: Computed value.
    """
    X = []
    y = []
    
    for i in range(num_samples):
        class_label = i % 10
        img = torch.zeros((3, image_size, image_size))
        
        if class_label == 0:  # Horizontal stripes
            for row in range(0, image_size, 4):
                img[:, row:row+2, :] = 1
        
        elif class_label == 1:  # Vertical stripes
            for col in range(0, image_size, 4):
                img[:, :, col:col+2] = 1
        
        elif class_label == 2:  # Diagonal
            for i in range(image_size):
                if i < image_size:
                    img[:, i, i] = 1
                    if i+1 < image_size:
                        img[:, i, i+1] = 1
        
        elif class_label == 3:  # Checkerboard
            for i in range(0, image_size, 4):
                for j in range(0, image_size, 4):
                    if (i//4 + j//4) % 2 == 0:
                        img[:, i:i+4, j:j+4] = 1
        
        elif class_label == 4:  # Circle
            center = image_size // 2
            radius = image_size // 3
            y_grid, x_grid = torch.meshgrid(torch.arange(image_size), torch.arange(image_size), indexing='ij')
            mask = (x_grid - center)**2 + (y_grid - center)**2 <= radius**2
            img[:, mask] = 1
        
        elif class_label == 5:  # Square
            margin = image_size // 4
            img[:, margin:-margin, margin:-margin] = 1
        
        elif class_label == 6:  # Cross
            mid = image_size // 2
            thickness = 3
            img[:, mid-thickness:mid+thickness, :] = 1
            img[:, :, mid-thickness:mid+thickness] = 1
        
        elif class_label == 7:  # Triangle
            for i in range(image_size):
                width = int((i / image_size) * image_size / 2)
                start = image_size // 2 - width
                end = image_size // 2 + width
                img[:, i, start:end] = 1
        
        elif class_label == 8:  # Random noise
            img = torch.rand(3, image_size, image_size)
        
        else:  # Solid
            img[:] = 0.7
        
        # Add color variation
        color = torch.rand(3, 1, 1)
        img = img * color
        
        # Add noise
        img += torch.randn(3, image_size, image_size) * 0.1
        img = torch.clamp(img, 0, 1)
        
        X.append(img)
        y.append(class_label)
    
    return torch.tensor(X), torch.tensor(y)

# Generate dataset
X_train, y_train = generate_simple_images(200)
X_test, y_test = generate_simple_images(50)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")

# Visualize samples
class_names = ['H-Stripes', 'V-Stripes', 'Diagonal', 'Checker', 'Circle', 
               'Square', 'Cross', 'Triangle', 'Noise', 'Solid']

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i in range(10):
    # Find first occurrence of each class
    idx = torch.where(y_train == i)[0][0]
    img = X_train[idx].transpose(1, 2, 0)  # CHW -> HWC
    axes[i].imshow(img)
    axes[i].set_title(class_names[i])
    axes[i].axis('off')

plt.suptitle('Synthetic Image Dataset (10 Classes)', fontsize=14)
plt.tight_layout()
plt.show()

## Data Augmentation

AlexNet used data augmentation extensively - a key innovation

In [None]:
def random_flip(img):
    """
    Compute random flip.
    
    Args:
        img: Image data.
    Returns:
        img: Image data.
    """
    if torch.rand() > 0.5:
        return img[:, :, ::-1].copy()
    return img

def random_crop(img, crop_size=28):
    """
    Compute random crop.
    
    Args:
        img: Image data.
        crop_size: Size of crop.
    Returns:
        resized: Size parameter.
    """
    _, h, w = img.shape
    top = randint(0, h - crop_size + 1)
    left = randint(0, w - crop_size + 1)
    
    cropped = img[:, top:top+crop_size, left:left+crop_size]
    
    # Resize back to original
    # Simple nearest neighbor (for demo)
    scale_h = h / crop_size
    scale_w = w / crop_size
    
    resized = torch.zeros_like(img)
    for i in range(h):
        for j in range(w):
            src_i = min(int(i / scale_h), crop_size - 1)
            src_j = min(int(j / scale_w), crop_size - 1)
            resized[:, i, j] = cropped[:, src_i, src_j]
    
    return resized

def add_noise(img, noise_level=0.05):
    """
    Compute add noise.
    
    Args:
        img: Image data.
        noise_level: Input parameter.
    Returns:
        Computed result.
    """
    noise = torch.randn(*img.shape) * noise_level
    return torch.clamp(img + noise, 0, 1)

def augment_image(img):
    """
    Compute augment image.
    
    Args:
        img: Image data.
    Returns:
        img: Image data.
    """
    img = random_flip(img)
    img = random_crop(img)
    img = add_noise(img)
    return img

# Demonstrate augmentation
original = X_train[0]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

axes[0, 0].imshow(original.transpose(1, 2, 0))
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')

for i in range(1, 8):
    augmented = augment_image(original.copy())
    row = i // 4
    col = i % 4
    axes[row, col].imshow(augmented.transpose(1, 2, 0))
    axes[row, col].set_title(f'Augmented {i}')
    axes[row, col].axis('off')

plt.suptitle('Data Augmentation Examples', fontsize=14)
plt.tight_layout()
plt.show()

## Visualize Learned Filters

One of the insights from AlexNet: visualize what the network learns

In [None]:
# Visualize first layer filters
filters = alexnet.conv1_filters  # Shape: (32, 3, 3, 3)

fig, axes = plt.subplots(4, 8, figsize=(16, 8))
axes = axes.flatten()

for i in range(min(32, len(axes))):
    # Normalize filter for visualization
    filt = filters[i].transpose(1, 2, 0)  # CHW -> HWC
    filt = (filt - filt.min()) / (filt.max() - filt.min() + 1e-8)
    
    axes[i].imshow(filt)
    axes[i].axis('off')
    axes[i].set_title(f'F{i}', fontsize=8)

plt.suptitle('Conv1 Filters (32 filters, 3x3, RGB)', fontsize=14)
plt.tight_layout()
plt.show()

print("These filters learn to detect edges, colors, and simple patterns")

## Feature Map Visualization

In [None]:
# Process an image and visualize feature maps
test_image = X_train[4]  # Circle

# Forward through first conv layer
conv1_output = conv2d(test_image, alexnet.conv1_filters, stride=1, padding=1)
conv1_output += alexnet.conv1_bias[:, torch.newaxis, torch.newaxis]
conv1_output = relu(conv1_output)

# Visualize
fig = plt.figure(figsize=(16, 10))

# Original image
ax = plt.subplot(6, 6, 1)
ax.imshow(test_image.transpose(1, 2, 0))
ax.set_title('Input Image', fontsize=10)
ax.axis('off')

# Feature maps
for i in range(min(32, 35)):
    ax = plt.subplot(6, 6, i+2)
    ax.imshow(conv1_output[i], cmap='viridis')
    ax.set_title(f'Map {i}', fontsize=8)
    ax.axis('off')

plt.suptitle('Feature Maps after Conv1 + ReLU', fontsize=14)
plt.tight_layout()
plt.show()

print("Different feature maps respond to different patterns in the image")

## Test Classification

In [None]:
def softmax(x):
    """
    Compute softmax probabilities.
    
    Args:
        x: Input data.
    Returns:
        Softmax probabilities.
    """
    exp_x = torch.exp(x - torch.amax(x))
    return exp_x / exp_x.sum()

# Test on a few images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i in range(10):
    idx = i * 5  # Sample every 5th image
    img = X_test[idx]
    true_label = y_test[idx]
    
    # Forward pass
    logits = alexnet.forward(img, use_dropout=False)
    probs = softmax(logits)
    pred_label = torch.argmax(probs)
    
    # Display
    axes[i].imshow(img.transpose(1, 2, 0))
    axes[i].set_title(f'True: {class_names[true_label]}\nPred: {class_names[pred_label]}\nConf: {probs[pred_label]:.2f}',
                     fontsize=9)
    axes[i].axis('off')

plt.suptitle('AlexNet Predictions (Untrained)', fontsize=14)
plt.tight_layout()
plt.show()

print("Note: Model is untrained, so predictions are random!")
print("Training would require gradient descent, which we've simplified for clarity.")

## Key Takeaways

### AlexNet Innovations (2012):

1. **ReLU Activation**: Much faster than sigmoid/tanh
   - No saturation for positive values
   - Faster training (6x compared to tanh)

2. **Dropout**: Powerful regularization
   - Prevents overfitting
   - Used in FC layers (0.5 rate)

3. **Data Augmentation**: 
   - Random crops and flips
   - Color jittering
   - Artificially increases dataset size

4. **GPU Training**: 
   - Used 2 GTX 580 GPUs
   - Enabled training of deep networks

5. **Local Response Normalization (LRN)**:
   - Lateral inhibition between feature maps
   - Less common now (Batch Norm replaced it)

### Architecture:
```
Input (227x227x3)
  ↓
Conv1 (96 filters, 11x11, stride 4) + ReLU + MaxPool
  ↓
Conv2 (256 filters, 5x5) + ReLU + MaxPool
  ↓
Conv3 (384 filters, 3x3) + ReLU
  ↓
Conv4 (384 filters, 3x3) + ReLU
  ↓
Conv5 (256 filters, 3x3) + ReLU + MaxPool
  ↓
FC6 (4096) + ReLU + Dropout
  ↓
FC7 (4096) + ReLU + Dropout
  ↓
FC8 (1000 classes) + Softmax
```

### Impact:
- **Won ImageNet 2012**: 15.3% top-5 error (vs 26.2% second place)
- **Reignited deep learning**: Showed depth + data + compute works
- **GPU revolution**: Made GPUs essential for deep learning
- **Inspired modern CNNs**: VGG, ResNet, etc. built on these ideas

### Why It Worked:
1. Deep architecture (8 layers was deep in 2012!)
2. Large dataset (1.2M ImageNet images)
3. GPU acceleration (made training feasible)
4. Smart regularization (dropout + data aug)
5. ReLU activation (faster training)

### Modern Perspective:
- AlexNet is considered "simple" now
- ResNets have 100+ layers
- Batch Norm replaced LRN
- But the core ideas remain:
  - Deep hierarchical features
  - Convolution for spatial structure
  - Data augmentation
  - Regularization